Daniel.y commited on
Commit
132bd6e
·
unverified ·
2 Parent(s): b05bfe9 6f663b5

Merge pull request #1673 from danielaskdd/litellm-problem

Browse files

feat: Support `application/octet-stream` requests from LiteLLM clients for Ollama Emulation

lightrag/api/__init__.py CHANGED
@@ -1 +1 @@
1
- __api_version__ = "0172"
 
1
+ __api_version__ = "0173"
lightrag/api/routers/ollama_api.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import APIRouter, HTTPException, Request
2
  from pydantic import BaseModel
3
- from typing import List, Dict, Any, Optional
4
- import logging
5
  import time
6
  import json
7
  import re
@@ -95,6 +95,68 @@ class OllamaTagResponse(BaseModel):
95
  models: List[OllamaModel]
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def estimate_tokens(text: str) -> int:
99
  """Estimate the number of tokens in text using tiktoken"""
100
  tokens = TiktokenTokenizer().encode(text)
@@ -197,13 +259,43 @@ class OllamaAPI:
197
  ]
198
  )
199
 
200
- @self.router.post("/generate", dependencies=[Depends(combined_auth)])
201
- async def generate(raw_request: Request, request: OllamaGenerateRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  """Handle generate completion requests acting as an Ollama model
203
  For compatibility purpose, the request is not processed by LightRAG,
204
  and will be handled by underlying LLM model.
 
205
  """
206
  try:
 
 
 
207
  query = request.prompt
208
  start_time = time.time_ns()
209
  prompt_tokens = estimate_tokens(query)
@@ -278,7 +370,7 @@ class OllamaAPI:
278
  else:
279
  error_msg = f"Provider error: {error_msg}"
280
 
281
- logging.error(f"Stream error: {error_msg}")
282
 
283
  # Send error message to client
284
  error_data = {
@@ -363,13 +455,19 @@ class OllamaAPI:
363
  trace_exception(e)
364
  raise HTTPException(status_code=500, detail=str(e))
365
 
366
- @self.router.post("/chat", dependencies=[Depends(combined_auth)])
367
- async def chat(raw_request: Request, request: OllamaChatRequest):
 
 
368
  """Process chat completion requests acting as an Ollama model
369
  Routes user queries through LightRAG by selecting query mode based on prefix indicators.
370
  Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
 
371
  """
372
  try:
 
 
 
373
  # Get all messages
374
  messages = request.messages
375
  if not messages:
@@ -496,7 +594,7 @@ class OllamaAPI:
496
  else:
497
  error_msg = f"Provider error: {error_msg}"
498
 
499
- logging.error(f"Stream error: {error_msg}")
500
 
501
  # Send error message to client
502
  error_data = {
@@ -530,6 +628,11 @@ class OllamaAPI:
530
  data = {
531
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
532
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
 
 
 
 
533
  "done": True,
534
  "total_duration": total_time,
535
  "load_duration": 0,
 
1
  from fastapi import APIRouter, HTTPException, Request
2
  from pydantic import BaseModel
3
+ from typing import List, Dict, Any, Optional, Type
4
+ from lightrag.utils import logger
5
  import time
6
  import json
7
  import re
 
95
  models: List[OllamaModel]
96
 
97
 
98
+ class OllamaRunningModelDetails(BaseModel):
99
+ parent_model: str
100
+ format: str
101
+ family: str
102
+ families: List[str]
103
+ parameter_size: str
104
+ quantization_level: str
105
+
106
+
107
+ class OllamaRunningModel(BaseModel):
108
+ name: str
109
+ model: str
110
+ size: int
111
+ digest: str
112
+ details: OllamaRunningModelDetails
113
+ expires_at: str
114
+ size_vram: int
115
+
116
+
117
+ class OllamaPsResponse(BaseModel):
118
+ models: List[OllamaRunningModel]
119
+
120
+
121
+ async def parse_request_body(
122
+ request: Request, model_class: Type[BaseModel]
123
+ ) -> BaseModel:
124
+ """
125
+ Parse request body based on Content-Type header.
126
+ Supports both application/json and application/octet-stream.
127
+
128
+ Args:
129
+ request: The FastAPI Request object
130
+ model_class: The Pydantic model class to parse the request into
131
+
132
+ Returns:
133
+ An instance of the provided model_class
134
+ """
135
+ content_type = request.headers.get("content-type", "").lower()
136
+
137
+ try:
138
+ if content_type.startswith("application/json"):
139
+ # FastAPI already handles JSON parsing for us
140
+ body = await request.json()
141
+ elif content_type.startswith("application/octet-stream"):
142
+ # Manually parse octet-stream as JSON
143
+ body_bytes = await request.body()
144
+ body = json.loads(body_bytes.decode("utf-8"))
145
+ else:
146
+ # Try to parse as JSON for any other content type
147
+ body_bytes = await request.body()
148
+ body = json.loads(body_bytes.decode("utf-8"))
149
+
150
+ # Create an instance of the model
151
+ return model_class(**body)
152
+ except json.JSONDecodeError:
153
+ raise HTTPException(status_code=400, detail="Invalid JSON in request body")
154
+ except Exception as e:
155
+ raise HTTPException(
156
+ status_code=400, detail=f"Error parsing request body: {str(e)}"
157
+ )
158
+
159
+
160
  def estimate_tokens(text: str) -> int:
161
  """Estimate the number of tokens in text using tiktoken"""
162
  tokens = TiktokenTokenizer().encode(text)
 
259
  ]
260
  )
261
 
262
+ @self.router.get("/ps", dependencies=[Depends(combined_auth)])
263
+ async def get_running_models():
264
+ """List Running Models - returns currently running models"""
265
+ return OllamaPsResponse(
266
+ models=[
267
+ {
268
+ "name": self.ollama_server_infos.LIGHTRAG_MODEL,
269
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
270
+ "size": self.ollama_server_infos.LIGHTRAG_SIZE,
271
+ "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
272
+ "details": {
273
+ "parent_model": "",
274
+ "format": "gguf",
275
+ "family": "llama",
276
+ "families": ["llama"],
277
+ "parameter_size": "7.2B",
278
+ "quantization_level": "Q4_0",
279
+ },
280
+ "expires_at": "2050-12-31T14:38:31.83753-07:00",
281
+ "size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
282
+ }
283
+ ]
284
+ )
285
+
286
+ @self.router.post(
287
+ "/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
288
+ )
289
+ async def generate(raw_request: Request):
290
  """Handle generate completion requests acting as an Ollama model
291
  For compatibility purpose, the request is not processed by LightRAG,
292
  and will be handled by underlying LLM model.
293
+ Supports both application/json and application/octet-stream Content-Types.
294
  """
295
  try:
296
+ # Parse the request body manually
297
+ request = await parse_request_body(raw_request, OllamaGenerateRequest)
298
+
299
  query = request.prompt
300
  start_time = time.time_ns()
301
  prompt_tokens = estimate_tokens(query)
 
370
  else:
371
  error_msg = f"Provider error: {error_msg}"
372
 
373
+ logger.error(f"Stream error: {error_msg}")
374
 
375
  # Send error message to client
376
  error_data = {
 
455
  trace_exception(e)
456
  raise HTTPException(status_code=500, detail=str(e))
457
 
458
+ @self.router.post(
459
+ "/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
460
+ )
461
+ async def chat(raw_request: Request):
462
  """Process chat completion requests acting as an Ollama model
463
  Routes user queries through LightRAG by selecting query mode based on prefix indicators.
464
  Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
465
+ Supports both application/json and application/octet-stream Content-Types.
466
  """
467
  try:
468
+ # Parse the request body manually
469
+ request = await parse_request_body(raw_request, OllamaChatRequest)
470
+
471
  # Get all messages
472
  messages = request.messages
473
  if not messages:
 
594
  else:
595
  error_msg = f"Provider error: {error_msg}"
596
 
597
+ logger.error(f"Stream error: {error_msg}")
598
 
599
  # Send error message to client
600
  error_data = {
 
628
  data = {
629
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
630
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
631
+ "message": {
632
+ "role": "assistant",
633
+ "content": "",
634
+ "images": None,
635
+ },
636
  "done": True,
637
  "total_duration": total_time,
638
  "load_duration": 0,