yangdx commited on
Commit
efd74df
·
1 Parent(s): 79d6809

Split the Ollama API implementation to a separated file

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -4,28 +4,20 @@ from fastapi import (
4
  File,
5
  UploadFile,
6
  Form,
7
- Request,
8
  BackgroundTasks,
9
  )
10
 
11
- # Backend (Python)
12
- # Add this to store progress globally
13
- from typing import Dict
14
  import threading
15
- import asyncio
16
- import json
17
  import os
18
-
 
19
  from fastapi.staticfiles import StaticFiles
20
- from pydantic import BaseModel
21
  import logging
22
  import argparse
23
- import time
24
- import re
25
- from typing import List, Any, Optional, Union
26
  from lightrag import LightRAG, QueryParam
27
  from lightrag.api import __api_version__
28
-
29
  from lightrag.utils import EmbeddingFunc
30
  from enum import Enum
31
  from pathlib import Path
@@ -34,20 +26,30 @@ import aiofiles
34
  from ascii_colors import trace_exception, ASCIIColors
35
  import sys
36
  import configparser
37
-
38
  from fastapi import Depends, Security
39
  from fastapi.security import APIKeyHeader
40
  from fastapi.middleware.cors import CORSMiddleware
41
  from contextlib import asynccontextmanager
42
-
43
  from starlette.status import HTTP_403_FORBIDDEN
44
  import pipmaster as pm
45
-
46
  from dotenv import load_dotenv
 
 
 
 
47
 
48
  # Load environment variables
49
  load_dotenv()
50
 
 
 
 
 
 
 
 
 
 
51
  # Global progress tracker
52
  scan_progress: Dict = {
53
  "is_scanning": False,
@@ -76,24 +78,6 @@ def estimate_tokens(text: str) -> int:
76
  return int(tokens)
77
 
78
 
79
- class OllamaServerInfos:
80
- # Constants for emulated Ollama model information
81
- LIGHTRAG_NAME = "lightrag"
82
- LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
83
- LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
84
- LIGHTRAG_SIZE = 7365960935 # it's a dummy value
85
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
86
- LIGHTRAG_DIGEST = "sha256:lightrag"
87
-
88
- KV_STORAGE = "JsonKVStorage"
89
- DOC_STATUS_STORAGE = "JsonDocStatusStorage"
90
- GRAPH_STORAGE = "NetworkXStorage"
91
- VECTOR_STORAGE = "NanoVectorDBStorage"
92
-
93
-
94
- # Add infos
95
- ollama_server_infos = OllamaServerInfos()
96
-
97
  # read config.ini
98
  config = configparser.ConfigParser()
99
  config.read("config.ini", "utf-8")
@@ -101,8 +85,8 @@ config.read("config.ini", "utf-8")
101
  redis_uri = config.get("redis", "uri", fallback=None)
102
  if redis_uri:
103
  os.environ["REDIS_URI"] = redis_uri
104
- ollama_server_infos.KV_STORAGE = "RedisKVStorage"
105
- ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage"
106
 
107
  # Neo4j config
108
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
@@ -112,7 +96,7 @@ if neo4j_uri:
112
  os.environ["NEO4J_URI"] = neo4j_uri
113
  os.environ["NEO4J_USERNAME"] = neo4j_username
114
  os.environ["NEO4J_PASSWORD"] = neo4j_password
115
- ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage"
116
 
117
  # Milvus config
118
  milvus_uri = config.get("milvus", "uri", fallback=None)
@@ -124,7 +108,7 @@ if milvus_uri:
124
  os.environ["MILVUS_USER"] = milvus_user
125
  os.environ["MILVUS_PASSWORD"] = milvus_password
126
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
127
- ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge"
128
 
129
  # MongoDB config
130
  mongo_uri = config.get("mongodb", "uri", fallback=None)
@@ -132,8 +116,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
132
  if mongo_uri:
133
  os.environ["MONGO_URI"] = mongo_uri
134
  os.environ["MONGO_DATABASE"] = mongo_database
135
- ollama_server_infos.KV_STORAGE = "MongoKVStorage"
136
- ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage"
137
 
138
 
139
  def get_default_host(binding_type: str) -> str:
@@ -535,6 +519,7 @@ def parse_args() -> argparse.Namespace:
535
  help="Cosine similarity threshold (default: from env or 0.4)",
536
  )
537
 
 
538
  parser.add_argument(
539
  "--simulated-model-name",
540
  type=str,
@@ -599,84 +584,13 @@ class DocumentManager:
599
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
600
 
601
 
602
- # Pydantic models
603
  class SearchMode(str, Enum):
604
  naive = "naive"
605
  local = "local"
606
  global_ = "global"
607
  hybrid = "hybrid"
608
  mix = "mix"
609
- bypass = "bypass"
610
-
611
-
612
- class OllamaMessage(BaseModel):
613
- role: str
614
- content: str
615
- images: Optional[List[str]] = None
616
-
617
-
618
- class OllamaChatRequest(BaseModel):
619
- model: str = ollama_server_infos.LIGHTRAG_MODEL
620
- messages: List[OllamaMessage]
621
- stream: bool = True # Default to streaming mode
622
- options: Optional[Dict[str, Any]] = None
623
- system: Optional[str] = None
624
-
625
-
626
- class OllamaChatResponse(BaseModel):
627
- model: str
628
- created_at: str
629
- message: OllamaMessage
630
- done: bool
631
-
632
-
633
- class OllamaGenerateRequest(BaseModel):
634
- model: str = ollama_server_infos.LIGHTRAG_MODEL
635
- prompt: str
636
- system: Optional[str] = None
637
- stream: bool = False
638
- options: Optional[Dict[str, Any]] = None
639
-
640
-
641
- class OllamaGenerateResponse(BaseModel):
642
- model: str
643
- created_at: str
644
- response: str
645
- done: bool
646
- context: Optional[List[int]]
647
- total_duration: Optional[int]
648
- load_duration: Optional[int]
649
- prompt_eval_count: Optional[int]
650
- prompt_eval_duration: Optional[int]
651
- eval_count: Optional[int]
652
- eval_duration: Optional[int]
653
-
654
-
655
- class OllamaVersionResponse(BaseModel):
656
- version: str
657
-
658
-
659
- class OllamaModelDetails(BaseModel):
660
- parent_model: str
661
- format: str
662
- family: str
663
- families: List[str]
664
- parameter_size: str
665
- quantization_level: str
666
-
667
-
668
- class OllamaModel(BaseModel):
669
- name: str
670
- model: str
671
- size: int
672
- digest: str
673
- modified_at: str
674
- details: OllamaModelDetails
675
-
676
-
677
- class OllamaTagResponse(BaseModel):
678
- models: List[OllamaModel]
679
-
680
 
681
  class QueryRequest(BaseModel):
682
  query: str
@@ -920,10 +834,10 @@ def create_app(args):
920
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
921
  else {},
922
  embedding_func=embedding_func,
923
- kv_storage=ollama_server_infos.KV_STORAGE,
924
- graph_storage=ollama_server_infos.GRAPH_STORAGE,
925
- vector_storage=ollama_server_infos.VECTOR_STORAGE,
926
- doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
927
  vector_db_storage_cls_kwargs={
928
  "cosine_better_than_threshold": args.cosine_threshold
929
  },
@@ -949,10 +863,10 @@ def create_app(args):
949
  llm_model_max_async=args.max_async,
950
  llm_model_max_token_size=args.max_tokens,
951
  embedding_func=embedding_func,
952
- kv_storage=ollama_server_infos.KV_STORAGE,
953
- graph_storage=ollama_server_infos.GRAPH_STORAGE,
954
- vector_storage=ollama_server_infos.VECTOR_STORAGE,
955
- doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
956
  vector_db_storage_cls_kwargs={
957
  "cosine_better_than_threshold": args.cosine_threshold
958
  },
@@ -1475,450 +1389,9 @@ def create_app(args):
1475
  async def get_graphs(label: str):
1476
  return await rag.get_graps(nodel_label=label, max_depth=100)
1477
 
1478
- # Ollama compatible API endpoints
1479
- # -------------------------------------------------
1480
- @app.get("/api/version")
1481
- async def get_version():
1482
- """Get Ollama version information"""
1483
- return OllamaVersionResponse(version="0.5.4")
1484
-
1485
- @app.get("/api/tags")
1486
- async def get_tags():
1487
- """Retrun available models acting as an Ollama server"""
1488
- return OllamaTagResponse(
1489
- models=[
1490
- {
1491
- "name": ollama_server_infos.LIGHTRAG_MODEL,
1492
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1493
- "size": ollama_server_infos.LIGHTRAG_SIZE,
1494
- "digest": ollama_server_infos.LIGHTRAG_DIGEST,
1495
- "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1496
- "details": {
1497
- "parent_model": "",
1498
- "format": "gguf",
1499
- "family": ollama_server_infos.LIGHTRAG_NAME,
1500
- "families": [ollama_server_infos.LIGHTRAG_NAME],
1501
- "parameter_size": "13B",
1502
- "quantization_level": "Q4_0",
1503
- },
1504
- }
1505
- ]
1506
- )
1507
-
1508
- def parse_query_mode(query: str) -> tuple[str, SearchMode]:
1509
- """Parse query prefix to determine search mode
1510
- Returns tuple of (cleaned_query, search_mode)
1511
- """
1512
- mode_map = {
1513
- "/local ": SearchMode.local,
1514
- "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
1515
- "/naive ": SearchMode.naive,
1516
- "/hybrid ": SearchMode.hybrid,
1517
- "/mix ": SearchMode.mix,
1518
- "/bypass ": SearchMode.bypass,
1519
- }
1520
-
1521
- for prefix, mode in mode_map.items():
1522
- if query.startswith(prefix):
1523
- # After removing prefix an leading spaces
1524
- cleaned_query = query[len(prefix) :].lstrip()
1525
- return cleaned_query, mode
1526
-
1527
- return query, SearchMode.hybrid
1528
-
1529
- @app.post("/api/generate")
1530
- async def generate(raw_request: Request, request: OllamaGenerateRequest):
1531
- """Handle generate completion requests acting as an Ollama model
1532
- For compatiblity purpuse, the request is not processed by LightRAG,
1533
- and will be handled by underlying LLM model.
1534
- """
1535
- try:
1536
- query = request.prompt
1537
- start_time = time.time_ns()
1538
- prompt_tokens = estimate_tokens(query)
1539
-
1540
- if request.system:
1541
- rag.llm_model_kwargs["system_prompt"] = request.system
1542
-
1543
- if request.stream:
1544
- from fastapi.responses import StreamingResponse
1545
-
1546
- response = await rag.llm_model_func(
1547
- query, stream=True, **rag.llm_model_kwargs
1548
- )
1549
-
1550
- async def stream_generator():
1551
- try:
1552
- first_chunk_time = None
1553
- last_chunk_time = None
1554
- total_response = ""
1555
-
1556
- # Ensure response is an async generator
1557
- if isinstance(response, str):
1558
- # If it's a string, send in two parts
1559
- first_chunk_time = time.time_ns()
1560
- last_chunk_time = first_chunk_time
1561
- total_response = response
1562
-
1563
- data = {
1564
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1565
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1566
- "response": response,
1567
- "done": False,
1568
- }
1569
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1570
-
1571
- completion_tokens = estimate_tokens(total_response)
1572
- total_time = last_chunk_time - start_time
1573
- prompt_eval_time = first_chunk_time - start_time
1574
- eval_time = last_chunk_time - first_chunk_time
1575
-
1576
- data = {
1577
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1578
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1579
- "done": True,
1580
- "total_duration": total_time,
1581
- "load_duration": 0,
1582
- "prompt_eval_count": prompt_tokens,
1583
- "prompt_eval_duration": prompt_eval_time,
1584
- "eval_count": completion_tokens,
1585
- "eval_duration": eval_time,
1586
- }
1587
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1588
- else:
1589
- async for chunk in response:
1590
- if chunk:
1591
- if first_chunk_time is None:
1592
- first_chunk_time = time.time_ns()
1593
-
1594
- last_chunk_time = time.time_ns()
1595
-
1596
- total_response += chunk
1597
- data = {
1598
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1599
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1600
- "response": chunk,
1601
- "done": False,
1602
- }
1603
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1604
-
1605
- completion_tokens = estimate_tokens(total_response)
1606
- total_time = last_chunk_time - start_time
1607
- prompt_eval_time = first_chunk_time - start_time
1608
- eval_time = last_chunk_time - first_chunk_time
1609
-
1610
- data = {
1611
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1612
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1613
- "done": True,
1614
- "total_duration": total_time,
1615
- "load_duration": 0,
1616
- "prompt_eval_count": prompt_tokens,
1617
- "prompt_eval_duration": prompt_eval_time,
1618
- "eval_count": completion_tokens,
1619
- "eval_duration": eval_time,
1620
- }
1621
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1622
- return
1623
-
1624
- except Exception as e:
1625
- logging.error(f"Error in stream_generator: {str(e)}")
1626
- raise
1627
-
1628
- return StreamingResponse(
1629
- stream_generator(),
1630
- media_type="application/x-ndjson",
1631
- headers={
1632
- "Cache-Control": "no-cache",
1633
- "Connection": "keep-alive",
1634
- "Content-Type": "application/x-ndjson",
1635
- "Access-Control-Allow-Origin": "*",
1636
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1637
- "Access-Control-Allow-Headers": "Content-Type",
1638
- },
1639
- )
1640
- else:
1641
- first_chunk_time = time.time_ns()
1642
- response_text = await rag.llm_model_func(
1643
- query, stream=False, **rag.llm_model_kwargs
1644
- )
1645
- last_chunk_time = time.time_ns()
1646
-
1647
- if not response_text:
1648
- response_text = "No response generated"
1649
-
1650
- completion_tokens = estimate_tokens(str(response_text))
1651
- total_time = last_chunk_time - start_time
1652
- prompt_eval_time = first_chunk_time - start_time
1653
- eval_time = last_chunk_time - first_chunk_time
1654
-
1655
- return {
1656
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1657
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1658
- "response": str(response_text),
1659
- "done": True,
1660
- "total_duration": total_time,
1661
- "load_duration": 0,
1662
- "prompt_eval_count": prompt_tokens,
1663
- "prompt_eval_duration": prompt_eval_time,
1664
- "eval_count": completion_tokens,
1665
- "eval_duration": eval_time,
1666
- }
1667
- except Exception as e:
1668
- trace_exception(e)
1669
- raise HTTPException(status_code=500, detail=str(e))
1670
-
1671
- @app.post("/api/chat")
1672
- async def chat(raw_request: Request, request: OllamaChatRequest):
1673
- """Process chat completion requests acting as an Ollama model
1674
- Routes user queries through LightRAG by selecting query mode based on prefix indicators.
1675
- Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
1676
- """
1677
- try:
1678
- # Get all messages
1679
- messages = request.messages
1680
- if not messages:
1681
- raise HTTPException(status_code=400, detail="No messages provided")
1682
-
1683
- # Get the last message as query and previous messages as history
1684
- query = messages[-1].content
1685
- # Convert OllamaMessage objects to dictionaries
1686
- conversation_history = [
1687
- {"role": msg.role, "content": msg.content} for msg in messages[:-1]
1688
- ]
1689
-
1690
- # Check for query prefix
1691
- cleaned_query, mode = parse_query_mode(query)
1692
-
1693
- start_time = time.time_ns()
1694
- prompt_tokens = estimate_tokens(cleaned_query)
1695
-
1696
- param_dict = {
1697
- "mode": mode,
1698
- "stream": request.stream,
1699
- "only_need_context": False,
1700
- "conversation_history": conversation_history,
1701
- "top_k": args.top_k,
1702
- }
1703
-
1704
- if args.history_turns is not None:
1705
- param_dict["history_turns"] = args.history_turns
1706
-
1707
- query_param = QueryParam(**param_dict)
1708
-
1709
- if request.stream:
1710
- from fastapi.responses import StreamingResponse
1711
-
1712
- # Determine if the request is prefix with "/bypass"
1713
- if mode == SearchMode.bypass:
1714
- if request.system:
1715
- rag.llm_model_kwargs["system_prompt"] = request.system
1716
- response = await rag.llm_model_func(
1717
- cleaned_query,
1718
- stream=True,
1719
- history_messages=conversation_history,
1720
- **rag.llm_model_kwargs,
1721
- )
1722
- else:
1723
- response = await rag.aquery( # Need await to get async generator
1724
- cleaned_query, param=query_param
1725
- )
1726
-
1727
- async def stream_generator():
1728
- first_chunk_time = None
1729
- last_chunk_time = None
1730
- total_response = ""
1731
-
1732
- try:
1733
- # Ensure response is an async generator
1734
- if isinstance(response, str):
1735
- # If it's a string, send in two parts
1736
- first_chunk_time = time.time_ns()
1737
- last_chunk_time = first_chunk_time
1738
- total_response = response
1739
-
1740
- data = {
1741
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1742
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1743
- "message": {
1744
- "role": "assistant",
1745
- "content": response,
1746
- "images": None,
1747
- },
1748
- "done": False,
1749
- }
1750
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1751
-
1752
- completion_tokens = estimate_tokens(total_response)
1753
- total_time = last_chunk_time - start_time
1754
- prompt_eval_time = first_chunk_time - start_time
1755
- eval_time = last_chunk_time - first_chunk_time
1756
-
1757
- data = {
1758
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1759
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1760
- "done": True,
1761
- "total_duration": total_time,
1762
- "load_duration": 0,
1763
- "prompt_eval_count": prompt_tokens,
1764
- "prompt_eval_duration": prompt_eval_time,
1765
- "eval_count": completion_tokens,
1766
- "eval_duration": eval_time,
1767
- }
1768
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1769
- else:
1770
- try:
1771
- async for chunk in response:
1772
- if chunk:
1773
- if first_chunk_time is None:
1774
- first_chunk_time = time.time_ns()
1775
-
1776
- last_chunk_time = time.time_ns()
1777
-
1778
- total_response += chunk
1779
- data = {
1780
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1781
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1782
- "message": {
1783
- "role": "assistant",
1784
- "content": chunk,
1785
- "images": None,
1786
- },
1787
- "done": False,
1788
- }
1789
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1790
- except (asyncio.CancelledError, Exception) as e:
1791
- error_msg = str(e)
1792
- if isinstance(e, asyncio.CancelledError):
1793
- error_msg = "Stream was cancelled by server"
1794
- else:
1795
- error_msg = f"Provider error: {error_msg}"
1796
-
1797
- logging.error(f"Stream error: {error_msg}")
1798
-
1799
- # Send error message to client
1800
- error_data = {
1801
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1802
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1803
- "message": {
1804
- "role": "assistant",
1805
- "content": f"\n\nError: {error_msg}",
1806
- "images": None,
1807
- },
1808
- "done": False,
1809
- }
1810
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
1811
-
1812
- # Send final message to close the stream
1813
- final_data = {
1814
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1815
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1816
- "done": True,
1817
- }
1818
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
1819
- return
1820
-
1821
- if last_chunk_time is not None:
1822
- completion_tokens = estimate_tokens(total_response)
1823
- total_time = last_chunk_time - start_time
1824
- prompt_eval_time = first_chunk_time - start_time
1825
- eval_time = last_chunk_time - first_chunk_time
1826
-
1827
- data = {
1828
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1829
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1830
- "done": True,
1831
- "total_duration": total_time,
1832
- "load_duration": 0,
1833
- "prompt_eval_count": prompt_tokens,
1834
- "prompt_eval_duration": prompt_eval_time,
1835
- "eval_count": completion_tokens,
1836
- "eval_duration": eval_time,
1837
- }
1838
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1839
-
1840
- except Exception as e:
1841
- error_msg = f"Error in stream_generator: {str(e)}"
1842
- logging.error(error_msg)
1843
-
1844
- # Send error message to client
1845
- error_data = {
1846
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1847
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1848
- "error": {"code": "STREAM_ERROR", "message": error_msg},
1849
- }
1850
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
1851
-
1852
- # Ensure sending end marker
1853
- final_data = {
1854
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1855
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1856
- "done": True,
1857
- }
1858
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
1859
- return
1860
-
1861
- return StreamingResponse(
1862
- stream_generator(),
1863
- media_type="application/x-ndjson",
1864
- headers={
1865
- "Cache-Control": "no-cache",
1866
- "Connection": "keep-alive",
1867
- "Content-Type": "application/x-ndjson",
1868
- "Access-Control-Allow-Origin": "*",
1869
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1870
- "Access-Control-Allow-Headers": "Content-Type",
1871
- },
1872
- )
1873
- else:
1874
- first_chunk_time = time.time_ns()
1875
-
1876
- # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
1877
- match_result = re.search(
1878
- r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
1879
- )
1880
- if match_result or mode == SearchMode.bypass:
1881
- if request.system:
1882
- rag.llm_model_kwargs["system_prompt"] = request.system
1883
-
1884
- response_text = await rag.llm_model_func(
1885
- cleaned_query,
1886
- stream=False,
1887
- history_messages=conversation_history,
1888
- **rag.llm_model_kwargs,
1889
- )
1890
- else:
1891
- response_text = await rag.aquery(cleaned_query, param=query_param)
1892
-
1893
- last_chunk_time = time.time_ns()
1894
-
1895
- if not response_text:
1896
- response_text = "No response generated"
1897
-
1898
- completion_tokens = estimate_tokens(str(response_text))
1899
- total_time = last_chunk_time - start_time
1900
- prompt_eval_time = first_chunk_time - start_time
1901
- eval_time = last_chunk_time - first_chunk_time
1902
-
1903
- return {
1904
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1905
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1906
- "message": {
1907
- "role": "assistant",
1908
- "content": str(response_text),
1909
- "images": None,
1910
- },
1911
- "done": True,
1912
- "total_duration": total_time,
1913
- "load_duration": 0,
1914
- "prompt_eval_count": prompt_tokens,
1915
- "prompt_eval_duration": prompt_eval_time,
1916
- "eval_count": completion_tokens,
1917
- "eval_duration": eval_time,
1918
- }
1919
- except Exception as e:
1920
- trace_exception(e)
1921
- raise HTTPException(status_code=500, detail=str(e))
1922
 
1923
  @app.get("/documents", dependencies=[Depends(optional_api_key)])
1924
  async def documents():
@@ -1945,10 +1418,10 @@ def create_app(args):
1945
  "embedding_binding_host": args.embedding_binding_host,
1946
  "embedding_model": args.embedding_model,
1947
  "max_tokens": args.max_tokens,
1948
- "kv_storage": ollama_server_infos.KV_STORAGE,
1949
- "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE,
1950
- "graph_storage": ollama_server_infos.GRAPH_STORAGE,
1951
- "vector_storage": ollama_server_infos.VECTOR_STORAGE,
1952
  },
1953
  }
1954
 
 
4
  File,
5
  UploadFile,
6
  Form,
 
7
  BackgroundTasks,
8
  )
9
 
 
 
 
10
  import threading
 
 
11
  import os
12
+ import json
13
+ import re
14
  from fastapi.staticfiles import StaticFiles
 
15
  import logging
16
  import argparse
17
+ from typing import List, Any, Optional, Union, Dict
18
+ from pydantic import BaseModel
 
19
  from lightrag import LightRAG, QueryParam
20
  from lightrag.api import __api_version__
 
21
  from lightrag.utils import EmbeddingFunc
22
  from enum import Enum
23
  from pathlib import Path
 
26
  from ascii_colors import trace_exception, ASCIIColors
27
  import sys
28
  import configparser
 
29
  from fastapi import Depends, Security
30
  from fastapi.security import APIKeyHeader
31
  from fastapi.middleware.cors import CORSMiddleware
32
  from contextlib import asynccontextmanager
 
33
  from starlette.status import HTTP_403_FORBIDDEN
34
  import pipmaster as pm
 
35
  from dotenv import load_dotenv
36
+ from .ollama_api import (
37
+ OllamaAPI,
38
+ )
39
+ from .ollama_api import ollama_server_infos
40
 
41
  # Load environment variables
42
  load_dotenv()
43
 
44
+ class RAGStorageConfig:
45
+ KV_STORAGE = "JsonKVStorage"
46
+ DOC_STATUS_STORAGE = "JsonDocStatusStorage"
47
+ GRAPH_STORAGE = "NetworkXStorage"
48
+ VECTOR_STORAGE = "NanoVectorDBStorage"
49
+
50
+ # Initialize rag storage config
51
+ rag_storage_config = RAGStorageConfig()
52
+
53
  # Global progress tracker
54
  scan_progress: Dict = {
55
  "is_scanning": False,
 
78
  return int(tokens)
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # read config.ini
82
  config = configparser.ConfigParser()
83
  config.read("config.ini", "utf-8")
 
85
  redis_uri = config.get("redis", "uri", fallback=None)
86
  if redis_uri:
87
  os.environ["REDIS_URI"] = redis_uri
88
+ rag_storage_config.KV_STORAGE = "RedisKVStorage"
89
+ rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
90
 
91
  # Neo4j config
92
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
 
96
  os.environ["NEO4J_URI"] = neo4j_uri
97
  os.environ["NEO4J_USERNAME"] = neo4j_username
98
  os.environ["NEO4J_PASSWORD"] = neo4j_password
99
+ rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
100
 
101
  # Milvus config
102
  milvus_uri = config.get("milvus", "uri", fallback=None)
 
108
  os.environ["MILVUS_USER"] = milvus_user
109
  os.environ["MILVUS_PASSWORD"] = milvus_password
110
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
111
+ rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
112
 
113
  # MongoDB config
114
  mongo_uri = config.get("mongodb", "uri", fallback=None)
 
116
  if mongo_uri:
117
  os.environ["MONGO_URI"] = mongo_uri
118
  os.environ["MONGO_DATABASE"] = mongo_database
119
+ rag_storage_config.KV_STORAGE = "MongoKVStorage"
120
+ rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
121
 
122
 
123
  def get_default_host(binding_type: str) -> str:
 
519
  help="Cosine similarity threshold (default: from env or 0.4)",
520
  )
521
 
522
+ # Ollama model name
523
  parser.add_argument(
524
  "--simulated-model-name",
525
  type=str,
 
584
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
585
 
586
 
587
+ # LightRAG query mode
588
  class SearchMode(str, Enum):
589
  naive = "naive"
590
  local = "local"
591
  global_ = "global"
592
  hybrid = "hybrid"
593
  mix = "mix"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
  class QueryRequest(BaseModel):
596
  query: str
 
834
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
835
  else {},
836
  embedding_func=embedding_func,
837
+ kv_storage=rag_storage_config.KV_STORAGE,
838
+ graph_storage=rag_storage_config.GRAPH_STORAGE,
839
+ vector_storage=rag_storage_config.VECTOR_STORAGE,
840
+ doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
841
  vector_db_storage_cls_kwargs={
842
  "cosine_better_than_threshold": args.cosine_threshold
843
  },
 
863
  llm_model_max_async=args.max_async,
864
  llm_model_max_token_size=args.max_tokens,
865
  embedding_func=embedding_func,
866
+ kv_storage=rag_storage_config.KV_STORAGE,
867
+ graph_storage=rag_storage_config.GRAPH_STORAGE,
868
+ vector_storage=rag_storage_config.VECTOR_STORAGE,
869
+ doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
870
  vector_db_storage_cls_kwargs={
871
  "cosine_better_than_threshold": args.cosine_threshold
872
  },
 
1389
  async def get_graphs(label: str):
1390
  return await rag.get_graps(nodel_label=label, max_depth=100)
1391
 
1392
+ # Add Ollama API routes
1393
+ ollama_api = OllamaAPI(rag)
1394
+ app.include_router(ollama_api.router, prefix="/api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1395
 
1396
  @app.get("/documents", dependencies=[Depends(optional_api_key)])
1397
  async def documents():
 
1418
  "embedding_binding_host": args.embedding_binding_host,
1419
  "embedding_model": args.embedding_model,
1420
  "max_tokens": args.max_tokens,
1421
+ "kv_storage": rag_storage_config.KV_STORAGE,
1422
+ "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
1423
+ "graph_storage": rag_storage_config.GRAPH_STORAGE,
1424
+ "vector_storage": rag_storage_config.VECTOR_STORAGE,
1425
  },
1426
  }
1427
 
lightrag/api/ollama_api.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Request
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any, Optional
4
+ import logging
5
+ import time
6
+ import json
7
+ import re
8
+ import os
9
+ from enum import Enum
10
+ from fastapi.responses import StreamingResponse
11
+ import asyncio
12
+ from ascii_colors import trace_exception
13
+ from lightrag import LightRAG, QueryParam
14
+
15
+ class OllamaServerInfos:
16
+ # Constants for emulated Ollama model information
17
+ LIGHTRAG_NAME = "lightrag"
18
+ LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
19
+ LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
20
+ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
21
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
22
+ LIGHTRAG_DIGEST = "sha256:lightrag"
23
+
24
+ ollama_server_infos = OllamaServerInfos()
25
+
26
+ # query mode according to query prefix (bypass is not LightRAG quer mode)
27
+ class SearchMode(str, Enum):
28
+ naive = "naive"
29
+ local = "local"
30
+ global_ = "global"
31
+ hybrid = "hybrid"
32
+ mix = "mix"
33
+ bypass = "bypass"
34
+
35
+ class OllamaMessage(BaseModel):
36
+ role: str
37
+ content: str
38
+ images: Optional[List[str]] = None
39
+
40
+ class OllamaChatRequest(BaseModel):
41
+ model: str
42
+ messages: List[OllamaMessage]
43
+ stream: bool = True
44
+ options: Optional[Dict[str, Any]] = None
45
+ system: Optional[str] = None
46
+
47
+ class OllamaChatResponse(BaseModel):
48
+ model: str
49
+ created_at: str
50
+ message: OllamaMessage
51
+ done: bool
52
+
53
+ class OllamaGenerateRequest(BaseModel):
54
+ model: str
55
+ prompt: str
56
+ system: Optional[str] = None
57
+ stream: bool = False
58
+ options: Optional[Dict[str, Any]] = None
59
+
60
+ class OllamaGenerateResponse(BaseModel):
61
+ model: str
62
+ created_at: str
63
+ response: str
64
+ done: bool
65
+ context: Optional[List[int]]
66
+ total_duration: Optional[int]
67
+ load_duration: Optional[int]
68
+ prompt_eval_count: Optional[int]
69
+ prompt_eval_duration: Optional[int]
70
+ eval_count: Optional[int]
71
+ eval_duration: Optional[int]
72
+
73
+ class OllamaVersionResponse(BaseModel):
74
+ version: str
75
+
76
+ class OllamaModelDetails(BaseModel):
77
+ parent_model: str
78
+ format: str
79
+ family: str
80
+ families: List[str]
81
+ parameter_size: str
82
+ quantization_level: str
83
+
84
+ class OllamaModel(BaseModel):
85
+ name: str
86
+ model: str
87
+ size: int
88
+ digest: str
89
+ modified_at: str
90
+ details: OllamaModelDetails
91
+
92
+ class OllamaTagResponse(BaseModel):
93
+ models: List[OllamaModel]
94
+
95
+ def estimate_tokens(text: str) -> int:
96
+ """Estimate the number of tokens in text
97
+ Chinese characters: approximately 1.5 tokens per character
98
+ English characters: approximately 0.25 tokens per character
99
+ """
100
+ # Use regex to match Chinese and non-Chinese characters separately
101
+ chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
102
+ non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
103
+
104
+ # Calculate estimated token count
105
+ tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
106
+
107
+ return int(tokens)
108
+
109
+ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
110
+ """Parse query prefix to determine search mode
111
+ Returns tuple of (cleaned_query, search_mode)
112
+ """
113
+ mode_map = {
114
+ "/local ": SearchMode.local,
115
+ "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
116
+ "/naive ": SearchMode.naive,
117
+ "/hybrid ": SearchMode.hybrid,
118
+ "/mix ": SearchMode.mix,
119
+ "/bypass ": SearchMode.bypass,
120
+ }
121
+
122
+ for prefix, mode in mode_map.items():
123
+ if query.startswith(prefix):
124
+ # After removing prefix an leading spaces
125
+ cleaned_query = query[len(prefix) :].lstrip()
126
+ return cleaned_query, mode
127
+
128
+ return query, SearchMode.hybrid
129
+
130
+ class OllamaAPI:
131
+ def __init__(self, rag: LightRAG):
132
+ self.rag = rag
133
+ self.ollama_server_infos = ollama_server_infos
134
+ self.router = APIRouter()
135
+ self.setup_routes()
136
+
137
+ def setup_routes(self):
138
+ @self.router.get("/version")
139
+ async def get_version():
140
+ """Get Ollama version information"""
141
+ return OllamaVersionResponse(version="0.5.4")
142
+
143
+ @self.router.get("/tags")
144
+ async def get_tags():
145
+ """Return available models acting as an Ollama server"""
146
+ return OllamaTagResponse(
147
+ models=[
148
+ {
149
+ "name": self.ollama_server_infos.LIGHTRAG_MODEL,
150
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
151
+ "size": self.ollama_server_infos.LIGHTRAG_SIZE,
152
+ "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
153
+ "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
154
+ "details": {
155
+ "parent_model": "",
156
+ "format": "gguf",
157
+ "family": self.ollama_server_infos.LIGHTRAG_NAME,
158
+ "families": [self.ollama_server_infos.LIGHTRAG_NAME],
159
+ "parameter_size": "13B",
160
+ "quantization_level": "Q4_0",
161
+ },
162
+ }
163
+ ]
164
+ )
165
+
166
+ @self.router.post("/generate")
167
+ async def generate(raw_request: Request, request: OllamaGenerateRequest):
168
+ """Handle generate completion requests acting as an Ollama model
169
+ For compatibility purpose, the request is not processed by LightRAG,
170
+ and will be handled by underlying LLM model.
171
+ """
172
+ try:
173
+ query = request.prompt
174
+ start_time = time.time_ns()
175
+ prompt_tokens = estimate_tokens(query)
176
+
177
+ if request.system:
178
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
179
+
180
+ if request.stream:
181
+ response = await self.rag.llm_model_func(
182
+ query, stream=True, **self.rag.llm_model_kwargs
183
+ )
184
+
185
+ async def stream_generator():
186
+ try:
187
+ first_chunk_time = None
188
+ last_chunk_time = None
189
+ total_response = ""
190
+
191
+ # Ensure response is an async generator
192
+ if isinstance(response, str):
193
+ # If it's a string, send in two parts
194
+ first_chunk_time = time.time_ns()
195
+ last_chunk_time = first_chunk_time
196
+ total_response = response
197
+
198
+ data = {
199
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
200
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
201
+ "response": response,
202
+ "done": False,
203
+ }
204
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
205
+
206
+ completion_tokens = estimate_tokens(total_response)
207
+ total_time = last_chunk_time - start_time
208
+ prompt_eval_time = first_chunk_time - start_time
209
+ eval_time = last_chunk_time - first_chunk_time
210
+
211
+ data = {
212
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
213
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
214
+ "done": True,
215
+ "total_duration": total_time,
216
+ "load_duration": 0,
217
+ "prompt_eval_count": prompt_tokens,
218
+ "prompt_eval_duration": prompt_eval_time,
219
+ "eval_count": completion_tokens,
220
+ "eval_duration": eval_time,
221
+ }
222
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
223
+ else:
224
+ async for chunk in response:
225
+ if chunk:
226
+ if first_chunk_time is None:
227
+ first_chunk_time = time.time_ns()
228
+
229
+ last_chunk_time = time.time_ns()
230
+
231
+ total_response += chunk
232
+ data = {
233
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
234
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
235
+ "response": chunk,
236
+ "done": False,
237
+ }
238
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
239
+
240
+ completion_tokens = estimate_tokens(total_response)
241
+ total_time = last_chunk_time - start_time
242
+ prompt_eval_time = first_chunk_time - start_time
243
+ eval_time = last_chunk_time - first_chunk_time
244
+
245
+ data = {
246
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
247
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
248
+ "done": True,
249
+ "total_duration": total_time,
250
+ "load_duration": 0,
251
+ "prompt_eval_count": prompt_tokens,
252
+ "prompt_eval_duration": prompt_eval_time,
253
+ "eval_count": completion_tokens,
254
+ "eval_duration": eval_time,
255
+ }
256
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
257
+ return
258
+
259
+ except Exception as e:
260
+ trace_exception(e)
261
+ raise
262
+
263
+ return StreamingResponse(
264
+ stream_generator(),
265
+ media_type="application/x-ndjson",
266
+ headers={
267
+ "Cache-Control": "no-cache",
268
+ "Connection": "keep-alive",
269
+ "Content-Type": "application/x-ndjson",
270
+ "Access-Control-Allow-Origin": "*",
271
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
272
+ "Access-Control-Allow-Headers": "Content-Type",
273
+ },
274
+ )
275
+ else:
276
+ first_chunk_time = time.time_ns()
277
+ response_text = await self.rag.llm_model_func(
278
+ query, stream=False, **self.rag.llm_model_kwargs
279
+ )
280
+ last_chunk_time = time.time_ns()
281
+
282
+ if not response_text:
283
+ response_text = "No response generated"
284
+
285
+ completion_tokens = estimate_tokens(str(response_text))
286
+ total_time = last_chunk_time - start_time
287
+ prompt_eval_time = first_chunk_time - start_time
288
+ eval_time = last_chunk_time - first_chunk_time
289
+
290
+ return {
291
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
292
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
293
+ "response": str(response_text),
294
+ "done": True,
295
+ "total_duration": total_time,
296
+ "load_duration": 0,
297
+ "prompt_eval_count": prompt_tokens,
298
+ "prompt_eval_duration": prompt_eval_time,
299
+ "eval_count": completion_tokens,
300
+ "eval_duration": eval_time,
301
+ }
302
+ except Exception as e:
303
+ trace_exception(e)
304
+ raise HTTPException(status_code=500, detail=str(e))
305
+
306
+ @self.router.post("/chat")
307
+ async def chat(raw_request: Request, request: OllamaChatRequest):
308
+ """Process chat completion requests acting as an Ollama model
309
+ Routes user queries through LightRAG by selecting query mode based on prefix indicators.
310
+ Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
311
+ """
312
+ try:
313
+ # Get all messages
314
+ messages = request.messages
315
+ if not messages:
316
+ raise HTTPException(status_code=400, detail="No messages provided")
317
+
318
+ # Get the last message as query and previous messages as history
319
+ query = messages[-1].content
320
+ # Convert OllamaMessage objects to dictionaries
321
+ conversation_history = [
322
+ {"role": msg.role, "content": msg.content} for msg in messages[:-1]
323
+ ]
324
+
325
+ # Check for query prefix
326
+ cleaned_query, mode = parse_query_mode(query)
327
+
328
+ start_time = time.time_ns()
329
+ prompt_tokens = estimate_tokens(cleaned_query)
330
+
331
+ param_dict = {
332
+ "mode": mode,
333
+ "stream": request.stream,
334
+ "only_need_context": False,
335
+ "conversation_history": conversation_history,
336
+ "top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50,
337
+ }
338
+
339
+ if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None:
340
+ param_dict["history_turns"] = self.rag.args.history_turns
341
+
342
+ query_param = QueryParam(**param_dict)
343
+
344
+ if request.stream:
345
+ # Determine if the request is prefix with "/bypass"
346
+ if mode == SearchMode.bypass:
347
+ if request.system:
348
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
349
+ response = await self.rag.llm_model_func(
350
+ cleaned_query,
351
+ stream=True,
352
+ history_messages=conversation_history,
353
+ **self.rag.llm_model_kwargs,
354
+ )
355
+ else:
356
+ response = await self.rag.aquery(
357
+ cleaned_query, param=query_param
358
+ )
359
+
360
+ async def stream_generator():
361
+ first_chunk_time = None
362
+ last_chunk_time = None
363
+ total_response = ""
364
+
365
+ try:
366
+ # Ensure response is an async generator
367
+ if isinstance(response, str):
368
+ # If it's a string, send in two parts
369
+ first_chunk_time = time.time_ns()
370
+ last_chunk_time = first_chunk_time
371
+ total_response = response
372
+
373
+ data = {
374
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
375
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
376
+ "message": {
377
+ "role": "assistant",
378
+ "content": response,
379
+ "images": None,
380
+ },
381
+ "done": False,
382
+ }
383
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
384
+
385
+ completion_tokens = estimate_tokens(total_response)
386
+ total_time = last_chunk_time - start_time
387
+ prompt_eval_time = first_chunk_time - start_time
388
+ eval_time = last_chunk_time - first_chunk_time
389
+
390
+ data = {
391
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
392
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
393
+ "done": True,
394
+ "total_duration": total_time,
395
+ "load_duration": 0,
396
+ "prompt_eval_count": prompt_tokens,
397
+ "prompt_eval_duration": prompt_eval_time,
398
+ "eval_count": completion_tokens,
399
+ "eval_duration": eval_time,
400
+ }
401
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
402
+ else:
403
+ try:
404
+ async for chunk in response:
405
+ if chunk:
406
+ if first_chunk_time is None:
407
+ first_chunk_time = time.time_ns()
408
+
409
+ last_chunk_time = time.time_ns()
410
+
411
+ total_response += chunk
412
+ data = {
413
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
414
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
415
+ "message": {
416
+ "role": "assistant",
417
+ "content": chunk,
418
+ "images": None,
419
+ },
420
+ "done": False,
421
+ }
422
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
423
+ except (asyncio.CancelledError, Exception) as e:
424
+ error_msg = str(e)
425
+ if isinstance(e, asyncio.CancelledError):
426
+ error_msg = "Stream was cancelled by server"
427
+ else:
428
+ error_msg = f"Provider error: {error_msg}"
429
+
430
+ logging.error(f"Stream error: {error_msg}")
431
+
432
+ # Send error message to client
433
+ error_data = {
434
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
435
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
436
+ "message": {
437
+ "role": "assistant",
438
+ "content": f"\n\nError: {error_msg}",
439
+ "images": None,
440
+ },
441
+ "done": False,
442
+ }
443
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
444
+
445
+ # Send final message to close the stream
446
+ final_data = {
447
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
448
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
449
+ "done": True,
450
+ }
451
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
452
+ return
453
+
454
+ if last_chunk_time is not None:
455
+ completion_tokens = estimate_tokens(total_response)
456
+ total_time = last_chunk_time - start_time
457
+ prompt_eval_time = first_chunk_time - start_time
458
+ eval_time = last_chunk_time - first_chunk_time
459
+
460
+ data = {
461
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
462
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
463
+ "done": True,
464
+ "total_duration": total_time,
465
+ "load_duration": 0,
466
+ "prompt_eval_count": prompt_tokens,
467
+ "prompt_eval_duration": prompt_eval_time,
468
+ "eval_count": completion_tokens,
469
+ "eval_duration": eval_time,
470
+ }
471
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
472
+
473
+ except Exception as e:
474
+ error_msg = f"Error in stream_generator: {str(e)}"
475
+ logging.error(error_msg)
476
+
477
+ # Send error message to client
478
+ error_data = {
479
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
480
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
481
+ "error": {"code": "STREAM_ERROR", "message": error_msg},
482
+ }
483
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
484
+
485
+ # Ensure sending end marker
486
+ final_data = {
487
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
488
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
489
+ "done": True,
490
+ }
491
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
492
+ return
493
+
494
+ return StreamingResponse(
495
+ stream_generator(),
496
+ media_type="application/x-ndjson",
497
+ headers={
498
+ "Cache-Control": "no-cache",
499
+ "Connection": "keep-alive",
500
+ "Content-Type": "application/x-ndjson",
501
+ "Access-Control-Allow-Origin": "*",
502
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
503
+ "Access-Control-Allow-Headers": "Content-Type",
504
+ },
505
+ )
506
+ else:
507
+ first_chunk_time = time.time_ns()
508
+
509
+ # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
510
+ match_result = re.search(
511
+ r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
512
+ )
513
+ if match_result or mode == SearchMode.bypass:
514
+ if request.system:
515
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
516
+
517
+ response_text = await self.rag.llm_model_func(
518
+ cleaned_query,
519
+ stream=False,
520
+ history_messages=conversation_history,
521
+ **self.rag.llm_model_kwargs,
522
+ )
523
+ else:
524
+ response_text = await self.rag.aquery(cleaned_query, param=query_param)
525
+
526
+ last_chunk_time = time.time_ns()
527
+
528
+ if not response_text:
529
+ response_text = "No response generated"
530
+
531
+ completion_tokens = estimate_tokens(str(response_text))
532
+ total_time = last_chunk_time - start_time
533
+ prompt_eval_time = first_chunk_time - start_time
534
+ eval_time = last_chunk_time - first_chunk_time
535
+
536
+ return {
537
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
538
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
539
+ "message": {
540
+ "role": "assistant",
541
+ "content": str(response_text),
542
+ "images": None,
543
+ },
544
+ "done": True,
545
+ "total_duration": total_time,
546
+ "load_duration": 0,
547
+ "prompt_eval_count": prompt_tokens,
548
+ "prompt_eval_duration": prompt_eval_time,
549
+ "eval_count": completion_tokens,
550
+ "eval_duration": eval_time,
551
+ }
552
+ except Exception as e:
553
+ trace_exception(e)
554
+ raise HTTPException(status_code=500, detail=str(e))