yangdx
commited on
Commit
·
41e44a6
1
Parent(s):
2ef219d
完成ollma接口的代码编写
Browse files- lightrag/api/lightrag_ollama.py +128 -7
- setup.py +1 -0
lightrag/api/lightrag_ollama.py
CHANGED
@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
|
2 |
from pydantic import BaseModel
|
3 |
import logging
|
4 |
import argparse
|
|
|
5 |
from lightrag import LightRAG, QueryParam
|
6 |
-
# from lightrag.llm import lollms_model_complete, lollms_embed
|
7 |
-
# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
|
8 |
from lightrag.llm import openai_complete_if_cache, ollama_embedding
|
9 |
-
# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
|
10 |
|
11 |
from lightrag.utils import EmbeddingFunc
|
12 |
-
from typing import Optional, List
|
13 |
from enum import Enum
|
14 |
from pathlib import Path
|
15 |
import shutil
|
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
|
|
26 |
from dotenv import load_dotenv
|
27 |
load_dotenv()
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
async def llm_model_func(
|
30 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
31 |
) -> str:
|
@@ -219,21 +223,43 @@ class DocumentManager:
|
|
219 |
class SearchMode(str, Enum):
|
220 |
naive = "naive"
|
221 |
local = "local"
|
222 |
-
global_ = "global"
|
223 |
hybrid = "hybrid"
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
class QueryRequest(BaseModel):
|
227 |
query: str
|
228 |
mode: SearchMode = SearchMode.hybrid
|
229 |
stream: bool = False
|
230 |
only_need_context: bool = False
|
231 |
|
232 |
-
|
233 |
class QueryResponse(BaseModel):
|
234 |
response: str
|
235 |
|
236 |
-
|
237 |
class InsertTextRequest(BaseModel):
|
238 |
text: str
|
239 |
description: Optional[str] = None
|
@@ -555,6 +581,101 @@ def create_app(args):
|
|
555 |
except Exception as e:
|
556 |
raise HTTPException(status_code=500, detail=str(e))
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
559 |
async def get_status():
|
560 |
"""Get current system status"""
|
|
|
2 |
from pydantic import BaseModel
|
3 |
import logging
|
4 |
import argparse
|
5 |
+
from typing import List, Dict, Any, Optional
|
6 |
from lightrag import LightRAG, QueryParam
|
|
|
|
|
7 |
from lightrag.llm import openai_complete_if_cache, ollama_embedding
|
|
|
8 |
|
9 |
from lightrag.utils import EmbeddingFunc
|
|
|
10 |
from enum import Enum
|
11 |
from pathlib import Path
|
12 |
import shutil
|
|
|
23 |
from dotenv import load_dotenv
|
24 |
load_dotenv()
|
25 |
|
26 |
+
# Constants for model information
|
27 |
+
LIGHTRAG_NAME = "lightrag"
|
28 |
+
LIGHTRAG_TAG = "latest"
|
29 |
+
LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
30 |
+
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
31 |
+
LIGHTRAG_DIGEST = "sha256:lightrag"
|
32 |
+
|
33 |
async def llm_model_func(
|
34 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
35 |
) -> str:
|
|
|
223 |
class SearchMode(str, Enum):
|
224 |
naive = "naive"
|
225 |
local = "local"
|
226 |
+
global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
|
227 |
hybrid = "hybrid"
|
228 |
|
229 |
+
# Ollama API compatible models
|
230 |
+
class OllamaMessage(BaseModel):
|
231 |
+
role: str
|
232 |
+
content: str
|
233 |
+
|
234 |
+
class OllamaChatRequest(BaseModel):
|
235 |
+
model: str = LIGHTRAG_MODEL
|
236 |
+
messages: List[OllamaMessage]
|
237 |
+
stream: bool = False
|
238 |
+
options: Optional[Dict[str, Any]] = None
|
239 |
+
|
240 |
+
class OllamaChatResponse(BaseModel):
|
241 |
+
model: str
|
242 |
+
created_at: str
|
243 |
+
message: OllamaMessage
|
244 |
+
done: bool
|
245 |
|
246 |
+
class OllamaVersionResponse(BaseModel):
|
247 |
+
version: str
|
248 |
+
build: str = "default"
|
249 |
+
|
250 |
+
class OllamaTagResponse(BaseModel):
|
251 |
+
models: List[Dict[str, str]]
|
252 |
+
|
253 |
+
# Original LightRAG models
|
254 |
class QueryRequest(BaseModel):
|
255 |
query: str
|
256 |
mode: SearchMode = SearchMode.hybrid
|
257 |
stream: bool = False
|
258 |
only_need_context: bool = False
|
259 |
|
|
|
260 |
class QueryResponse(BaseModel):
|
261 |
response: str
|
262 |
|
|
|
263 |
class InsertTextRequest(BaseModel):
|
264 |
text: str
|
265 |
description: Optional[str] = None
|
|
|
581 |
except Exception as e:
|
582 |
raise HTTPException(status_code=500, detail=str(e))
|
583 |
|
584 |
+
# Ollama compatible API endpoints
|
585 |
+
@app.get("/api/version")
|
586 |
+
async def get_version():
|
587 |
+
"""Get Ollama version information"""
|
588 |
+
return OllamaVersionResponse(
|
589 |
+
version="0.1.0"
|
590 |
+
)
|
591 |
+
|
592 |
+
@app.get("/api/tags")
|
593 |
+
async def get_tags():
|
594 |
+
"""Get available models"""
|
595 |
+
return OllamaTagResponse(
|
596 |
+
models=[{
|
597 |
+
"name": LIGHTRAG_NAME,
|
598 |
+
"tag": LIGHTRAG_TAG,
|
599 |
+
"size": 0,
|
600 |
+
"digest": LIGHTRAG_DIGEST,
|
601 |
+
"modified_at": LIGHTRAG_CREATED_AT
|
602 |
+
}]
|
603 |
+
)
|
604 |
+
|
605 |
+
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
606 |
+
"""Parse query prefix to determine search mode
|
607 |
+
Returns tuple of (cleaned_query, search_mode)
|
608 |
+
"""
|
609 |
+
mode_map = {
|
610 |
+
"/local ": SearchMode.local,
|
611 |
+
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
612 |
+
"/naive ": SearchMode.naive,
|
613 |
+
"/hybrid ": SearchMode.hybrid
|
614 |
+
}
|
615 |
+
|
616 |
+
for prefix, mode in mode_map.items():
|
617 |
+
if query.startswith(prefix):
|
618 |
+
return query[len(prefix):], mode
|
619 |
+
|
620 |
+
return query, SearchMode.hybrid
|
621 |
+
|
622 |
+
@app.post("/api/chat")
|
623 |
+
async def chat(request: OllamaChatRequest):
|
624 |
+
"""Handle chat completion requests"""
|
625 |
+
try:
|
626 |
+
# Convert chat format to query
|
627 |
+
query = request.messages[-1].content if request.messages else ""
|
628 |
+
|
629 |
+
# Parse query mode and clean query
|
630 |
+
cleaned_query, mode = parse_query_mode(query)
|
631 |
+
|
632 |
+
# Call RAG with determined mode
|
633 |
+
response = await rag.aquery(
|
634 |
+
cleaned_query,
|
635 |
+
param=QueryParam(
|
636 |
+
mode=mode,
|
637 |
+
stream=request.stream
|
638 |
+
)
|
639 |
+
)
|
640 |
+
|
641 |
+
if request.stream:
|
642 |
+
async def stream_generator():
|
643 |
+
result = ""
|
644 |
+
async for chunk in response:
|
645 |
+
result += chunk
|
646 |
+
yield OllamaChatResponse(
|
647 |
+
model=LIGHTRAG_MODEL,
|
648 |
+
created_at=LIGHTRAG_CREATED_AT,
|
649 |
+
message=OllamaMessage(
|
650 |
+
role="assistant",
|
651 |
+
content=chunk
|
652 |
+
),
|
653 |
+
done=False
|
654 |
+
)
|
655 |
+
# Send final message
|
656 |
+
yield OllamaChatResponse(
|
657 |
+
model=LIGHTRAG_MODEL,
|
658 |
+
created_at=LIGHTRAG_CREATED_AT,
|
659 |
+
message=OllamaMessage(
|
660 |
+
role="assistant",
|
661 |
+
content=result
|
662 |
+
),
|
663 |
+
done=True
|
664 |
+
)
|
665 |
+
return stream_generator()
|
666 |
+
else:
|
667 |
+
return OllamaChatResponse(
|
668 |
+
model=LIGHTRAG_MODEL,
|
669 |
+
created_at=LIGHTRAG_CREATED_AT,
|
670 |
+
message=OllamaMessage(
|
671 |
+
role="assistant",
|
672 |
+
content=response
|
673 |
+
),
|
674 |
+
done=True
|
675 |
+
)
|
676 |
+
except Exception as e:
|
677 |
+
raise HTTPException(status_code=500, detail=str(e))
|
678 |
+
|
679 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
680 |
async def get_status():
|
681 |
"""Get current system status"""
|
setup.py
CHANGED
@@ -101,6 +101,7 @@ setuptools.setup(
|
|
101 |
entry_points={
|
102 |
"console_scripts": [
|
103 |
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
|
|
104 |
],
|
105 |
},
|
106 |
)
|
|
|
101 |
entry_points={
|
102 |
"console_scripts": [
|
103 |
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
104 |
+
"lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
|
105 |
],
|
106 |
},
|
107 |
)
|