Merge pull request #717 from danielaskdd/split-ollama-api-to-separated-file
Browse files- lightrag/api/lightrag_server.py +42 -566
- lightrag/api/ollama_api.py +574 -0
lightrag/api/lightrag_server.py
CHANGED
@@ -4,28 +4,20 @@ from fastapi import (
|
|
4 |
File,
|
5 |
UploadFile,
|
6 |
Form,
|
7 |
-
Request,
|
8 |
BackgroundTasks,
|
9 |
)
|
10 |
|
11 |
-
# Backend (Python)
|
12 |
-
# Add this to store progress globally
|
13 |
-
from typing import Dict
|
14 |
import threading
|
15 |
-
import asyncio
|
16 |
-
import json
|
17 |
import os
|
18 |
-
|
|
|
19 |
from fastapi.staticfiles import StaticFiles
|
20 |
-
from pydantic import BaseModel
|
21 |
import logging
|
22 |
import argparse
|
23 |
-
import
|
24 |
-
import
|
25 |
-
from typing import List, Any, Optional, Union
|
26 |
from lightrag import LightRAG, QueryParam
|
27 |
from lightrag.api import __api_version__
|
28 |
-
|
29 |
from lightrag.utils import EmbeddingFunc
|
30 |
from enum import Enum
|
31 |
from pathlib import Path
|
@@ -34,20 +26,32 @@ import aiofiles
|
|
34 |
from ascii_colors import trace_exception, ASCIIColors
|
35 |
import sys
|
36 |
import configparser
|
37 |
-
|
38 |
from fastapi import Depends, Security
|
39 |
from fastapi.security import APIKeyHeader
|
40 |
from fastapi.middleware.cors import CORSMiddleware
|
41 |
from contextlib import asynccontextmanager
|
42 |
-
|
43 |
from starlette.status import HTTP_403_FORBIDDEN
|
44 |
import pipmaster as pm
|
45 |
-
|
46 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Load environment variables
|
49 |
load_dotenv()
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Global progress tracker
|
52 |
scan_progress: Dict = {
|
53 |
"is_scanning": False,
|
@@ -76,24 +80,6 @@ def estimate_tokens(text: str) -> int:
|
|
76 |
return int(tokens)
|
77 |
|
78 |
|
79 |
-
class OllamaServerInfos:
|
80 |
-
# Constants for emulated Ollama model information
|
81 |
-
LIGHTRAG_NAME = "lightrag"
|
82 |
-
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
83 |
-
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
84 |
-
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
85 |
-
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
86 |
-
LIGHTRAG_DIGEST = "sha256:lightrag"
|
87 |
-
|
88 |
-
KV_STORAGE = "JsonKVStorage"
|
89 |
-
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
90 |
-
GRAPH_STORAGE = "NetworkXStorage"
|
91 |
-
VECTOR_STORAGE = "NanoVectorDBStorage"
|
92 |
-
|
93 |
-
|
94 |
-
# Add infos
|
95 |
-
ollama_server_infos = OllamaServerInfos()
|
96 |
-
|
97 |
# read config.ini
|
98 |
config = configparser.ConfigParser()
|
99 |
config.read("config.ini", "utf-8")
|
@@ -101,8 +87,8 @@ config.read("config.ini", "utf-8")
|
|
101 |
redis_uri = config.get("redis", "uri", fallback=None)
|
102 |
if redis_uri:
|
103 |
os.environ["REDIS_URI"] = redis_uri
|
104 |
-
|
105 |
-
|
106 |
|
107 |
# Neo4j config
|
108 |
neo4j_uri = config.get("neo4j", "uri", fallback=None)
|
@@ -112,7 +98,7 @@ if neo4j_uri:
|
|
112 |
os.environ["NEO4J_URI"] = neo4j_uri
|
113 |
os.environ["NEO4J_USERNAME"] = neo4j_username
|
114 |
os.environ["NEO4J_PASSWORD"] = neo4j_password
|
115 |
-
|
116 |
|
117 |
# Milvus config
|
118 |
milvus_uri = config.get("milvus", "uri", fallback=None)
|
@@ -124,7 +110,7 @@ if milvus_uri:
|
|
124 |
os.environ["MILVUS_USER"] = milvus_user
|
125 |
os.environ["MILVUS_PASSWORD"] = milvus_password
|
126 |
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
127 |
-
|
128 |
|
129 |
# MongoDB config
|
130 |
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
@@ -132,8 +118,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
|
|
132 |
if mongo_uri:
|
133 |
os.environ["MONGO_URI"] = mongo_uri
|
134 |
os.environ["MONGO_DATABASE"] = mongo_database
|
135 |
-
|
136 |
-
|
137 |
|
138 |
|
139 |
def get_default_host(binding_type: str) -> str:
|
@@ -535,6 +521,7 @@ def parse_args() -> argparse.Namespace:
|
|
535 |
help="Cosine similarity threshold (default: from env or 0.4)",
|
536 |
)
|
537 |
|
|
|
538 |
parser.add_argument(
|
539 |
"--simulated-model-name",
|
540 |
type=str,
|
@@ -599,83 +586,13 @@ class DocumentManager:
|
|
599 |
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
600 |
|
601 |
|
602 |
-
#
|
603 |
class SearchMode(str, Enum):
|
604 |
naive = "naive"
|
605 |
local = "local"
|
606 |
global_ = "global"
|
607 |
hybrid = "hybrid"
|
608 |
mix = "mix"
|
609 |
-
bypass = "bypass"
|
610 |
-
|
611 |
-
|
612 |
-
class OllamaMessage(BaseModel):
|
613 |
-
role: str
|
614 |
-
content: str
|
615 |
-
images: Optional[List[str]] = None
|
616 |
-
|
617 |
-
|
618 |
-
class OllamaChatRequest(BaseModel):
|
619 |
-
model: str = ollama_server_infos.LIGHTRAG_MODEL
|
620 |
-
messages: List[OllamaMessage]
|
621 |
-
stream: bool = True # Default to streaming mode
|
622 |
-
options: Optional[Dict[str, Any]] = None
|
623 |
-
system: Optional[str] = None
|
624 |
-
|
625 |
-
|
626 |
-
class OllamaChatResponse(BaseModel):
|
627 |
-
model: str
|
628 |
-
created_at: str
|
629 |
-
message: OllamaMessage
|
630 |
-
done: bool
|
631 |
-
|
632 |
-
|
633 |
-
class OllamaGenerateRequest(BaseModel):
|
634 |
-
model: str = ollama_server_infos.LIGHTRAG_MODEL
|
635 |
-
prompt: str
|
636 |
-
system: Optional[str] = None
|
637 |
-
stream: bool = False
|
638 |
-
options: Optional[Dict[str, Any]] = None
|
639 |
-
|
640 |
-
|
641 |
-
class OllamaGenerateResponse(BaseModel):
|
642 |
-
model: str
|
643 |
-
created_at: str
|
644 |
-
response: str
|
645 |
-
done: bool
|
646 |
-
context: Optional[List[int]]
|
647 |
-
total_duration: Optional[int]
|
648 |
-
load_duration: Optional[int]
|
649 |
-
prompt_eval_count: Optional[int]
|
650 |
-
prompt_eval_duration: Optional[int]
|
651 |
-
eval_count: Optional[int]
|
652 |
-
eval_duration: Optional[int]
|
653 |
-
|
654 |
-
|
655 |
-
class OllamaVersionResponse(BaseModel):
|
656 |
-
version: str
|
657 |
-
|
658 |
-
|
659 |
-
class OllamaModelDetails(BaseModel):
|
660 |
-
parent_model: str
|
661 |
-
format: str
|
662 |
-
family: str
|
663 |
-
families: List[str]
|
664 |
-
parameter_size: str
|
665 |
-
quantization_level: str
|
666 |
-
|
667 |
-
|
668 |
-
class OllamaModel(BaseModel):
|
669 |
-
name: str
|
670 |
-
model: str
|
671 |
-
size: int
|
672 |
-
digest: str
|
673 |
-
modified_at: str
|
674 |
-
details: OllamaModelDetails
|
675 |
-
|
676 |
-
|
677 |
-
class OllamaTagResponse(BaseModel):
|
678 |
-
models: List[OllamaModel]
|
679 |
|
680 |
|
681 |
class QueryRequest(BaseModel):
|
@@ -920,10 +837,10 @@ def create_app(args):
|
|
920 |
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
921 |
else {},
|
922 |
embedding_func=embedding_func,
|
923 |
-
kv_storage=
|
924 |
-
graph_storage=
|
925 |
-
vector_storage=
|
926 |
-
doc_status_storage=
|
927 |
vector_db_storage_cls_kwargs={
|
928 |
"cosine_better_than_threshold": args.cosine_threshold
|
929 |
},
|
@@ -949,10 +866,10 @@ def create_app(args):
|
|
949 |
llm_model_max_async=args.max_async,
|
950 |
llm_model_max_token_size=args.max_tokens,
|
951 |
embedding_func=embedding_func,
|
952 |
-
kv_storage=
|
953 |
-
graph_storage=
|
954 |
-
vector_storage=
|
955 |
-
doc_status_storage=
|
956 |
vector_db_storage_cls_kwargs={
|
957 |
"cosine_better_than_threshold": args.cosine_threshold
|
958 |
},
|
@@ -1475,450 +1392,9 @@ def create_app(args):
|
|
1475 |
async def get_graphs(label: str):
|
1476 |
return await rag.get_graps(nodel_label=label, max_depth=100)
|
1477 |
|
1478 |
-
# Ollama
|
1479 |
-
|
1480 |
-
|
1481 |
-
async def get_version():
|
1482 |
-
"""Get Ollama version information"""
|
1483 |
-
return OllamaVersionResponse(version="0.5.4")
|
1484 |
-
|
1485 |
-
@app.get("/api/tags")
|
1486 |
-
async def get_tags():
|
1487 |
-
"""Retrun available models acting as an Ollama server"""
|
1488 |
-
return OllamaTagResponse(
|
1489 |
-
models=[
|
1490 |
-
{
|
1491 |
-
"name": ollama_server_infos.LIGHTRAG_MODEL,
|
1492 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1493 |
-
"size": ollama_server_infos.LIGHTRAG_SIZE,
|
1494 |
-
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
|
1495 |
-
"modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1496 |
-
"details": {
|
1497 |
-
"parent_model": "",
|
1498 |
-
"format": "gguf",
|
1499 |
-
"family": ollama_server_infos.LIGHTRAG_NAME,
|
1500 |
-
"families": [ollama_server_infos.LIGHTRAG_NAME],
|
1501 |
-
"parameter_size": "13B",
|
1502 |
-
"quantization_level": "Q4_0",
|
1503 |
-
},
|
1504 |
-
}
|
1505 |
-
]
|
1506 |
-
)
|
1507 |
-
|
1508 |
-
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
1509 |
-
"""Parse query prefix to determine search mode
|
1510 |
-
Returns tuple of (cleaned_query, search_mode)
|
1511 |
-
"""
|
1512 |
-
mode_map = {
|
1513 |
-
"/local ": SearchMode.local,
|
1514 |
-
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
1515 |
-
"/naive ": SearchMode.naive,
|
1516 |
-
"/hybrid ": SearchMode.hybrid,
|
1517 |
-
"/mix ": SearchMode.mix,
|
1518 |
-
"/bypass ": SearchMode.bypass,
|
1519 |
-
}
|
1520 |
-
|
1521 |
-
for prefix, mode in mode_map.items():
|
1522 |
-
if query.startswith(prefix):
|
1523 |
-
# After removing prefix an leading spaces
|
1524 |
-
cleaned_query = query[len(prefix) :].lstrip()
|
1525 |
-
return cleaned_query, mode
|
1526 |
-
|
1527 |
-
return query, SearchMode.hybrid
|
1528 |
-
|
1529 |
-
@app.post("/api/generate")
|
1530 |
-
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
1531 |
-
"""Handle generate completion requests acting as an Ollama model
|
1532 |
-
For compatiblity purpuse, the request is not processed by LightRAG,
|
1533 |
-
and will be handled by underlying LLM model.
|
1534 |
-
"""
|
1535 |
-
try:
|
1536 |
-
query = request.prompt
|
1537 |
-
start_time = time.time_ns()
|
1538 |
-
prompt_tokens = estimate_tokens(query)
|
1539 |
-
|
1540 |
-
if request.system:
|
1541 |
-
rag.llm_model_kwargs["system_prompt"] = request.system
|
1542 |
-
|
1543 |
-
if request.stream:
|
1544 |
-
from fastapi.responses import StreamingResponse
|
1545 |
-
|
1546 |
-
response = await rag.llm_model_func(
|
1547 |
-
query, stream=True, **rag.llm_model_kwargs
|
1548 |
-
)
|
1549 |
-
|
1550 |
-
async def stream_generator():
|
1551 |
-
try:
|
1552 |
-
first_chunk_time = None
|
1553 |
-
last_chunk_time = None
|
1554 |
-
total_response = ""
|
1555 |
-
|
1556 |
-
# Ensure response is an async generator
|
1557 |
-
if isinstance(response, str):
|
1558 |
-
# If it's a string, send in two parts
|
1559 |
-
first_chunk_time = time.time_ns()
|
1560 |
-
last_chunk_time = first_chunk_time
|
1561 |
-
total_response = response
|
1562 |
-
|
1563 |
-
data = {
|
1564 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1565 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1566 |
-
"response": response,
|
1567 |
-
"done": False,
|
1568 |
-
}
|
1569 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1570 |
-
|
1571 |
-
completion_tokens = estimate_tokens(total_response)
|
1572 |
-
total_time = last_chunk_time - start_time
|
1573 |
-
prompt_eval_time = first_chunk_time - start_time
|
1574 |
-
eval_time = last_chunk_time - first_chunk_time
|
1575 |
-
|
1576 |
-
data = {
|
1577 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1578 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1579 |
-
"done": True,
|
1580 |
-
"total_duration": total_time,
|
1581 |
-
"load_duration": 0,
|
1582 |
-
"prompt_eval_count": prompt_tokens,
|
1583 |
-
"prompt_eval_duration": prompt_eval_time,
|
1584 |
-
"eval_count": completion_tokens,
|
1585 |
-
"eval_duration": eval_time,
|
1586 |
-
}
|
1587 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1588 |
-
else:
|
1589 |
-
async for chunk in response:
|
1590 |
-
if chunk:
|
1591 |
-
if first_chunk_time is None:
|
1592 |
-
first_chunk_time = time.time_ns()
|
1593 |
-
|
1594 |
-
last_chunk_time = time.time_ns()
|
1595 |
-
|
1596 |
-
total_response += chunk
|
1597 |
-
data = {
|
1598 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1599 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1600 |
-
"response": chunk,
|
1601 |
-
"done": False,
|
1602 |
-
}
|
1603 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1604 |
-
|
1605 |
-
completion_tokens = estimate_tokens(total_response)
|
1606 |
-
total_time = last_chunk_time - start_time
|
1607 |
-
prompt_eval_time = first_chunk_time - start_time
|
1608 |
-
eval_time = last_chunk_time - first_chunk_time
|
1609 |
-
|
1610 |
-
data = {
|
1611 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1612 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1613 |
-
"done": True,
|
1614 |
-
"total_duration": total_time,
|
1615 |
-
"load_duration": 0,
|
1616 |
-
"prompt_eval_count": prompt_tokens,
|
1617 |
-
"prompt_eval_duration": prompt_eval_time,
|
1618 |
-
"eval_count": completion_tokens,
|
1619 |
-
"eval_duration": eval_time,
|
1620 |
-
}
|
1621 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1622 |
-
return
|
1623 |
-
|
1624 |
-
except Exception as e:
|
1625 |
-
logging.error(f"Error in stream_generator: {str(e)}")
|
1626 |
-
raise
|
1627 |
-
|
1628 |
-
return StreamingResponse(
|
1629 |
-
stream_generator(),
|
1630 |
-
media_type="application/x-ndjson",
|
1631 |
-
headers={
|
1632 |
-
"Cache-Control": "no-cache",
|
1633 |
-
"Connection": "keep-alive",
|
1634 |
-
"Content-Type": "application/x-ndjson",
|
1635 |
-
"Access-Control-Allow-Origin": "*",
|
1636 |
-
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
1637 |
-
"Access-Control-Allow-Headers": "Content-Type",
|
1638 |
-
},
|
1639 |
-
)
|
1640 |
-
else:
|
1641 |
-
first_chunk_time = time.time_ns()
|
1642 |
-
response_text = await rag.llm_model_func(
|
1643 |
-
query, stream=False, **rag.llm_model_kwargs
|
1644 |
-
)
|
1645 |
-
last_chunk_time = time.time_ns()
|
1646 |
-
|
1647 |
-
if not response_text:
|
1648 |
-
response_text = "No response generated"
|
1649 |
-
|
1650 |
-
completion_tokens = estimate_tokens(str(response_text))
|
1651 |
-
total_time = last_chunk_time - start_time
|
1652 |
-
prompt_eval_time = first_chunk_time - start_time
|
1653 |
-
eval_time = last_chunk_time - first_chunk_time
|
1654 |
-
|
1655 |
-
return {
|
1656 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1657 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1658 |
-
"response": str(response_text),
|
1659 |
-
"done": True,
|
1660 |
-
"total_duration": total_time,
|
1661 |
-
"load_duration": 0,
|
1662 |
-
"prompt_eval_count": prompt_tokens,
|
1663 |
-
"prompt_eval_duration": prompt_eval_time,
|
1664 |
-
"eval_count": completion_tokens,
|
1665 |
-
"eval_duration": eval_time,
|
1666 |
-
}
|
1667 |
-
except Exception as e:
|
1668 |
-
trace_exception(e)
|
1669 |
-
raise HTTPException(status_code=500, detail=str(e))
|
1670 |
-
|
1671 |
-
@app.post("/api/chat")
|
1672 |
-
async def chat(raw_request: Request, request: OllamaChatRequest):
|
1673 |
-
"""Process chat completion requests acting as an Ollama model
|
1674 |
-
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
1675 |
-
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
1676 |
-
"""
|
1677 |
-
try:
|
1678 |
-
# Get all messages
|
1679 |
-
messages = request.messages
|
1680 |
-
if not messages:
|
1681 |
-
raise HTTPException(status_code=400, detail="No messages provided")
|
1682 |
-
|
1683 |
-
# Get the last message as query and previous messages as history
|
1684 |
-
query = messages[-1].content
|
1685 |
-
# Convert OllamaMessage objects to dictionaries
|
1686 |
-
conversation_history = [
|
1687 |
-
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
1688 |
-
]
|
1689 |
-
|
1690 |
-
# Check for query prefix
|
1691 |
-
cleaned_query, mode = parse_query_mode(query)
|
1692 |
-
|
1693 |
-
start_time = time.time_ns()
|
1694 |
-
prompt_tokens = estimate_tokens(cleaned_query)
|
1695 |
-
|
1696 |
-
param_dict = {
|
1697 |
-
"mode": mode,
|
1698 |
-
"stream": request.stream,
|
1699 |
-
"only_need_context": False,
|
1700 |
-
"conversation_history": conversation_history,
|
1701 |
-
"top_k": args.top_k,
|
1702 |
-
}
|
1703 |
-
|
1704 |
-
if args.history_turns is not None:
|
1705 |
-
param_dict["history_turns"] = args.history_turns
|
1706 |
-
|
1707 |
-
query_param = QueryParam(**param_dict)
|
1708 |
-
|
1709 |
-
if request.stream:
|
1710 |
-
from fastapi.responses import StreamingResponse
|
1711 |
-
|
1712 |
-
# Determine if the request is prefix with "/bypass"
|
1713 |
-
if mode == SearchMode.bypass:
|
1714 |
-
if request.system:
|
1715 |
-
rag.llm_model_kwargs["system_prompt"] = request.system
|
1716 |
-
response = await rag.llm_model_func(
|
1717 |
-
cleaned_query,
|
1718 |
-
stream=True,
|
1719 |
-
history_messages=conversation_history,
|
1720 |
-
**rag.llm_model_kwargs,
|
1721 |
-
)
|
1722 |
-
else:
|
1723 |
-
response = await rag.aquery( # Need await to get async generator
|
1724 |
-
cleaned_query, param=query_param
|
1725 |
-
)
|
1726 |
-
|
1727 |
-
async def stream_generator():
|
1728 |
-
first_chunk_time = None
|
1729 |
-
last_chunk_time = None
|
1730 |
-
total_response = ""
|
1731 |
-
|
1732 |
-
try:
|
1733 |
-
# Ensure response is an async generator
|
1734 |
-
if isinstance(response, str):
|
1735 |
-
# If it's a string, send in two parts
|
1736 |
-
first_chunk_time = time.time_ns()
|
1737 |
-
last_chunk_time = first_chunk_time
|
1738 |
-
total_response = response
|
1739 |
-
|
1740 |
-
data = {
|
1741 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1742 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1743 |
-
"message": {
|
1744 |
-
"role": "assistant",
|
1745 |
-
"content": response,
|
1746 |
-
"images": None,
|
1747 |
-
},
|
1748 |
-
"done": False,
|
1749 |
-
}
|
1750 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1751 |
-
|
1752 |
-
completion_tokens = estimate_tokens(total_response)
|
1753 |
-
total_time = last_chunk_time - start_time
|
1754 |
-
prompt_eval_time = first_chunk_time - start_time
|
1755 |
-
eval_time = last_chunk_time - first_chunk_time
|
1756 |
-
|
1757 |
-
data = {
|
1758 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1759 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1760 |
-
"done": True,
|
1761 |
-
"total_duration": total_time,
|
1762 |
-
"load_duration": 0,
|
1763 |
-
"prompt_eval_count": prompt_tokens,
|
1764 |
-
"prompt_eval_duration": prompt_eval_time,
|
1765 |
-
"eval_count": completion_tokens,
|
1766 |
-
"eval_duration": eval_time,
|
1767 |
-
}
|
1768 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1769 |
-
else:
|
1770 |
-
try:
|
1771 |
-
async for chunk in response:
|
1772 |
-
if chunk:
|
1773 |
-
if first_chunk_time is None:
|
1774 |
-
first_chunk_time = time.time_ns()
|
1775 |
-
|
1776 |
-
last_chunk_time = time.time_ns()
|
1777 |
-
|
1778 |
-
total_response += chunk
|
1779 |
-
data = {
|
1780 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1781 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1782 |
-
"message": {
|
1783 |
-
"role": "assistant",
|
1784 |
-
"content": chunk,
|
1785 |
-
"images": None,
|
1786 |
-
},
|
1787 |
-
"done": False,
|
1788 |
-
}
|
1789 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1790 |
-
except (asyncio.CancelledError, Exception) as e:
|
1791 |
-
error_msg = str(e)
|
1792 |
-
if isinstance(e, asyncio.CancelledError):
|
1793 |
-
error_msg = "Stream was cancelled by server"
|
1794 |
-
else:
|
1795 |
-
error_msg = f"Provider error: {error_msg}"
|
1796 |
-
|
1797 |
-
logging.error(f"Stream error: {error_msg}")
|
1798 |
-
|
1799 |
-
# Send error message to client
|
1800 |
-
error_data = {
|
1801 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1802 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1803 |
-
"message": {
|
1804 |
-
"role": "assistant",
|
1805 |
-
"content": f"\n\nError: {error_msg}",
|
1806 |
-
"images": None,
|
1807 |
-
},
|
1808 |
-
"done": False,
|
1809 |
-
}
|
1810 |
-
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
1811 |
-
|
1812 |
-
# Send final message to close the stream
|
1813 |
-
final_data = {
|
1814 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1815 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1816 |
-
"done": True,
|
1817 |
-
}
|
1818 |
-
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
1819 |
-
return
|
1820 |
-
|
1821 |
-
if last_chunk_time is not None:
|
1822 |
-
completion_tokens = estimate_tokens(total_response)
|
1823 |
-
total_time = last_chunk_time - start_time
|
1824 |
-
prompt_eval_time = first_chunk_time - start_time
|
1825 |
-
eval_time = last_chunk_time - first_chunk_time
|
1826 |
-
|
1827 |
-
data = {
|
1828 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1829 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1830 |
-
"done": True,
|
1831 |
-
"total_duration": total_time,
|
1832 |
-
"load_duration": 0,
|
1833 |
-
"prompt_eval_count": prompt_tokens,
|
1834 |
-
"prompt_eval_duration": prompt_eval_time,
|
1835 |
-
"eval_count": completion_tokens,
|
1836 |
-
"eval_duration": eval_time,
|
1837 |
-
}
|
1838 |
-
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1839 |
-
|
1840 |
-
except Exception as e:
|
1841 |
-
error_msg = f"Error in stream_generator: {str(e)}"
|
1842 |
-
logging.error(error_msg)
|
1843 |
-
|
1844 |
-
# Send error message to client
|
1845 |
-
error_data = {
|
1846 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1847 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1848 |
-
"error": {"code": "STREAM_ERROR", "message": error_msg},
|
1849 |
-
}
|
1850 |
-
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
1851 |
-
|
1852 |
-
# Ensure sending end marker
|
1853 |
-
final_data = {
|
1854 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1855 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1856 |
-
"done": True,
|
1857 |
-
}
|
1858 |
-
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
1859 |
-
return
|
1860 |
-
|
1861 |
-
return StreamingResponse(
|
1862 |
-
stream_generator(),
|
1863 |
-
media_type="application/x-ndjson",
|
1864 |
-
headers={
|
1865 |
-
"Cache-Control": "no-cache",
|
1866 |
-
"Connection": "keep-alive",
|
1867 |
-
"Content-Type": "application/x-ndjson",
|
1868 |
-
"Access-Control-Allow-Origin": "*",
|
1869 |
-
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
1870 |
-
"Access-Control-Allow-Headers": "Content-Type",
|
1871 |
-
},
|
1872 |
-
)
|
1873 |
-
else:
|
1874 |
-
first_chunk_time = time.time_ns()
|
1875 |
-
|
1876 |
-
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
1877 |
-
match_result = re.search(
|
1878 |
-
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
1879 |
-
)
|
1880 |
-
if match_result or mode == SearchMode.bypass:
|
1881 |
-
if request.system:
|
1882 |
-
rag.llm_model_kwargs["system_prompt"] = request.system
|
1883 |
-
|
1884 |
-
response_text = await rag.llm_model_func(
|
1885 |
-
cleaned_query,
|
1886 |
-
stream=False,
|
1887 |
-
history_messages=conversation_history,
|
1888 |
-
**rag.llm_model_kwargs,
|
1889 |
-
)
|
1890 |
-
else:
|
1891 |
-
response_text = await rag.aquery(cleaned_query, param=query_param)
|
1892 |
-
|
1893 |
-
last_chunk_time = time.time_ns()
|
1894 |
-
|
1895 |
-
if not response_text:
|
1896 |
-
response_text = "No response generated"
|
1897 |
-
|
1898 |
-
completion_tokens = estimate_tokens(str(response_text))
|
1899 |
-
total_time = last_chunk_time - start_time
|
1900 |
-
prompt_eval_time = first_chunk_time - start_time
|
1901 |
-
eval_time = last_chunk_time - first_chunk_time
|
1902 |
-
|
1903 |
-
return {
|
1904 |
-
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
1905 |
-
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
1906 |
-
"message": {
|
1907 |
-
"role": "assistant",
|
1908 |
-
"content": str(response_text),
|
1909 |
-
"images": None,
|
1910 |
-
},
|
1911 |
-
"done": True,
|
1912 |
-
"total_duration": total_time,
|
1913 |
-
"load_duration": 0,
|
1914 |
-
"prompt_eval_count": prompt_tokens,
|
1915 |
-
"prompt_eval_duration": prompt_eval_time,
|
1916 |
-
"eval_count": completion_tokens,
|
1917 |
-
"eval_duration": eval_time,
|
1918 |
-
}
|
1919 |
-
except Exception as e:
|
1920 |
-
trace_exception(e)
|
1921 |
-
raise HTTPException(status_code=500, detail=str(e))
|
1922 |
|
1923 |
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
1924 |
async def documents():
|
@@ -1945,10 +1421,10 @@ def create_app(args):
|
|
1945 |
"embedding_binding_host": args.embedding_binding_host,
|
1946 |
"embedding_model": args.embedding_model,
|
1947 |
"max_tokens": args.max_tokens,
|
1948 |
-
"kv_storage":
|
1949 |
-
"doc_status_storage":
|
1950 |
-
"graph_storage":
|
1951 |
-
"vector_storage":
|
1952 |
},
|
1953 |
}
|
1954 |
|
|
|
4 |
File,
|
5 |
UploadFile,
|
6 |
Form,
|
|
|
7 |
BackgroundTasks,
|
8 |
)
|
9 |
|
|
|
|
|
|
|
10 |
import threading
|
|
|
|
|
11 |
import os
|
12 |
+
import json
|
13 |
+
import re
|
14 |
from fastapi.staticfiles import StaticFiles
|
|
|
15 |
import logging
|
16 |
import argparse
|
17 |
+
from typing import List, Any, Optional, Union, Dict
|
18 |
+
from pydantic import BaseModel
|
|
|
19 |
from lightrag import LightRAG, QueryParam
|
20 |
from lightrag.api import __api_version__
|
|
|
21 |
from lightrag.utils import EmbeddingFunc
|
22 |
from enum import Enum
|
23 |
from pathlib import Path
|
|
|
26 |
from ascii_colors import trace_exception, ASCIIColors
|
27 |
import sys
|
28 |
import configparser
|
|
|
29 |
from fastapi import Depends, Security
|
30 |
from fastapi.security import APIKeyHeader
|
31 |
from fastapi.middleware.cors import CORSMiddleware
|
32 |
from contextlib import asynccontextmanager
|
|
|
33 |
from starlette.status import HTTP_403_FORBIDDEN
|
34 |
import pipmaster as pm
|
|
|
35 |
from dotenv import load_dotenv
|
36 |
+
from .ollama_api import (
|
37 |
+
OllamaAPI,
|
38 |
+
)
|
39 |
+
from .ollama_api import ollama_server_infos
|
40 |
|
41 |
# Load environment variables
|
42 |
load_dotenv()
|
43 |
|
44 |
+
|
45 |
+
class RAGStorageConfig:
|
46 |
+
KV_STORAGE = "JsonKVStorage"
|
47 |
+
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
48 |
+
GRAPH_STORAGE = "NetworkXStorage"
|
49 |
+
VECTOR_STORAGE = "NanoVectorDBStorage"
|
50 |
+
|
51 |
+
|
52 |
+
# Initialize rag storage config
|
53 |
+
rag_storage_config = RAGStorageConfig()
|
54 |
+
|
55 |
# Global progress tracker
|
56 |
scan_progress: Dict = {
|
57 |
"is_scanning": False,
|
|
|
80 |
return int(tokens)
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
# read config.ini
|
84 |
config = configparser.ConfigParser()
|
85 |
config.read("config.ini", "utf-8")
|
|
|
87 |
redis_uri = config.get("redis", "uri", fallback=None)
|
88 |
if redis_uri:
|
89 |
os.environ["REDIS_URI"] = redis_uri
|
90 |
+
rag_storage_config.KV_STORAGE = "RedisKVStorage"
|
91 |
+
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
|
92 |
|
93 |
# Neo4j config
|
94 |
neo4j_uri = config.get("neo4j", "uri", fallback=None)
|
|
|
98 |
os.environ["NEO4J_URI"] = neo4j_uri
|
99 |
os.environ["NEO4J_USERNAME"] = neo4j_username
|
100 |
os.environ["NEO4J_PASSWORD"] = neo4j_password
|
101 |
+
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
|
102 |
|
103 |
# Milvus config
|
104 |
milvus_uri = config.get("milvus", "uri", fallback=None)
|
|
|
110 |
os.environ["MILVUS_USER"] = milvus_user
|
111 |
os.environ["MILVUS_PASSWORD"] = milvus_password
|
112 |
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
113 |
+
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
114 |
|
115 |
# MongoDB config
|
116 |
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
|
|
118 |
if mongo_uri:
|
119 |
os.environ["MONGO_URI"] = mongo_uri
|
120 |
os.environ["MONGO_DATABASE"] = mongo_database
|
121 |
+
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
122 |
+
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
|
123 |
|
124 |
|
125 |
def get_default_host(binding_type: str) -> str:
|
|
|
521 |
help="Cosine similarity threshold (default: from env or 0.4)",
|
522 |
)
|
523 |
|
524 |
+
# Ollama model name
|
525 |
parser.add_argument(
|
526 |
"--simulated-model-name",
|
527 |
type=str,
|
|
|
586 |
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
587 |
|
588 |
|
589 |
+
# LightRAG query mode
|
590 |
class SearchMode(str, Enum):
|
591 |
naive = "naive"
|
592 |
local = "local"
|
593 |
global_ = "global"
|
594 |
hybrid = "hybrid"
|
595 |
mix = "mix"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
|
598 |
class QueryRequest(BaseModel):
|
|
|
837 |
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
838 |
else {},
|
839 |
embedding_func=embedding_func,
|
840 |
+
kv_storage=rag_storage_config.KV_STORAGE,
|
841 |
+
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
842 |
+
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
843 |
+
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
844 |
vector_db_storage_cls_kwargs={
|
845 |
"cosine_better_than_threshold": args.cosine_threshold
|
846 |
},
|
|
|
866 |
llm_model_max_async=args.max_async,
|
867 |
llm_model_max_token_size=args.max_tokens,
|
868 |
embedding_func=embedding_func,
|
869 |
+
kv_storage=rag_storage_config.KV_STORAGE,
|
870 |
+
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
871 |
+
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
872 |
+
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
873 |
vector_db_storage_cls_kwargs={
|
874 |
"cosine_better_than_threshold": args.cosine_threshold
|
875 |
},
|
|
|
1392 |
async def get_graphs(label: str):
|
1393 |
return await rag.get_graps(nodel_label=label, max_depth=100)
|
1394 |
|
1395 |
+
# Add Ollama API routes
|
1396 |
+
ollama_api = OllamaAPI(rag)
|
1397 |
+
app.include_router(ollama_api.router, prefix="/api")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1398 |
|
1399 |
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
1400 |
async def documents():
|
|
|
1421 |
"embedding_binding_host": args.embedding_binding_host,
|
1422 |
"embedding_model": args.embedding_model,
|
1423 |
"max_tokens": args.max_tokens,
|
1424 |
+
"kv_storage": rag_storage_config.KV_STORAGE,
|
1425 |
+
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
|
1426 |
+
"graph_storage": rag_storage_config.GRAPH_STORAGE,
|
1427 |
+
"vector_storage": rag_storage_config.VECTOR_STORAGE,
|
1428 |
},
|
1429 |
}
|
1430 |
|
lightrag/api/ollama_api.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Request
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List, Dict, Any, Optional
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
import os
|
9 |
+
from enum import Enum
|
10 |
+
from fastapi.responses import StreamingResponse
|
11 |
+
import asyncio
|
12 |
+
from ascii_colors import trace_exception
|
13 |
+
from lightrag import LightRAG, QueryParam
|
14 |
+
|
15 |
+
|
16 |
+
class OllamaServerInfos:
|
17 |
+
# Constants for emulated Ollama model information
|
18 |
+
LIGHTRAG_NAME = "lightrag"
|
19 |
+
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
20 |
+
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
21 |
+
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
22 |
+
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
23 |
+
LIGHTRAG_DIGEST = "sha256:lightrag"
|
24 |
+
|
25 |
+
|
26 |
+
ollama_server_infos = OllamaServerInfos()
|
27 |
+
|
28 |
+
|
29 |
+
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
30 |
+
class SearchMode(str, Enum):
|
31 |
+
naive = "naive"
|
32 |
+
local = "local"
|
33 |
+
global_ = "global"
|
34 |
+
hybrid = "hybrid"
|
35 |
+
mix = "mix"
|
36 |
+
bypass = "bypass"
|
37 |
+
|
38 |
+
|
39 |
+
class OllamaMessage(BaseModel):
|
40 |
+
role: str
|
41 |
+
content: str
|
42 |
+
images: Optional[List[str]] = None
|
43 |
+
|
44 |
+
|
45 |
+
class OllamaChatRequest(BaseModel):
|
46 |
+
model: str
|
47 |
+
messages: List[OllamaMessage]
|
48 |
+
stream: bool = True
|
49 |
+
options: Optional[Dict[str, Any]] = None
|
50 |
+
system: Optional[str] = None
|
51 |
+
|
52 |
+
|
53 |
+
class OllamaChatResponse(BaseModel):
|
54 |
+
model: str
|
55 |
+
created_at: str
|
56 |
+
message: OllamaMessage
|
57 |
+
done: bool
|
58 |
+
|
59 |
+
|
60 |
+
class OllamaGenerateRequest(BaseModel):
|
61 |
+
model: str
|
62 |
+
prompt: str
|
63 |
+
system: Optional[str] = None
|
64 |
+
stream: bool = False
|
65 |
+
options: Optional[Dict[str, Any]] = None
|
66 |
+
|
67 |
+
|
68 |
+
class OllamaGenerateResponse(BaseModel):
|
69 |
+
model: str
|
70 |
+
created_at: str
|
71 |
+
response: str
|
72 |
+
done: bool
|
73 |
+
context: Optional[List[int]]
|
74 |
+
total_duration: Optional[int]
|
75 |
+
load_duration: Optional[int]
|
76 |
+
prompt_eval_count: Optional[int]
|
77 |
+
prompt_eval_duration: Optional[int]
|
78 |
+
eval_count: Optional[int]
|
79 |
+
eval_duration: Optional[int]
|
80 |
+
|
81 |
+
|
82 |
+
class OllamaVersionResponse(BaseModel):
|
83 |
+
version: str
|
84 |
+
|
85 |
+
|
86 |
+
class OllamaModelDetails(BaseModel):
|
87 |
+
parent_model: str
|
88 |
+
format: str
|
89 |
+
family: str
|
90 |
+
families: List[str]
|
91 |
+
parameter_size: str
|
92 |
+
quantization_level: str
|
93 |
+
|
94 |
+
|
95 |
+
class OllamaModel(BaseModel):
|
96 |
+
name: str
|
97 |
+
model: str
|
98 |
+
size: int
|
99 |
+
digest: str
|
100 |
+
modified_at: str
|
101 |
+
details: OllamaModelDetails
|
102 |
+
|
103 |
+
|
104 |
+
class OllamaTagResponse(BaseModel):
|
105 |
+
models: List[OllamaModel]
|
106 |
+
|
107 |
+
|
108 |
+
def estimate_tokens(text: str) -> int:
|
109 |
+
"""Estimate the number of tokens in text
|
110 |
+
Chinese characters: approximately 1.5 tokens per character
|
111 |
+
English characters: approximately 0.25 tokens per character
|
112 |
+
"""
|
113 |
+
# Use regex to match Chinese and non-Chinese characters separately
|
114 |
+
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
115 |
+
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
|
116 |
+
|
117 |
+
# Calculate estimated token count
|
118 |
+
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
119 |
+
|
120 |
+
return int(tokens)
|
121 |
+
|
122 |
+
|
123 |
+
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
124 |
+
"""Parse query prefix to determine search mode
|
125 |
+
Returns tuple of (cleaned_query, search_mode)
|
126 |
+
"""
|
127 |
+
mode_map = {
|
128 |
+
"/local ": SearchMode.local,
|
129 |
+
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
130 |
+
"/naive ": SearchMode.naive,
|
131 |
+
"/hybrid ": SearchMode.hybrid,
|
132 |
+
"/mix ": SearchMode.mix,
|
133 |
+
"/bypass ": SearchMode.bypass,
|
134 |
+
}
|
135 |
+
|
136 |
+
for prefix, mode in mode_map.items():
|
137 |
+
if query.startswith(prefix):
|
138 |
+
# After removing prefix an leading spaces
|
139 |
+
cleaned_query = query[len(prefix) :].lstrip()
|
140 |
+
return cleaned_query, mode
|
141 |
+
|
142 |
+
return query, SearchMode.hybrid
|
143 |
+
|
144 |
+
|
145 |
+
class OllamaAPI:
|
146 |
+
def __init__(self, rag: LightRAG):
|
147 |
+
self.rag = rag
|
148 |
+
self.ollama_server_infos = ollama_server_infos
|
149 |
+
self.router = APIRouter()
|
150 |
+
self.setup_routes()
|
151 |
+
|
152 |
+
def setup_routes(self):
|
153 |
+
@self.router.get("/version")
|
154 |
+
async def get_version():
|
155 |
+
"""Get Ollama version information"""
|
156 |
+
return OllamaVersionResponse(version="0.5.4")
|
157 |
+
|
158 |
+
@self.router.get("/tags")
|
159 |
+
async def get_tags():
|
160 |
+
"""Return available models acting as an Ollama server"""
|
161 |
+
return OllamaTagResponse(
|
162 |
+
models=[
|
163 |
+
{
|
164 |
+
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
165 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
166 |
+
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
167 |
+
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
168 |
+
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
169 |
+
"details": {
|
170 |
+
"parent_model": "",
|
171 |
+
"format": "gguf",
|
172 |
+
"family": self.ollama_server_infos.LIGHTRAG_NAME,
|
173 |
+
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
|
174 |
+
"parameter_size": "13B",
|
175 |
+
"quantization_level": "Q4_0",
|
176 |
+
},
|
177 |
+
}
|
178 |
+
]
|
179 |
+
)
|
180 |
+
|
181 |
+
@self.router.post("/generate")
|
182 |
+
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
183 |
+
"""Handle generate completion requests acting as an Ollama model
|
184 |
+
For compatibility purpose, the request is not processed by LightRAG,
|
185 |
+
and will be handled by underlying LLM model.
|
186 |
+
"""
|
187 |
+
try:
|
188 |
+
query = request.prompt
|
189 |
+
start_time = time.time_ns()
|
190 |
+
prompt_tokens = estimate_tokens(query)
|
191 |
+
|
192 |
+
if request.system:
|
193 |
+
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
194 |
+
|
195 |
+
if request.stream:
|
196 |
+
response = await self.rag.llm_model_func(
|
197 |
+
query, stream=True, **self.rag.llm_model_kwargs
|
198 |
+
)
|
199 |
+
|
200 |
+
async def stream_generator():
|
201 |
+
try:
|
202 |
+
first_chunk_time = None
|
203 |
+
last_chunk_time = None
|
204 |
+
total_response = ""
|
205 |
+
|
206 |
+
# Ensure response is an async generator
|
207 |
+
if isinstance(response, str):
|
208 |
+
# If it's a string, send in two parts
|
209 |
+
first_chunk_time = time.time_ns()
|
210 |
+
last_chunk_time = first_chunk_time
|
211 |
+
total_response = response
|
212 |
+
|
213 |
+
data = {
|
214 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
215 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
216 |
+
"response": response,
|
217 |
+
"done": False,
|
218 |
+
}
|
219 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
220 |
+
|
221 |
+
completion_tokens = estimate_tokens(total_response)
|
222 |
+
total_time = last_chunk_time - start_time
|
223 |
+
prompt_eval_time = first_chunk_time - start_time
|
224 |
+
eval_time = last_chunk_time - first_chunk_time
|
225 |
+
|
226 |
+
data = {
|
227 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
228 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
229 |
+
"done": True,
|
230 |
+
"total_duration": total_time,
|
231 |
+
"load_duration": 0,
|
232 |
+
"prompt_eval_count": prompt_tokens,
|
233 |
+
"prompt_eval_duration": prompt_eval_time,
|
234 |
+
"eval_count": completion_tokens,
|
235 |
+
"eval_duration": eval_time,
|
236 |
+
}
|
237 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
238 |
+
else:
|
239 |
+
async for chunk in response:
|
240 |
+
if chunk:
|
241 |
+
if first_chunk_time is None:
|
242 |
+
first_chunk_time = time.time_ns()
|
243 |
+
|
244 |
+
last_chunk_time = time.time_ns()
|
245 |
+
|
246 |
+
total_response += chunk
|
247 |
+
data = {
|
248 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
249 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
250 |
+
"response": chunk,
|
251 |
+
"done": False,
|
252 |
+
}
|
253 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
254 |
+
|
255 |
+
completion_tokens = estimate_tokens(total_response)
|
256 |
+
total_time = last_chunk_time - start_time
|
257 |
+
prompt_eval_time = first_chunk_time - start_time
|
258 |
+
eval_time = last_chunk_time - first_chunk_time
|
259 |
+
|
260 |
+
data = {
|
261 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
262 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
263 |
+
"done": True,
|
264 |
+
"total_duration": total_time,
|
265 |
+
"load_duration": 0,
|
266 |
+
"prompt_eval_count": prompt_tokens,
|
267 |
+
"prompt_eval_duration": prompt_eval_time,
|
268 |
+
"eval_count": completion_tokens,
|
269 |
+
"eval_duration": eval_time,
|
270 |
+
}
|
271 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
272 |
+
return
|
273 |
+
|
274 |
+
except Exception as e:
|
275 |
+
trace_exception(e)
|
276 |
+
raise
|
277 |
+
|
278 |
+
return StreamingResponse(
|
279 |
+
stream_generator(),
|
280 |
+
media_type="application/x-ndjson",
|
281 |
+
headers={
|
282 |
+
"Cache-Control": "no-cache",
|
283 |
+
"Connection": "keep-alive",
|
284 |
+
"Content-Type": "application/x-ndjson",
|
285 |
+
"Access-Control-Allow-Origin": "*",
|
286 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
287 |
+
"Access-Control-Allow-Headers": "Content-Type",
|
288 |
+
},
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
first_chunk_time = time.time_ns()
|
292 |
+
response_text = await self.rag.llm_model_func(
|
293 |
+
query, stream=False, **self.rag.llm_model_kwargs
|
294 |
+
)
|
295 |
+
last_chunk_time = time.time_ns()
|
296 |
+
|
297 |
+
if not response_text:
|
298 |
+
response_text = "No response generated"
|
299 |
+
|
300 |
+
completion_tokens = estimate_tokens(str(response_text))
|
301 |
+
total_time = last_chunk_time - start_time
|
302 |
+
prompt_eval_time = first_chunk_time - start_time
|
303 |
+
eval_time = last_chunk_time - first_chunk_time
|
304 |
+
|
305 |
+
return {
|
306 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
307 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
308 |
+
"response": str(response_text),
|
309 |
+
"done": True,
|
310 |
+
"total_duration": total_time,
|
311 |
+
"load_duration": 0,
|
312 |
+
"prompt_eval_count": prompt_tokens,
|
313 |
+
"prompt_eval_duration": prompt_eval_time,
|
314 |
+
"eval_count": completion_tokens,
|
315 |
+
"eval_duration": eval_time,
|
316 |
+
}
|
317 |
+
except Exception as e:
|
318 |
+
trace_exception(e)
|
319 |
+
raise HTTPException(status_code=500, detail=str(e))
|
320 |
+
|
321 |
+
@self.router.post("/chat")
|
322 |
+
async def chat(raw_request: Request, request: OllamaChatRequest):
|
323 |
+
"""Process chat completion requests acting as an Ollama model
|
324 |
+
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
325 |
+
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
326 |
+
"""
|
327 |
+
try:
|
328 |
+
# Get all messages
|
329 |
+
messages = request.messages
|
330 |
+
if not messages:
|
331 |
+
raise HTTPException(status_code=400, detail="No messages provided")
|
332 |
+
|
333 |
+
# Get the last message as query and previous messages as history
|
334 |
+
query = messages[-1].content
|
335 |
+
# Convert OllamaMessage objects to dictionaries
|
336 |
+
conversation_history = [
|
337 |
+
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
338 |
+
]
|
339 |
+
|
340 |
+
# Check for query prefix
|
341 |
+
cleaned_query, mode = parse_query_mode(query)
|
342 |
+
|
343 |
+
start_time = time.time_ns()
|
344 |
+
prompt_tokens = estimate_tokens(cleaned_query)
|
345 |
+
|
346 |
+
param_dict = {
|
347 |
+
"mode": mode,
|
348 |
+
"stream": request.stream,
|
349 |
+
"only_need_context": False,
|
350 |
+
"conversation_history": conversation_history,
|
351 |
+
"top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
|
352 |
+
}
|
353 |
+
|
354 |
+
if (
|
355 |
+
hasattr(self.rag, "args")
|
356 |
+
and self.rag.args.history_turns is not None
|
357 |
+
):
|
358 |
+
param_dict["history_turns"] = self.rag.args.history_turns
|
359 |
+
|
360 |
+
query_param = QueryParam(**param_dict)
|
361 |
+
|
362 |
+
if request.stream:
|
363 |
+
# Determine if the request is prefix with "/bypass"
|
364 |
+
if mode == SearchMode.bypass:
|
365 |
+
if request.system:
|
366 |
+
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
367 |
+
response = await self.rag.llm_model_func(
|
368 |
+
cleaned_query,
|
369 |
+
stream=True,
|
370 |
+
history_messages=conversation_history,
|
371 |
+
**self.rag.llm_model_kwargs,
|
372 |
+
)
|
373 |
+
else:
|
374 |
+
response = await self.rag.aquery(
|
375 |
+
cleaned_query, param=query_param
|
376 |
+
)
|
377 |
+
|
378 |
+
async def stream_generator():
|
379 |
+
first_chunk_time = None
|
380 |
+
last_chunk_time = None
|
381 |
+
total_response = ""
|
382 |
+
|
383 |
+
try:
|
384 |
+
# Ensure response is an async generator
|
385 |
+
if isinstance(response, str):
|
386 |
+
# If it's a string, send in two parts
|
387 |
+
first_chunk_time = time.time_ns()
|
388 |
+
last_chunk_time = first_chunk_time
|
389 |
+
total_response = response
|
390 |
+
|
391 |
+
data = {
|
392 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
393 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
394 |
+
"message": {
|
395 |
+
"role": "assistant",
|
396 |
+
"content": response,
|
397 |
+
"images": None,
|
398 |
+
},
|
399 |
+
"done": False,
|
400 |
+
}
|
401 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
402 |
+
|
403 |
+
completion_tokens = estimate_tokens(total_response)
|
404 |
+
total_time = last_chunk_time - start_time
|
405 |
+
prompt_eval_time = first_chunk_time - start_time
|
406 |
+
eval_time = last_chunk_time - first_chunk_time
|
407 |
+
|
408 |
+
data = {
|
409 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
410 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
411 |
+
"done": True,
|
412 |
+
"total_duration": total_time,
|
413 |
+
"load_duration": 0,
|
414 |
+
"prompt_eval_count": prompt_tokens,
|
415 |
+
"prompt_eval_duration": prompt_eval_time,
|
416 |
+
"eval_count": completion_tokens,
|
417 |
+
"eval_duration": eval_time,
|
418 |
+
}
|
419 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
420 |
+
else:
|
421 |
+
try:
|
422 |
+
async for chunk in response:
|
423 |
+
if chunk:
|
424 |
+
if first_chunk_time is None:
|
425 |
+
first_chunk_time = time.time_ns()
|
426 |
+
|
427 |
+
last_chunk_time = time.time_ns()
|
428 |
+
|
429 |
+
total_response += chunk
|
430 |
+
data = {
|
431 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
432 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
433 |
+
"message": {
|
434 |
+
"role": "assistant",
|
435 |
+
"content": chunk,
|
436 |
+
"images": None,
|
437 |
+
},
|
438 |
+
"done": False,
|
439 |
+
}
|
440 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
441 |
+
except (asyncio.CancelledError, Exception) as e:
|
442 |
+
error_msg = str(e)
|
443 |
+
if isinstance(e, asyncio.CancelledError):
|
444 |
+
error_msg = "Stream was cancelled by server"
|
445 |
+
else:
|
446 |
+
error_msg = f"Provider error: {error_msg}"
|
447 |
+
|
448 |
+
logging.error(f"Stream error: {error_msg}")
|
449 |
+
|
450 |
+
# Send error message to client
|
451 |
+
error_data = {
|
452 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
453 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
454 |
+
"message": {
|
455 |
+
"role": "assistant",
|
456 |
+
"content": f"\n\nError: {error_msg}",
|
457 |
+
"images": None,
|
458 |
+
},
|
459 |
+
"done": False,
|
460 |
+
}
|
461 |
+
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
462 |
+
|
463 |
+
# Send final message to close the stream
|
464 |
+
final_data = {
|
465 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
466 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
467 |
+
"done": True,
|
468 |
+
}
|
469 |
+
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
470 |
+
return
|
471 |
+
|
472 |
+
if last_chunk_time is not None:
|
473 |
+
completion_tokens = estimate_tokens(total_response)
|
474 |
+
total_time = last_chunk_time - start_time
|
475 |
+
prompt_eval_time = first_chunk_time - start_time
|
476 |
+
eval_time = last_chunk_time - first_chunk_time
|
477 |
+
|
478 |
+
data = {
|
479 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
480 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
481 |
+
"done": True,
|
482 |
+
"total_duration": total_time,
|
483 |
+
"load_duration": 0,
|
484 |
+
"prompt_eval_count": prompt_tokens,
|
485 |
+
"prompt_eval_duration": prompt_eval_time,
|
486 |
+
"eval_count": completion_tokens,
|
487 |
+
"eval_duration": eval_time,
|
488 |
+
}
|
489 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
490 |
+
|
491 |
+
except Exception as e:
|
492 |
+
error_msg = f"Error in stream_generator: {str(e)}"
|
493 |
+
logging.error(error_msg)
|
494 |
+
|
495 |
+
# Send error message to client
|
496 |
+
error_data = {
|
497 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
498 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
499 |
+
"error": {"code": "STREAM_ERROR", "message": error_msg},
|
500 |
+
}
|
501 |
+
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
502 |
+
|
503 |
+
# Ensure sending end marker
|
504 |
+
final_data = {
|
505 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
506 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
507 |
+
"done": True,
|
508 |
+
}
|
509 |
+
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
510 |
+
return
|
511 |
+
|
512 |
+
return StreamingResponse(
|
513 |
+
stream_generator(),
|
514 |
+
media_type="application/x-ndjson",
|
515 |
+
headers={
|
516 |
+
"Cache-Control": "no-cache",
|
517 |
+
"Connection": "keep-alive",
|
518 |
+
"Content-Type": "application/x-ndjson",
|
519 |
+
"Access-Control-Allow-Origin": "*",
|
520 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
521 |
+
"Access-Control-Allow-Headers": "Content-Type",
|
522 |
+
},
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
first_chunk_time = time.time_ns()
|
526 |
+
|
527 |
+
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
528 |
+
match_result = re.search(
|
529 |
+
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
530 |
+
)
|
531 |
+
if match_result or mode == SearchMode.bypass:
|
532 |
+
if request.system:
|
533 |
+
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
534 |
+
|
535 |
+
response_text = await self.rag.llm_model_func(
|
536 |
+
cleaned_query,
|
537 |
+
stream=False,
|
538 |
+
history_messages=conversation_history,
|
539 |
+
**self.rag.llm_model_kwargs,
|
540 |
+
)
|
541 |
+
else:
|
542 |
+
response_text = await self.rag.aquery(
|
543 |
+
cleaned_query, param=query_param
|
544 |
+
)
|
545 |
+
|
546 |
+
last_chunk_time = time.time_ns()
|
547 |
+
|
548 |
+
if not response_text:
|
549 |
+
response_text = "No response generated"
|
550 |
+
|
551 |
+
completion_tokens = estimate_tokens(str(response_text))
|
552 |
+
total_time = last_chunk_time - start_time
|
553 |
+
prompt_eval_time = first_chunk_time - start_time
|
554 |
+
eval_time = last_chunk_time - first_chunk_time
|
555 |
+
|
556 |
+
return {
|
557 |
+
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
558 |
+
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
559 |
+
"message": {
|
560 |
+
"role": "assistant",
|
561 |
+
"content": str(response_text),
|
562 |
+
"images": None,
|
563 |
+
},
|
564 |
+
"done": True,
|
565 |
+
"total_duration": total_time,
|
566 |
+
"load_duration": 0,
|
567 |
+
"prompt_eval_count": prompt_tokens,
|
568 |
+
"prompt_eval_duration": prompt_eval_time,
|
569 |
+
"eval_count": completion_tokens,
|
570 |
+
"eval_duration": eval_time,
|
571 |
+
}
|
572 |
+
except Exception as e:
|
573 |
+
trace_exception(e)
|
574 |
+
raise HTTPException(status_code=500, detail=str(e))
|