Merge pull request #1673 from danielaskdd/litellm-problem
Browse filesfeat: Support `application/octet-stream` requests from LiteLLM clients for Ollama Emulation
- lightrag/api/__init__.py +1 -1
- lightrag/api/routers/ollama_api.py +111 -8
lightrag/api/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
__api_version__ = "
|
|
|
1 |
+
__api_version__ = "0173"
|
lightrag/api/routers/ollama_api.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from fastapi import APIRouter, HTTPException, Request
|
2 |
from pydantic import BaseModel
|
3 |
-
from typing import List, Dict, Any, Optional
|
4 |
-
import
|
5 |
import time
|
6 |
import json
|
7 |
import re
|
@@ -95,6 +95,68 @@ class OllamaTagResponse(BaseModel):
|
|
95 |
models: List[OllamaModel]
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def estimate_tokens(text: str) -> int:
|
99 |
"""Estimate the number of tokens in text using tiktoken"""
|
100 |
tokens = TiktokenTokenizer().encode(text)
|
@@ -197,13 +259,43 @@ class OllamaAPI:
|
|
197 |
]
|
198 |
)
|
199 |
|
200 |
-
@self.router.
|
201 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
"""Handle generate completion requests acting as an Ollama model
|
203 |
For compatibility purpose, the request is not processed by LightRAG,
|
204 |
and will be handled by underlying LLM model.
|
|
|
205 |
"""
|
206 |
try:
|
|
|
|
|
|
|
207 |
query = request.prompt
|
208 |
start_time = time.time_ns()
|
209 |
prompt_tokens = estimate_tokens(query)
|
@@ -278,7 +370,7 @@ class OllamaAPI:
|
|
278 |
else:
|
279 |
error_msg = f"Provider error: {error_msg}"
|
280 |
|
281 |
-
|
282 |
|
283 |
# Send error message to client
|
284 |
error_data = {
|
@@ -363,13 +455,19 @@ class OllamaAPI:
|
|
363 |
trace_exception(e)
|
364 |
raise HTTPException(status_code=500, detail=str(e))
|
365 |
|
366 |
-
@self.router.post(
|
367 |
-
|
|
|
|
|
368 |
"""Process chat completion requests acting as an Ollama model
|
369 |
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
370 |
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
|
|
371 |
"""
|
372 |
try:
|
|
|
|
|
|
|
373 |
# Get all messages
|
374 |
messages = request.messages
|
375 |
if not messages:
|
@@ -496,7 +594,7 @@ class OllamaAPI:
|
|
496 |
else:
|
497 |
error_msg = f"Provider error: {error_msg}"
|
498 |
|
499 |
-
|
500 |
|
501 |
# Send error message to client
|
502 |
error_data = {
|
@@ -530,6 +628,11 @@ class OllamaAPI:
|
|
530 |
data = {
|
531 |
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
532 |
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
|
|
|
|
|
|
|
|
|
533 |
"done": True,
|
534 |
"total_duration": total_time,
|
535 |
"load_duration": 0,
|
|
|
1 |
from fastapi import APIRouter, HTTPException, Request
|
2 |
from pydantic import BaseModel
|
3 |
+
from typing import List, Dict, Any, Optional, Type
|
4 |
+
from lightrag.utils import logger
|
5 |
import time
|
6 |
import json
|
7 |
import re
|
|
|
95 |
models: List[OllamaModel]
|
96 |
|
97 |
|
98 |
+
class OllamaRunningModelDetails(BaseModel):
|
99 |
+
parent_model: str
|
100 |
+
format: str
|
101 |
+
family: str
|
102 |
+
families: List[str]
|
103 |
+
parameter_size: str
|
104 |
+
quantization_level: str
|
105 |
+
|
106 |
+
|
107 |
+
class OllamaRunningModel(BaseModel):
|
108 |
+
name: str
|
109 |
+
model: str
|
110 |
+
size: int
|
111 |
+
digest: str
|
112 |
+
details: OllamaRunningModelDetails
|
113 |
+
expires_at: str
|
114 |
+
size_vram: int
|
115 |
+
|
116 |
+
|
117 |
+
class OllamaPsResponse(BaseModel):
|
118 |
+
models: List[OllamaRunningModel]
|
119 |
+
|
120 |
+
|
121 |
+
async def parse_request_body(
|
122 |
+
request: Request, model_class: Type[BaseModel]
|
123 |
+
) -> BaseModel:
|
124 |
+
"""
|
125 |
+
Parse request body based on Content-Type header.
|
126 |
+
Supports both application/json and application/octet-stream.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
request: The FastAPI Request object
|
130 |
+
model_class: The Pydantic model class to parse the request into
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
An instance of the provided model_class
|
134 |
+
"""
|
135 |
+
content_type = request.headers.get("content-type", "").lower()
|
136 |
+
|
137 |
+
try:
|
138 |
+
if content_type.startswith("application/json"):
|
139 |
+
# FastAPI already handles JSON parsing for us
|
140 |
+
body = await request.json()
|
141 |
+
elif content_type.startswith("application/octet-stream"):
|
142 |
+
# Manually parse octet-stream as JSON
|
143 |
+
body_bytes = await request.body()
|
144 |
+
body = json.loads(body_bytes.decode("utf-8"))
|
145 |
+
else:
|
146 |
+
# Try to parse as JSON for any other content type
|
147 |
+
body_bytes = await request.body()
|
148 |
+
body = json.loads(body_bytes.decode("utf-8"))
|
149 |
+
|
150 |
+
# Create an instance of the model
|
151 |
+
return model_class(**body)
|
152 |
+
except json.JSONDecodeError:
|
153 |
+
raise HTTPException(status_code=400, detail="Invalid JSON in request body")
|
154 |
+
except Exception as e:
|
155 |
+
raise HTTPException(
|
156 |
+
status_code=400, detail=f"Error parsing request body: {str(e)}"
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
def estimate_tokens(text: str) -> int:
|
161 |
"""Estimate the number of tokens in text using tiktoken"""
|
162 |
tokens = TiktokenTokenizer().encode(text)
|
|
|
259 |
]
|
260 |
)
|
261 |
|
262 |
+
@self.router.get("/ps", dependencies=[Depends(combined_auth)])
|
263 |
+
async def get_running_models():
|
264 |
+
"""List Running Models - returns currently running models"""
|
265 |
+
return OllamaPsResponse(
|
266 |
+
models=[
|
267 |
+
{
|
268 |
+
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
269 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
270 |
+
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
271 |
+
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
272 |
+
"details": {
|
273 |
+
"parent_model": "",
|
274 |
+
"format": "gguf",
|
275 |
+
"family": "llama",
|
276 |
+
"families": ["llama"],
|
277 |
+
"parameter_size": "7.2B",
|
278 |
+
"quantization_level": "Q4_0",
|
279 |
+
},
|
280 |
+
"expires_at": "2050-12-31T14:38:31.83753-07:00",
|
281 |
+
"size_vram": self.ollama_server_infos.LIGHTRAG_SIZE,
|
282 |
+
}
|
283 |
+
]
|
284 |
+
)
|
285 |
+
|
286 |
+
@self.router.post(
|
287 |
+
"/generate", dependencies=[Depends(combined_auth)], include_in_schema=True
|
288 |
+
)
|
289 |
+
async def generate(raw_request: Request):
|
290 |
"""Handle generate completion requests acting as an Ollama model
|
291 |
For compatibility purpose, the request is not processed by LightRAG,
|
292 |
and will be handled by underlying LLM model.
|
293 |
+
Supports both application/json and application/octet-stream Content-Types.
|
294 |
"""
|
295 |
try:
|
296 |
+
# Parse the request body manually
|
297 |
+
request = await parse_request_body(raw_request, OllamaGenerateRequest)
|
298 |
+
|
299 |
query = request.prompt
|
300 |
start_time = time.time_ns()
|
301 |
prompt_tokens = estimate_tokens(query)
|
|
|
370 |
else:
|
371 |
error_msg = f"Provider error: {error_msg}"
|
372 |
|
373 |
+
logger.error(f"Stream error: {error_msg}")
|
374 |
|
375 |
# Send error message to client
|
376 |
error_data = {
|
|
|
455 |
trace_exception(e)
|
456 |
raise HTTPException(status_code=500, detail=str(e))
|
457 |
|
458 |
+
@self.router.post(
|
459 |
+
"/chat", dependencies=[Depends(combined_auth)], include_in_schema=True
|
460 |
+
)
|
461 |
+
async def chat(raw_request: Request):
|
462 |
"""Process chat completion requests acting as an Ollama model
|
463 |
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
464 |
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
465 |
+
Supports both application/json and application/octet-stream Content-Types.
|
466 |
"""
|
467 |
try:
|
468 |
+
# Parse the request body manually
|
469 |
+
request = await parse_request_body(raw_request, OllamaChatRequest)
|
470 |
+
|
471 |
# Get all messages
|
472 |
messages = request.messages
|
473 |
if not messages:
|
|
|
594 |
else:
|
595 |
error_msg = f"Provider error: {error_msg}"
|
596 |
|
597 |
+
logger.error(f"Stream error: {error_msg}")
|
598 |
|
599 |
# Send error message to client
|
600 |
error_data = {
|
|
|
628 |
data = {
|
629 |
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
630 |
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
631 |
+
"message": {
|
632 |
+
"role": "assistant",
|
633 |
+
"content": "",
|
634 |
+
"images": None,
|
635 |
+
},
|
636 |
"done": True,
|
637 |
"total_duration": total_time,
|
638 |
"load_duration": 0,
|