yangdx commited on
Commit
41e44a6
·
1 Parent(s): 2ef219d

完成ollma接口的代码编写

Browse files
Files changed (2) hide show
  1. lightrag/api/lightrag_ollama.py +128 -7
  2. setup.py +1 -0
lightrag/api/lightrag_ollama.py CHANGED
@@ -2,14 +2,11 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
2
  from pydantic import BaseModel
3
  import logging
4
  import argparse
 
5
  from lightrag import LightRAG, QueryParam
6
- # from lightrag.llm import lollms_model_complete, lollms_embed
7
- # from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding
8
  from lightrag.llm import openai_complete_if_cache, ollama_embedding
9
- # from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
10
 
11
  from lightrag.utils import EmbeddingFunc
12
- from typing import Optional, List
13
  from enum import Enum
14
  from pathlib import Path
15
  import shutil
@@ -26,6 +23,13 @@ from starlette.status import HTTP_403_FORBIDDEN
26
  from dotenv import load_dotenv
27
  load_dotenv()
28
 
 
 
 
 
 
 
 
29
  async def llm_model_func(
30
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
31
  ) -> str:
@@ -219,21 +223,43 @@ class DocumentManager:
219
  class SearchMode(str, Enum):
220
  naive = "naive"
221
  local = "local"
222
- global_ = "global"
223
  hybrid = "hybrid"
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
 
 
 
 
 
 
 
 
226
  class QueryRequest(BaseModel):
227
  query: str
228
  mode: SearchMode = SearchMode.hybrid
229
  stream: bool = False
230
  only_need_context: bool = False
231
 
232
-
233
  class QueryResponse(BaseModel):
234
  response: str
235
 
236
-
237
  class InsertTextRequest(BaseModel):
238
  text: str
239
  description: Optional[str] = None
@@ -555,6 +581,101 @@ def create_app(args):
555
  except Exception as e:
556
  raise HTTPException(status_code=500, detail=str(e))
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  @app.get("/health", dependencies=[Depends(optional_api_key)])
559
  async def get_status():
560
  """Get current system status"""
 
2
  from pydantic import BaseModel
3
  import logging
4
  import argparse
5
+ from typing import List, Dict, Any, Optional
6
  from lightrag import LightRAG, QueryParam
 
 
7
  from lightrag.llm import openai_complete_if_cache, ollama_embedding
 
8
 
9
  from lightrag.utils import EmbeddingFunc
 
10
  from enum import Enum
11
  from pathlib import Path
12
  import shutil
 
23
  from dotenv import load_dotenv
24
  load_dotenv()
25
 
26
+ # Constants for model information
27
+ LIGHTRAG_NAME = "lightrag"
28
+ LIGHTRAG_TAG = "latest"
29
+ LIGHTRAG_MODEL = "{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
30
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
31
+ LIGHTRAG_DIGEST = "sha256:lightrag"
32
+
33
  async def llm_model_func(
34
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
35
  ) -> str:
 
223
  class SearchMode(str, Enum):
224
  naive = "naive"
225
  local = "local"
226
+ global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
227
  hybrid = "hybrid"
228
 
229
+ # Ollama API compatible models
230
+ class OllamaMessage(BaseModel):
231
+ role: str
232
+ content: str
233
+
234
+ class OllamaChatRequest(BaseModel):
235
+ model: str = LIGHTRAG_MODEL
236
+ messages: List[OllamaMessage]
237
+ stream: bool = False
238
+ options: Optional[Dict[str, Any]] = None
239
+
240
+ class OllamaChatResponse(BaseModel):
241
+ model: str
242
+ created_at: str
243
+ message: OllamaMessage
244
+ done: bool
245
 
246
+ class OllamaVersionResponse(BaseModel):
247
+ version: str
248
+ build: str = "default"
249
+
250
+ class OllamaTagResponse(BaseModel):
251
+ models: List[Dict[str, str]]
252
+
253
+ # Original LightRAG models
254
  class QueryRequest(BaseModel):
255
  query: str
256
  mode: SearchMode = SearchMode.hybrid
257
  stream: bool = False
258
  only_need_context: bool = False
259
 
 
260
  class QueryResponse(BaseModel):
261
  response: str
262
 
 
263
  class InsertTextRequest(BaseModel):
264
  text: str
265
  description: Optional[str] = None
 
581
  except Exception as e:
582
  raise HTTPException(status_code=500, detail=str(e))
583
 
584
+ # Ollama compatible API endpoints
585
+ @app.get("/api/version")
586
+ async def get_version():
587
+ """Get Ollama version information"""
588
+ return OllamaVersionResponse(
589
+ version="0.1.0"
590
+ )
591
+
592
+ @app.get("/api/tags")
593
+ async def get_tags():
594
+ """Get available models"""
595
+ return OllamaTagResponse(
596
+ models=[{
597
+ "name": LIGHTRAG_NAME,
598
+ "tag": LIGHTRAG_TAG,
599
+ "size": 0,
600
+ "digest": LIGHTRAG_DIGEST,
601
+ "modified_at": LIGHTRAG_CREATED_AT
602
+ }]
603
+ )
604
+
605
+ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
606
+ """Parse query prefix to determine search mode
607
+ Returns tuple of (cleaned_query, search_mode)
608
+ """
609
+ mode_map = {
610
+ "/local ": SearchMode.local,
611
+ "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
612
+ "/naive ": SearchMode.naive,
613
+ "/hybrid ": SearchMode.hybrid
614
+ }
615
+
616
+ for prefix, mode in mode_map.items():
617
+ if query.startswith(prefix):
618
+ return query[len(prefix):], mode
619
+
620
+ return query, SearchMode.hybrid
621
+
622
+ @app.post("/api/chat")
623
+ async def chat(request: OllamaChatRequest):
624
+ """Handle chat completion requests"""
625
+ try:
626
+ # Convert chat format to query
627
+ query = request.messages[-1].content if request.messages else ""
628
+
629
+ # Parse query mode and clean query
630
+ cleaned_query, mode = parse_query_mode(query)
631
+
632
+ # Call RAG with determined mode
633
+ response = await rag.aquery(
634
+ cleaned_query,
635
+ param=QueryParam(
636
+ mode=mode,
637
+ stream=request.stream
638
+ )
639
+ )
640
+
641
+ if request.stream:
642
+ async def stream_generator():
643
+ result = ""
644
+ async for chunk in response:
645
+ result += chunk
646
+ yield OllamaChatResponse(
647
+ model=LIGHTRAG_MODEL,
648
+ created_at=LIGHTRAG_CREATED_AT,
649
+ message=OllamaMessage(
650
+ role="assistant",
651
+ content=chunk
652
+ ),
653
+ done=False
654
+ )
655
+ # Send final message
656
+ yield OllamaChatResponse(
657
+ model=LIGHTRAG_MODEL,
658
+ created_at=LIGHTRAG_CREATED_AT,
659
+ message=OllamaMessage(
660
+ role="assistant",
661
+ content=result
662
+ ),
663
+ done=True
664
+ )
665
+ return stream_generator()
666
+ else:
667
+ return OllamaChatResponse(
668
+ model=LIGHTRAG_MODEL,
669
+ created_at=LIGHTRAG_CREATED_AT,
670
+ message=OllamaMessage(
671
+ role="assistant",
672
+ content=response
673
+ ),
674
+ done=True
675
+ )
676
+ except Exception as e:
677
+ raise HTTPException(status_code=500, detail=str(e))
678
+
679
  @app.get("/health", dependencies=[Depends(optional_api_key)])
680
  async def get_status():
681
  """Get current system status"""
setup.py CHANGED
@@ -101,6 +101,7 @@ setuptools.setup(
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
 
104
  ],
105
  },
106
  )
 
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
104
+ "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
105
  ],
106
  },
107
  )