zrguo commited on
Commit
5ba9af1
·
unverified ·
2 Parent(s): 79d6809 3f47dc3

Merge pull request #717 from danielaskdd/split-ollama-api-to-separated-file

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -4,28 +4,20 @@ from fastapi import (
4
  File,
5
  UploadFile,
6
  Form,
7
- Request,
8
  BackgroundTasks,
9
  )
10
 
11
- # Backend (Python)
12
- # Add this to store progress globally
13
- from typing import Dict
14
  import threading
15
- import asyncio
16
- import json
17
  import os
18
-
 
19
  from fastapi.staticfiles import StaticFiles
20
- from pydantic import BaseModel
21
  import logging
22
  import argparse
23
- import time
24
- import re
25
- from typing import List, Any, Optional, Union
26
  from lightrag import LightRAG, QueryParam
27
  from lightrag.api import __api_version__
28
-
29
  from lightrag.utils import EmbeddingFunc
30
  from enum import Enum
31
  from pathlib import Path
@@ -34,20 +26,32 @@ import aiofiles
34
  from ascii_colors import trace_exception, ASCIIColors
35
  import sys
36
  import configparser
37
-
38
  from fastapi import Depends, Security
39
  from fastapi.security import APIKeyHeader
40
  from fastapi.middleware.cors import CORSMiddleware
41
  from contextlib import asynccontextmanager
42
-
43
  from starlette.status import HTTP_403_FORBIDDEN
44
  import pipmaster as pm
45
-
46
  from dotenv import load_dotenv
 
 
 
 
47
 
48
  # Load environment variables
49
  load_dotenv()
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Global progress tracker
52
  scan_progress: Dict = {
53
  "is_scanning": False,
@@ -76,24 +80,6 @@ def estimate_tokens(text: str) -> int:
76
  return int(tokens)
77
 
78
 
79
- class OllamaServerInfos:
80
- # Constants for emulated Ollama model information
81
- LIGHTRAG_NAME = "lightrag"
82
- LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
83
- LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
84
- LIGHTRAG_SIZE = 7365960935 # it's a dummy value
85
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
86
- LIGHTRAG_DIGEST = "sha256:lightrag"
87
-
88
- KV_STORAGE = "JsonKVStorage"
89
- DOC_STATUS_STORAGE = "JsonDocStatusStorage"
90
- GRAPH_STORAGE = "NetworkXStorage"
91
- VECTOR_STORAGE = "NanoVectorDBStorage"
92
-
93
-
94
- # Add infos
95
- ollama_server_infos = OllamaServerInfos()
96
-
97
  # read config.ini
98
  config = configparser.ConfigParser()
99
  config.read("config.ini", "utf-8")
@@ -101,8 +87,8 @@ config.read("config.ini", "utf-8")
101
  redis_uri = config.get("redis", "uri", fallback=None)
102
  if redis_uri:
103
  os.environ["REDIS_URI"] = redis_uri
104
- ollama_server_infos.KV_STORAGE = "RedisKVStorage"
105
- ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage"
106
 
107
  # Neo4j config
108
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
@@ -112,7 +98,7 @@ if neo4j_uri:
112
  os.environ["NEO4J_URI"] = neo4j_uri
113
  os.environ["NEO4J_USERNAME"] = neo4j_username
114
  os.environ["NEO4J_PASSWORD"] = neo4j_password
115
- ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage"
116
 
117
  # Milvus config
118
  milvus_uri = config.get("milvus", "uri", fallback=None)
@@ -124,7 +110,7 @@ if milvus_uri:
124
  os.environ["MILVUS_USER"] = milvus_user
125
  os.environ["MILVUS_PASSWORD"] = milvus_password
126
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
127
- ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge"
128
 
129
  # MongoDB config
130
  mongo_uri = config.get("mongodb", "uri", fallback=None)
@@ -132,8 +118,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
132
  if mongo_uri:
133
  os.environ["MONGO_URI"] = mongo_uri
134
  os.environ["MONGO_DATABASE"] = mongo_database
135
- ollama_server_infos.KV_STORAGE = "MongoKVStorage"
136
- ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage"
137
 
138
 
139
  def get_default_host(binding_type: str) -> str:
@@ -535,6 +521,7 @@ def parse_args() -> argparse.Namespace:
535
  help="Cosine similarity threshold (default: from env or 0.4)",
536
  )
537
 
 
538
  parser.add_argument(
539
  "--simulated-model-name",
540
  type=str,
@@ -599,83 +586,13 @@ class DocumentManager:
599
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
600
 
601
 
602
- # Pydantic models
603
  class SearchMode(str, Enum):
604
  naive = "naive"
605
  local = "local"
606
  global_ = "global"
607
  hybrid = "hybrid"
608
  mix = "mix"
609
- bypass = "bypass"
610
-
611
-
612
- class OllamaMessage(BaseModel):
613
- role: str
614
- content: str
615
- images: Optional[List[str]] = None
616
-
617
-
618
- class OllamaChatRequest(BaseModel):
619
- model: str = ollama_server_infos.LIGHTRAG_MODEL
620
- messages: List[OllamaMessage]
621
- stream: bool = True # Default to streaming mode
622
- options: Optional[Dict[str, Any]] = None
623
- system: Optional[str] = None
624
-
625
-
626
- class OllamaChatResponse(BaseModel):
627
- model: str
628
- created_at: str
629
- message: OllamaMessage
630
- done: bool
631
-
632
-
633
- class OllamaGenerateRequest(BaseModel):
634
- model: str = ollama_server_infos.LIGHTRAG_MODEL
635
- prompt: str
636
- system: Optional[str] = None
637
- stream: bool = False
638
- options: Optional[Dict[str, Any]] = None
639
-
640
-
641
- class OllamaGenerateResponse(BaseModel):
642
- model: str
643
- created_at: str
644
- response: str
645
- done: bool
646
- context: Optional[List[int]]
647
- total_duration: Optional[int]
648
- load_duration: Optional[int]
649
- prompt_eval_count: Optional[int]
650
- prompt_eval_duration: Optional[int]
651
- eval_count: Optional[int]
652
- eval_duration: Optional[int]
653
-
654
-
655
- class OllamaVersionResponse(BaseModel):
656
- version: str
657
-
658
-
659
- class OllamaModelDetails(BaseModel):
660
- parent_model: str
661
- format: str
662
- family: str
663
- families: List[str]
664
- parameter_size: str
665
- quantization_level: str
666
-
667
-
668
- class OllamaModel(BaseModel):
669
- name: str
670
- model: str
671
- size: int
672
- digest: str
673
- modified_at: str
674
- details: OllamaModelDetails
675
-
676
-
677
- class OllamaTagResponse(BaseModel):
678
- models: List[OllamaModel]
679
 
680
 
681
  class QueryRequest(BaseModel):
@@ -920,10 +837,10 @@ def create_app(args):
920
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
921
  else {},
922
  embedding_func=embedding_func,
923
- kv_storage=ollama_server_infos.KV_STORAGE,
924
- graph_storage=ollama_server_infos.GRAPH_STORAGE,
925
- vector_storage=ollama_server_infos.VECTOR_STORAGE,
926
- doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
927
  vector_db_storage_cls_kwargs={
928
  "cosine_better_than_threshold": args.cosine_threshold
929
  },
@@ -949,10 +866,10 @@ def create_app(args):
949
  llm_model_max_async=args.max_async,
950
  llm_model_max_token_size=args.max_tokens,
951
  embedding_func=embedding_func,
952
- kv_storage=ollama_server_infos.KV_STORAGE,
953
- graph_storage=ollama_server_infos.GRAPH_STORAGE,
954
- vector_storage=ollama_server_infos.VECTOR_STORAGE,
955
- doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
956
  vector_db_storage_cls_kwargs={
957
  "cosine_better_than_threshold": args.cosine_threshold
958
  },
@@ -1475,450 +1392,9 @@ def create_app(args):
1475
  async def get_graphs(label: str):
1476
  return await rag.get_graps(nodel_label=label, max_depth=100)
1477
 
1478
- # Ollama compatible API endpoints
1479
- # -------------------------------------------------
1480
- @app.get("/api/version")
1481
- async def get_version():
1482
- """Get Ollama version information"""
1483
- return OllamaVersionResponse(version="0.5.4")
1484
-
1485
- @app.get("/api/tags")
1486
- async def get_tags():
1487
- """Retrun available models acting as an Ollama server"""
1488
- return OllamaTagResponse(
1489
- models=[
1490
- {
1491
- "name": ollama_server_infos.LIGHTRAG_MODEL,
1492
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1493
- "size": ollama_server_infos.LIGHTRAG_SIZE,
1494
- "digest": ollama_server_infos.LIGHTRAG_DIGEST,
1495
- "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1496
- "details": {
1497
- "parent_model": "",
1498
- "format": "gguf",
1499
- "family": ollama_server_infos.LIGHTRAG_NAME,
1500
- "families": [ollama_server_infos.LIGHTRAG_NAME],
1501
- "parameter_size": "13B",
1502
- "quantization_level": "Q4_0",
1503
- },
1504
- }
1505
- ]
1506
- )
1507
-
1508
- def parse_query_mode(query: str) -> tuple[str, SearchMode]:
1509
- """Parse query prefix to determine search mode
1510
- Returns tuple of (cleaned_query, search_mode)
1511
- """
1512
- mode_map = {
1513
- "/local ": SearchMode.local,
1514
- "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
1515
- "/naive ": SearchMode.naive,
1516
- "/hybrid ": SearchMode.hybrid,
1517
- "/mix ": SearchMode.mix,
1518
- "/bypass ": SearchMode.bypass,
1519
- }
1520
-
1521
- for prefix, mode in mode_map.items():
1522
- if query.startswith(prefix):
1523
- # After removing prefix an leading spaces
1524
- cleaned_query = query[len(prefix) :].lstrip()
1525
- return cleaned_query, mode
1526
-
1527
- return query, SearchMode.hybrid
1528
-
1529
- @app.post("/api/generate")
1530
- async def generate(raw_request: Request, request: OllamaGenerateRequest):
1531
- """Handle generate completion requests acting as an Ollama model
1532
- For compatiblity purpuse, the request is not processed by LightRAG,
1533
- and will be handled by underlying LLM model.
1534
- """
1535
- try:
1536
- query = request.prompt
1537
- start_time = time.time_ns()
1538
- prompt_tokens = estimate_tokens(query)
1539
-
1540
- if request.system:
1541
- rag.llm_model_kwargs["system_prompt"] = request.system
1542
-
1543
- if request.stream:
1544
- from fastapi.responses import StreamingResponse
1545
-
1546
- response = await rag.llm_model_func(
1547
- query, stream=True, **rag.llm_model_kwargs
1548
- )
1549
-
1550
- async def stream_generator():
1551
- try:
1552
- first_chunk_time = None
1553
- last_chunk_time = None
1554
- total_response = ""
1555
-
1556
- # Ensure response is an async generator
1557
- if isinstance(response, str):
1558
- # If it's a string, send in two parts
1559
- first_chunk_time = time.time_ns()
1560
- last_chunk_time = first_chunk_time
1561
- total_response = response
1562
-
1563
- data = {
1564
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1565
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1566
- "response": response,
1567
- "done": False,
1568
- }
1569
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1570
-
1571
- completion_tokens = estimate_tokens(total_response)
1572
- total_time = last_chunk_time - start_time
1573
- prompt_eval_time = first_chunk_time - start_time
1574
- eval_time = last_chunk_time - first_chunk_time
1575
-
1576
- data = {
1577
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1578
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1579
- "done": True,
1580
- "total_duration": total_time,
1581
- "load_duration": 0,
1582
- "prompt_eval_count": prompt_tokens,
1583
- "prompt_eval_duration": prompt_eval_time,
1584
- "eval_count": completion_tokens,
1585
- "eval_duration": eval_time,
1586
- }
1587
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1588
- else:
1589
- async for chunk in response:
1590
- if chunk:
1591
- if first_chunk_time is None:
1592
- first_chunk_time = time.time_ns()
1593
-
1594
- last_chunk_time = time.time_ns()
1595
-
1596
- total_response += chunk
1597
- data = {
1598
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1599
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1600
- "response": chunk,
1601
- "done": False,
1602
- }
1603
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1604
-
1605
- completion_tokens = estimate_tokens(total_response)
1606
- total_time = last_chunk_time - start_time
1607
- prompt_eval_time = first_chunk_time - start_time
1608
- eval_time = last_chunk_time - first_chunk_time
1609
-
1610
- data = {
1611
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1612
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1613
- "done": True,
1614
- "total_duration": total_time,
1615
- "load_duration": 0,
1616
- "prompt_eval_count": prompt_tokens,
1617
- "prompt_eval_duration": prompt_eval_time,
1618
- "eval_count": completion_tokens,
1619
- "eval_duration": eval_time,
1620
- }
1621
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1622
- return
1623
-
1624
- except Exception as e:
1625
- logging.error(f"Error in stream_generator: {str(e)}")
1626
- raise
1627
-
1628
- return StreamingResponse(
1629
- stream_generator(),
1630
- media_type="application/x-ndjson",
1631
- headers={
1632
- "Cache-Control": "no-cache",
1633
- "Connection": "keep-alive",
1634
- "Content-Type": "application/x-ndjson",
1635
- "Access-Control-Allow-Origin": "*",
1636
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1637
- "Access-Control-Allow-Headers": "Content-Type",
1638
- },
1639
- )
1640
- else:
1641
- first_chunk_time = time.time_ns()
1642
- response_text = await rag.llm_model_func(
1643
- query, stream=False, **rag.llm_model_kwargs
1644
- )
1645
- last_chunk_time = time.time_ns()
1646
-
1647
- if not response_text:
1648
- response_text = "No response generated"
1649
-
1650
- completion_tokens = estimate_tokens(str(response_text))
1651
- total_time = last_chunk_time - start_time
1652
- prompt_eval_time = first_chunk_time - start_time
1653
- eval_time = last_chunk_time - first_chunk_time
1654
-
1655
- return {
1656
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1657
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1658
- "response": str(response_text),
1659
- "done": True,
1660
- "total_duration": total_time,
1661
- "load_duration": 0,
1662
- "prompt_eval_count": prompt_tokens,
1663
- "prompt_eval_duration": prompt_eval_time,
1664
- "eval_count": completion_tokens,
1665
- "eval_duration": eval_time,
1666
- }
1667
- except Exception as e:
1668
- trace_exception(e)
1669
- raise HTTPException(status_code=500, detail=str(e))
1670
-
1671
- @app.post("/api/chat")
1672
- async def chat(raw_request: Request, request: OllamaChatRequest):
1673
- """Process chat completion requests acting as an Ollama model
1674
- Routes user queries through LightRAG by selecting query mode based on prefix indicators.
1675
- Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
1676
- """
1677
- try:
1678
- # Get all messages
1679
- messages = request.messages
1680
- if not messages:
1681
- raise HTTPException(status_code=400, detail="No messages provided")
1682
-
1683
- # Get the last message as query and previous messages as history
1684
- query = messages[-1].content
1685
- # Convert OllamaMessage objects to dictionaries
1686
- conversation_history = [
1687
- {"role": msg.role, "content": msg.content} for msg in messages[:-1]
1688
- ]
1689
-
1690
- # Check for query prefix
1691
- cleaned_query, mode = parse_query_mode(query)
1692
-
1693
- start_time = time.time_ns()
1694
- prompt_tokens = estimate_tokens(cleaned_query)
1695
-
1696
- param_dict = {
1697
- "mode": mode,
1698
- "stream": request.stream,
1699
- "only_need_context": False,
1700
- "conversation_history": conversation_history,
1701
- "top_k": args.top_k,
1702
- }
1703
-
1704
- if args.history_turns is not None:
1705
- param_dict["history_turns"] = args.history_turns
1706
-
1707
- query_param = QueryParam(**param_dict)
1708
-
1709
- if request.stream:
1710
- from fastapi.responses import StreamingResponse
1711
-
1712
- # Determine if the request is prefix with "/bypass"
1713
- if mode == SearchMode.bypass:
1714
- if request.system:
1715
- rag.llm_model_kwargs["system_prompt"] = request.system
1716
- response = await rag.llm_model_func(
1717
- cleaned_query,
1718
- stream=True,
1719
- history_messages=conversation_history,
1720
- **rag.llm_model_kwargs,
1721
- )
1722
- else:
1723
- response = await rag.aquery( # Need await to get async generator
1724
- cleaned_query, param=query_param
1725
- )
1726
-
1727
- async def stream_generator():
1728
- first_chunk_time = None
1729
- last_chunk_time = None
1730
- total_response = ""
1731
-
1732
- try:
1733
- # Ensure response is an async generator
1734
- if isinstance(response, str):
1735
- # If it's a string, send in two parts
1736
- first_chunk_time = time.time_ns()
1737
- last_chunk_time = first_chunk_time
1738
- total_response = response
1739
-
1740
- data = {
1741
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1742
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1743
- "message": {
1744
- "role": "assistant",
1745
- "content": response,
1746
- "images": None,
1747
- },
1748
- "done": False,
1749
- }
1750
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1751
-
1752
- completion_tokens = estimate_tokens(total_response)
1753
- total_time = last_chunk_time - start_time
1754
- prompt_eval_time = first_chunk_time - start_time
1755
- eval_time = last_chunk_time - first_chunk_time
1756
-
1757
- data = {
1758
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1759
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1760
- "done": True,
1761
- "total_duration": total_time,
1762
- "load_duration": 0,
1763
- "prompt_eval_count": prompt_tokens,
1764
- "prompt_eval_duration": prompt_eval_time,
1765
- "eval_count": completion_tokens,
1766
- "eval_duration": eval_time,
1767
- }
1768
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1769
- else:
1770
- try:
1771
- async for chunk in response:
1772
- if chunk:
1773
- if first_chunk_time is None:
1774
- first_chunk_time = time.time_ns()
1775
-
1776
- last_chunk_time = time.time_ns()
1777
-
1778
- total_response += chunk
1779
- data = {
1780
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1781
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1782
- "message": {
1783
- "role": "assistant",
1784
- "content": chunk,
1785
- "images": None,
1786
- },
1787
- "done": False,
1788
- }
1789
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1790
- except (asyncio.CancelledError, Exception) as e:
1791
- error_msg = str(e)
1792
- if isinstance(e, asyncio.CancelledError):
1793
- error_msg = "Stream was cancelled by server"
1794
- else:
1795
- error_msg = f"Provider error: {error_msg}"
1796
-
1797
- logging.error(f"Stream error: {error_msg}")
1798
-
1799
- # Send error message to client
1800
- error_data = {
1801
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1802
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1803
- "message": {
1804
- "role": "assistant",
1805
- "content": f"\n\nError: {error_msg}",
1806
- "images": None,
1807
- },
1808
- "done": False,
1809
- }
1810
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
1811
-
1812
- # Send final message to close the stream
1813
- final_data = {
1814
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1815
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1816
- "done": True,
1817
- }
1818
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
1819
- return
1820
-
1821
- if last_chunk_time is not None:
1822
- completion_tokens = estimate_tokens(total_response)
1823
- total_time = last_chunk_time - start_time
1824
- prompt_eval_time = first_chunk_time - start_time
1825
- eval_time = last_chunk_time - first_chunk_time
1826
-
1827
- data = {
1828
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1829
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1830
- "done": True,
1831
- "total_duration": total_time,
1832
- "load_duration": 0,
1833
- "prompt_eval_count": prompt_tokens,
1834
- "prompt_eval_duration": prompt_eval_time,
1835
- "eval_count": completion_tokens,
1836
- "eval_duration": eval_time,
1837
- }
1838
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
1839
-
1840
- except Exception as e:
1841
- error_msg = f"Error in stream_generator: {str(e)}"
1842
- logging.error(error_msg)
1843
-
1844
- # Send error message to client
1845
- error_data = {
1846
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1847
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1848
- "error": {"code": "STREAM_ERROR", "message": error_msg},
1849
- }
1850
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
1851
-
1852
- # Ensure sending end marker
1853
- final_data = {
1854
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1855
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1856
- "done": True,
1857
- }
1858
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
1859
- return
1860
-
1861
- return StreamingResponse(
1862
- stream_generator(),
1863
- media_type="application/x-ndjson",
1864
- headers={
1865
- "Cache-Control": "no-cache",
1866
- "Connection": "keep-alive",
1867
- "Content-Type": "application/x-ndjson",
1868
- "Access-Control-Allow-Origin": "*",
1869
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1870
- "Access-Control-Allow-Headers": "Content-Type",
1871
- },
1872
- )
1873
- else:
1874
- first_chunk_time = time.time_ns()
1875
-
1876
- # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
1877
- match_result = re.search(
1878
- r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
1879
- )
1880
- if match_result or mode == SearchMode.bypass:
1881
- if request.system:
1882
- rag.llm_model_kwargs["system_prompt"] = request.system
1883
-
1884
- response_text = await rag.llm_model_func(
1885
- cleaned_query,
1886
- stream=False,
1887
- history_messages=conversation_history,
1888
- **rag.llm_model_kwargs,
1889
- )
1890
- else:
1891
- response_text = await rag.aquery(cleaned_query, param=query_param)
1892
-
1893
- last_chunk_time = time.time_ns()
1894
-
1895
- if not response_text:
1896
- response_text = "No response generated"
1897
-
1898
- completion_tokens = estimate_tokens(str(response_text))
1899
- total_time = last_chunk_time - start_time
1900
- prompt_eval_time = first_chunk_time - start_time
1901
- eval_time = last_chunk_time - first_chunk_time
1902
-
1903
- return {
1904
- "model": ollama_server_infos.LIGHTRAG_MODEL,
1905
- "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1906
- "message": {
1907
- "role": "assistant",
1908
- "content": str(response_text),
1909
- "images": None,
1910
- },
1911
- "done": True,
1912
- "total_duration": total_time,
1913
- "load_duration": 0,
1914
- "prompt_eval_count": prompt_tokens,
1915
- "prompt_eval_duration": prompt_eval_time,
1916
- "eval_count": completion_tokens,
1917
- "eval_duration": eval_time,
1918
- }
1919
- except Exception as e:
1920
- trace_exception(e)
1921
- raise HTTPException(status_code=500, detail=str(e))
1922
 
1923
  @app.get("/documents", dependencies=[Depends(optional_api_key)])
1924
  async def documents():
@@ -1945,10 +1421,10 @@ def create_app(args):
1945
  "embedding_binding_host": args.embedding_binding_host,
1946
  "embedding_model": args.embedding_model,
1947
  "max_tokens": args.max_tokens,
1948
- "kv_storage": ollama_server_infos.KV_STORAGE,
1949
- "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE,
1950
- "graph_storage": ollama_server_infos.GRAPH_STORAGE,
1951
- "vector_storage": ollama_server_infos.VECTOR_STORAGE,
1952
  },
1953
  }
1954
 
 
4
  File,
5
  UploadFile,
6
  Form,
 
7
  BackgroundTasks,
8
  )
9
 
 
 
 
10
  import threading
 
 
11
  import os
12
+ import json
13
+ import re
14
  from fastapi.staticfiles import StaticFiles
 
15
  import logging
16
  import argparse
17
+ from typing import List, Any, Optional, Union, Dict
18
+ from pydantic import BaseModel
 
19
  from lightrag import LightRAG, QueryParam
20
  from lightrag.api import __api_version__
 
21
  from lightrag.utils import EmbeddingFunc
22
  from enum import Enum
23
  from pathlib import Path
 
26
  from ascii_colors import trace_exception, ASCIIColors
27
  import sys
28
  import configparser
 
29
  from fastapi import Depends, Security
30
  from fastapi.security import APIKeyHeader
31
  from fastapi.middleware.cors import CORSMiddleware
32
  from contextlib import asynccontextmanager
 
33
  from starlette.status import HTTP_403_FORBIDDEN
34
  import pipmaster as pm
 
35
  from dotenv import load_dotenv
36
+ from .ollama_api import (
37
+ OllamaAPI,
38
+ )
39
+ from .ollama_api import ollama_server_infos
40
 
41
  # Load environment variables
42
  load_dotenv()
43
 
44
+
45
+ class RAGStorageConfig:
46
+ KV_STORAGE = "JsonKVStorage"
47
+ DOC_STATUS_STORAGE = "JsonDocStatusStorage"
48
+ GRAPH_STORAGE = "NetworkXStorage"
49
+ VECTOR_STORAGE = "NanoVectorDBStorage"
50
+
51
+
52
+ # Initialize rag storage config
53
+ rag_storage_config = RAGStorageConfig()
54
+
55
  # Global progress tracker
56
  scan_progress: Dict = {
57
  "is_scanning": False,
 
80
  return int(tokens)
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # read config.ini
84
  config = configparser.ConfigParser()
85
  config.read("config.ini", "utf-8")
 
87
  redis_uri = config.get("redis", "uri", fallback=None)
88
  if redis_uri:
89
  os.environ["REDIS_URI"] = redis_uri
90
+ rag_storage_config.KV_STORAGE = "RedisKVStorage"
91
+ rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
92
 
93
  # Neo4j config
94
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
 
98
  os.environ["NEO4J_URI"] = neo4j_uri
99
  os.environ["NEO4J_USERNAME"] = neo4j_username
100
  os.environ["NEO4J_PASSWORD"] = neo4j_password
101
+ rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
102
 
103
  # Milvus config
104
  milvus_uri = config.get("milvus", "uri", fallback=None)
 
110
  os.environ["MILVUS_USER"] = milvus_user
111
  os.environ["MILVUS_PASSWORD"] = milvus_password
112
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
113
+ rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
114
 
115
  # MongoDB config
116
  mongo_uri = config.get("mongodb", "uri", fallback=None)
 
118
  if mongo_uri:
119
  os.environ["MONGO_URI"] = mongo_uri
120
  os.environ["MONGO_DATABASE"] = mongo_database
121
+ rag_storage_config.KV_STORAGE = "MongoKVStorage"
122
+ rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
123
 
124
 
125
  def get_default_host(binding_type: str) -> str:
 
521
  help="Cosine similarity threshold (default: from env or 0.4)",
522
  )
523
 
524
+ # Ollama model name
525
  parser.add_argument(
526
  "--simulated-model-name",
527
  type=str,
 
586
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
587
 
588
 
589
+ # LightRAG query mode
590
  class SearchMode(str, Enum):
591
  naive = "naive"
592
  local = "local"
593
  global_ = "global"
594
  hybrid = "hybrid"
595
  mix = "mix"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
 
598
  class QueryRequest(BaseModel):
 
837
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
838
  else {},
839
  embedding_func=embedding_func,
840
+ kv_storage=rag_storage_config.KV_STORAGE,
841
+ graph_storage=rag_storage_config.GRAPH_STORAGE,
842
+ vector_storage=rag_storage_config.VECTOR_STORAGE,
843
+ doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
844
  vector_db_storage_cls_kwargs={
845
  "cosine_better_than_threshold": args.cosine_threshold
846
  },
 
866
  llm_model_max_async=args.max_async,
867
  llm_model_max_token_size=args.max_tokens,
868
  embedding_func=embedding_func,
869
+ kv_storage=rag_storage_config.KV_STORAGE,
870
+ graph_storage=rag_storage_config.GRAPH_STORAGE,
871
+ vector_storage=rag_storage_config.VECTOR_STORAGE,
872
+ doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
873
  vector_db_storage_cls_kwargs={
874
  "cosine_better_than_threshold": args.cosine_threshold
875
  },
 
1392
  async def get_graphs(label: str):
1393
  return await rag.get_graps(nodel_label=label, max_depth=100)
1394
 
1395
+ # Add Ollama API routes
1396
+ ollama_api = OllamaAPI(rag)
1397
+ app.include_router(ollama_api.router, prefix="/api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1398
 
1399
  @app.get("/documents", dependencies=[Depends(optional_api_key)])
1400
  async def documents():
 
1421
  "embedding_binding_host": args.embedding_binding_host,
1422
  "embedding_model": args.embedding_model,
1423
  "max_tokens": args.max_tokens,
1424
+ "kv_storage": rag_storage_config.KV_STORAGE,
1425
+ "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
1426
+ "graph_storage": rag_storage_config.GRAPH_STORAGE,
1427
+ "vector_storage": rag_storage_config.VECTOR_STORAGE,
1428
  },
1429
  }
1430
 
lightrag/api/ollama_api.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Request
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any, Optional
4
+ import logging
5
+ import time
6
+ import json
7
+ import re
8
+ import os
9
+ from enum import Enum
10
+ from fastapi.responses import StreamingResponse
11
+ import asyncio
12
+ from ascii_colors import trace_exception
13
+ from lightrag import LightRAG, QueryParam
14
+
15
+
16
+ class OllamaServerInfos:
17
+ # Constants for emulated Ollama model information
18
+ LIGHTRAG_NAME = "lightrag"
19
+ LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
20
+ LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
21
+ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
22
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
23
+ LIGHTRAG_DIGEST = "sha256:lightrag"
24
+
25
+
26
+ ollama_server_infos = OllamaServerInfos()
27
+
28
+
29
+ # query mode according to query prefix (bypass is not LightRAG quer mode)
30
+ class SearchMode(str, Enum):
31
+ naive = "naive"
32
+ local = "local"
33
+ global_ = "global"
34
+ hybrid = "hybrid"
35
+ mix = "mix"
36
+ bypass = "bypass"
37
+
38
+
39
+ class OllamaMessage(BaseModel):
40
+ role: str
41
+ content: str
42
+ images: Optional[List[str]] = None
43
+
44
+
45
+ class OllamaChatRequest(BaseModel):
46
+ model: str
47
+ messages: List[OllamaMessage]
48
+ stream: bool = True
49
+ options: Optional[Dict[str, Any]] = None
50
+ system: Optional[str] = None
51
+
52
+
53
+ class OllamaChatResponse(BaseModel):
54
+ model: str
55
+ created_at: str
56
+ message: OllamaMessage
57
+ done: bool
58
+
59
+
60
+ class OllamaGenerateRequest(BaseModel):
61
+ model: str
62
+ prompt: str
63
+ system: Optional[str] = None
64
+ stream: bool = False
65
+ options: Optional[Dict[str, Any]] = None
66
+
67
+
68
+ class OllamaGenerateResponse(BaseModel):
69
+ model: str
70
+ created_at: str
71
+ response: str
72
+ done: bool
73
+ context: Optional[List[int]]
74
+ total_duration: Optional[int]
75
+ load_duration: Optional[int]
76
+ prompt_eval_count: Optional[int]
77
+ prompt_eval_duration: Optional[int]
78
+ eval_count: Optional[int]
79
+ eval_duration: Optional[int]
80
+
81
+
82
+ class OllamaVersionResponse(BaseModel):
83
+ version: str
84
+
85
+
86
+ class OllamaModelDetails(BaseModel):
87
+ parent_model: str
88
+ format: str
89
+ family: str
90
+ families: List[str]
91
+ parameter_size: str
92
+ quantization_level: str
93
+
94
+
95
+ class OllamaModel(BaseModel):
96
+ name: str
97
+ model: str
98
+ size: int
99
+ digest: str
100
+ modified_at: str
101
+ details: OllamaModelDetails
102
+
103
+
104
+ class OllamaTagResponse(BaseModel):
105
+ models: List[OllamaModel]
106
+
107
+
108
+ def estimate_tokens(text: str) -> int:
109
+ """Estimate the number of tokens in text
110
+ Chinese characters: approximately 1.5 tokens per character
111
+ English characters: approximately 0.25 tokens per character
112
+ """
113
+ # Use regex to match Chinese and non-Chinese characters separately
114
+ chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
115
+ non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
116
+
117
+ # Calculate estimated token count
118
+ tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
119
+
120
+ return int(tokens)
121
+
122
+
123
+ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
124
+ """Parse query prefix to determine search mode
125
+ Returns tuple of (cleaned_query, search_mode)
126
+ """
127
+ mode_map = {
128
+ "/local ": SearchMode.local,
129
+ "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
130
+ "/naive ": SearchMode.naive,
131
+ "/hybrid ": SearchMode.hybrid,
132
+ "/mix ": SearchMode.mix,
133
+ "/bypass ": SearchMode.bypass,
134
+ }
135
+
136
+ for prefix, mode in mode_map.items():
137
+ if query.startswith(prefix):
138
+ # After removing prefix an leading spaces
139
+ cleaned_query = query[len(prefix) :].lstrip()
140
+ return cleaned_query, mode
141
+
142
+ return query, SearchMode.hybrid
143
+
144
+
145
+ class OllamaAPI:
146
+ def __init__(self, rag: LightRAG):
147
+ self.rag = rag
148
+ self.ollama_server_infos = ollama_server_infos
149
+ self.router = APIRouter()
150
+ self.setup_routes()
151
+
152
+ def setup_routes(self):
153
+ @self.router.get("/version")
154
+ async def get_version():
155
+ """Get Ollama version information"""
156
+ return OllamaVersionResponse(version="0.5.4")
157
+
158
+ @self.router.get("/tags")
159
+ async def get_tags():
160
+ """Return available models acting as an Ollama server"""
161
+ return OllamaTagResponse(
162
+ models=[
163
+ {
164
+ "name": self.ollama_server_infos.LIGHTRAG_MODEL,
165
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
166
+ "size": self.ollama_server_infos.LIGHTRAG_SIZE,
167
+ "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
168
+ "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
169
+ "details": {
170
+ "parent_model": "",
171
+ "format": "gguf",
172
+ "family": self.ollama_server_infos.LIGHTRAG_NAME,
173
+ "families": [self.ollama_server_infos.LIGHTRAG_NAME],
174
+ "parameter_size": "13B",
175
+ "quantization_level": "Q4_0",
176
+ },
177
+ }
178
+ ]
179
+ )
180
+
181
+ @self.router.post("/generate")
182
+ async def generate(raw_request: Request, request: OllamaGenerateRequest):
183
+ """Handle generate completion requests acting as an Ollama model
184
+ For compatibility purpose, the request is not processed by LightRAG,
185
+ and will be handled by underlying LLM model.
186
+ """
187
+ try:
188
+ query = request.prompt
189
+ start_time = time.time_ns()
190
+ prompt_tokens = estimate_tokens(query)
191
+
192
+ if request.system:
193
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
194
+
195
+ if request.stream:
196
+ response = await self.rag.llm_model_func(
197
+ query, stream=True, **self.rag.llm_model_kwargs
198
+ )
199
+
200
+ async def stream_generator():
201
+ try:
202
+ first_chunk_time = None
203
+ last_chunk_time = None
204
+ total_response = ""
205
+
206
+ # Ensure response is an async generator
207
+ if isinstance(response, str):
208
+ # If it's a string, send in two parts
209
+ first_chunk_time = time.time_ns()
210
+ last_chunk_time = first_chunk_time
211
+ total_response = response
212
+
213
+ data = {
214
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
215
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
216
+ "response": response,
217
+ "done": False,
218
+ }
219
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
220
+
221
+ completion_tokens = estimate_tokens(total_response)
222
+ total_time = last_chunk_time - start_time
223
+ prompt_eval_time = first_chunk_time - start_time
224
+ eval_time = last_chunk_time - first_chunk_time
225
+
226
+ data = {
227
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
228
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
229
+ "done": True,
230
+ "total_duration": total_time,
231
+ "load_duration": 0,
232
+ "prompt_eval_count": prompt_tokens,
233
+ "prompt_eval_duration": prompt_eval_time,
234
+ "eval_count": completion_tokens,
235
+ "eval_duration": eval_time,
236
+ }
237
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
238
+ else:
239
+ async for chunk in response:
240
+ if chunk:
241
+ if first_chunk_time is None:
242
+ first_chunk_time = time.time_ns()
243
+
244
+ last_chunk_time = time.time_ns()
245
+
246
+ total_response += chunk
247
+ data = {
248
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
249
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
250
+ "response": chunk,
251
+ "done": False,
252
+ }
253
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
254
+
255
+ completion_tokens = estimate_tokens(total_response)
256
+ total_time = last_chunk_time - start_time
257
+ prompt_eval_time = first_chunk_time - start_time
258
+ eval_time = last_chunk_time - first_chunk_time
259
+
260
+ data = {
261
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
262
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
263
+ "done": True,
264
+ "total_duration": total_time,
265
+ "load_duration": 0,
266
+ "prompt_eval_count": prompt_tokens,
267
+ "prompt_eval_duration": prompt_eval_time,
268
+ "eval_count": completion_tokens,
269
+ "eval_duration": eval_time,
270
+ }
271
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
272
+ return
273
+
274
+ except Exception as e:
275
+ trace_exception(e)
276
+ raise
277
+
278
+ return StreamingResponse(
279
+ stream_generator(),
280
+ media_type="application/x-ndjson",
281
+ headers={
282
+ "Cache-Control": "no-cache",
283
+ "Connection": "keep-alive",
284
+ "Content-Type": "application/x-ndjson",
285
+ "Access-Control-Allow-Origin": "*",
286
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
287
+ "Access-Control-Allow-Headers": "Content-Type",
288
+ },
289
+ )
290
+ else:
291
+ first_chunk_time = time.time_ns()
292
+ response_text = await self.rag.llm_model_func(
293
+ query, stream=False, **self.rag.llm_model_kwargs
294
+ )
295
+ last_chunk_time = time.time_ns()
296
+
297
+ if not response_text:
298
+ response_text = "No response generated"
299
+
300
+ completion_tokens = estimate_tokens(str(response_text))
301
+ total_time = last_chunk_time - start_time
302
+ prompt_eval_time = first_chunk_time - start_time
303
+ eval_time = last_chunk_time - first_chunk_time
304
+
305
+ return {
306
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
307
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
308
+ "response": str(response_text),
309
+ "done": True,
310
+ "total_duration": total_time,
311
+ "load_duration": 0,
312
+ "prompt_eval_count": prompt_tokens,
313
+ "prompt_eval_duration": prompt_eval_time,
314
+ "eval_count": completion_tokens,
315
+ "eval_duration": eval_time,
316
+ }
317
+ except Exception as e:
318
+ trace_exception(e)
319
+ raise HTTPException(status_code=500, detail=str(e))
320
+
321
+ @self.router.post("/chat")
322
+ async def chat(raw_request: Request, request: OllamaChatRequest):
323
+ """Process chat completion requests acting as an Ollama model
324
+ Routes user queries through LightRAG by selecting query mode based on prefix indicators.
325
+ Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
326
+ """
327
+ try:
328
+ # Get all messages
329
+ messages = request.messages
330
+ if not messages:
331
+ raise HTTPException(status_code=400, detail="No messages provided")
332
+
333
+ # Get the last message as query and previous messages as history
334
+ query = messages[-1].content
335
+ # Convert OllamaMessage objects to dictionaries
336
+ conversation_history = [
337
+ {"role": msg.role, "content": msg.content} for msg in messages[:-1]
338
+ ]
339
+
340
+ # Check for query prefix
341
+ cleaned_query, mode = parse_query_mode(query)
342
+
343
+ start_time = time.time_ns()
344
+ prompt_tokens = estimate_tokens(cleaned_query)
345
+
346
+ param_dict = {
347
+ "mode": mode,
348
+ "stream": request.stream,
349
+ "only_need_context": False,
350
+ "conversation_history": conversation_history,
351
+ "top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50,
352
+ }
353
+
354
+ if (
355
+ hasattr(self.rag, "args")
356
+ and self.rag.args.history_turns is not None
357
+ ):
358
+ param_dict["history_turns"] = self.rag.args.history_turns
359
+
360
+ query_param = QueryParam(**param_dict)
361
+
362
+ if request.stream:
363
+ # Determine if the request is prefix with "/bypass"
364
+ if mode == SearchMode.bypass:
365
+ if request.system:
366
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
367
+ response = await self.rag.llm_model_func(
368
+ cleaned_query,
369
+ stream=True,
370
+ history_messages=conversation_history,
371
+ **self.rag.llm_model_kwargs,
372
+ )
373
+ else:
374
+ response = await self.rag.aquery(
375
+ cleaned_query, param=query_param
376
+ )
377
+
378
+ async def stream_generator():
379
+ first_chunk_time = None
380
+ last_chunk_time = None
381
+ total_response = ""
382
+
383
+ try:
384
+ # Ensure response is an async generator
385
+ if isinstance(response, str):
386
+ # If it's a string, send in two parts
387
+ first_chunk_time = time.time_ns()
388
+ last_chunk_time = first_chunk_time
389
+ total_response = response
390
+
391
+ data = {
392
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
393
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
394
+ "message": {
395
+ "role": "assistant",
396
+ "content": response,
397
+ "images": None,
398
+ },
399
+ "done": False,
400
+ }
401
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
402
+
403
+ completion_tokens = estimate_tokens(total_response)
404
+ total_time = last_chunk_time - start_time
405
+ prompt_eval_time = first_chunk_time - start_time
406
+ eval_time = last_chunk_time - first_chunk_time
407
+
408
+ data = {
409
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
410
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
411
+ "done": True,
412
+ "total_duration": total_time,
413
+ "load_duration": 0,
414
+ "prompt_eval_count": prompt_tokens,
415
+ "prompt_eval_duration": prompt_eval_time,
416
+ "eval_count": completion_tokens,
417
+ "eval_duration": eval_time,
418
+ }
419
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
420
+ else:
421
+ try:
422
+ async for chunk in response:
423
+ if chunk:
424
+ if first_chunk_time is None:
425
+ first_chunk_time = time.time_ns()
426
+
427
+ last_chunk_time = time.time_ns()
428
+
429
+ total_response += chunk
430
+ data = {
431
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
432
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
433
+ "message": {
434
+ "role": "assistant",
435
+ "content": chunk,
436
+ "images": None,
437
+ },
438
+ "done": False,
439
+ }
440
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
441
+ except (asyncio.CancelledError, Exception) as e:
442
+ error_msg = str(e)
443
+ if isinstance(e, asyncio.CancelledError):
444
+ error_msg = "Stream was cancelled by server"
445
+ else:
446
+ error_msg = f"Provider error: {error_msg}"
447
+
448
+ logging.error(f"Stream error: {error_msg}")
449
+
450
+ # Send error message to client
451
+ error_data = {
452
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
453
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
454
+ "message": {
455
+ "role": "assistant",
456
+ "content": f"\n\nError: {error_msg}",
457
+ "images": None,
458
+ },
459
+ "done": False,
460
+ }
461
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
462
+
463
+ # Send final message to close the stream
464
+ final_data = {
465
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
466
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
467
+ "done": True,
468
+ }
469
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
470
+ return
471
+
472
+ if last_chunk_time is not None:
473
+ completion_tokens = estimate_tokens(total_response)
474
+ total_time = last_chunk_time - start_time
475
+ prompt_eval_time = first_chunk_time - start_time
476
+ eval_time = last_chunk_time - first_chunk_time
477
+
478
+ data = {
479
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
480
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
481
+ "done": True,
482
+ "total_duration": total_time,
483
+ "load_duration": 0,
484
+ "prompt_eval_count": prompt_tokens,
485
+ "prompt_eval_duration": prompt_eval_time,
486
+ "eval_count": completion_tokens,
487
+ "eval_duration": eval_time,
488
+ }
489
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
490
+
491
+ except Exception as e:
492
+ error_msg = f"Error in stream_generator: {str(e)}"
493
+ logging.error(error_msg)
494
+
495
+ # Send error message to client
496
+ error_data = {
497
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
498
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
499
+ "error": {"code": "STREAM_ERROR", "message": error_msg},
500
+ }
501
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
502
+
503
+ # Ensure sending end marker
504
+ final_data = {
505
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
506
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
507
+ "done": True,
508
+ }
509
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
510
+ return
511
+
512
+ return StreamingResponse(
513
+ stream_generator(),
514
+ media_type="application/x-ndjson",
515
+ headers={
516
+ "Cache-Control": "no-cache",
517
+ "Connection": "keep-alive",
518
+ "Content-Type": "application/x-ndjson",
519
+ "Access-Control-Allow-Origin": "*",
520
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
521
+ "Access-Control-Allow-Headers": "Content-Type",
522
+ },
523
+ )
524
+ else:
525
+ first_chunk_time = time.time_ns()
526
+
527
+ # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
528
+ match_result = re.search(
529
+ r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
530
+ )
531
+ if match_result or mode == SearchMode.bypass:
532
+ if request.system:
533
+ self.rag.llm_model_kwargs["system_prompt"] = request.system
534
+
535
+ response_text = await self.rag.llm_model_func(
536
+ cleaned_query,
537
+ stream=False,
538
+ history_messages=conversation_history,
539
+ **self.rag.llm_model_kwargs,
540
+ )
541
+ else:
542
+ response_text = await self.rag.aquery(
543
+ cleaned_query, param=query_param
544
+ )
545
+
546
+ last_chunk_time = time.time_ns()
547
+
548
+ if not response_text:
549
+ response_text = "No response generated"
550
+
551
+ completion_tokens = estimate_tokens(str(response_text))
552
+ total_time = last_chunk_time - start_time
553
+ prompt_eval_time = first_chunk_time - start_time
554
+ eval_time = last_chunk_time - first_chunk_time
555
+
556
+ return {
557
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
558
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
559
+ "message": {
560
+ "role": "assistant",
561
+ "content": str(response_text),
562
+ "images": None,
563
+ },
564
+ "done": True,
565
+ "total_duration": total_time,
566
+ "load_duration": 0,
567
+ "prompt_eval_count": prompt_tokens,
568
+ "prompt_eval_duration": prompt_eval_time,
569
+ "eval_count": completion_tokens,
570
+ "eval_duration": eval_time,
571
+ }
572
+ except Exception as e:
573
+ trace_exception(e)
574
+ raise HTTPException(status_code=500, detail=str(e))