YanSte commited on
Commit
678a87e
·
unverified ·
2 Parent(s): 18f2249 cf6afb8

Merge pull request #798 from YanSte/api_improvment

Browse files
Files changed (1) hide show
  1. lightrag/api/lightrag_server.py +165 -87
lightrag/api/lightrag_server.py CHANGED
@@ -13,14 +13,13 @@ import re
13
  from fastapi.staticfiles import StaticFiles
14
  import logging
15
  import argparse
16
- from typing import List, Any, Optional, Dict
17
- from pydantic import BaseModel
18
  from lightrag import LightRAG, QueryParam
 
19
  from lightrag.types import GPTKeywordExtractionFormat
20
  from lightrag.api import __api_version__
21
  from lightrag.utils import EmbeddingFunc
22
- from lightrag.base import DocStatus, DocProcessingStatus
23
- from enum import Enum
24
  from pathlib import Path
25
  import shutil
26
  import aiofiles
@@ -637,71 +636,155 @@ class DocumentManager:
637
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
638
 
639
 
640
- # LightRAG query mode
641
- class SearchMode(str, Enum):
642
- naive = "naive"
643
- local = "local"
644
- global_ = "global"
645
- hybrid = "hybrid"
646
- mix = "mix"
647
 
 
 
 
 
648
 
649
- class QueryRequest(BaseModel):
650
- query: str
 
 
651
 
652
- """Specifies the retrieval mode"""
653
- mode: SearchMode = SearchMode.hybrid
 
 
654
 
655
- """If True, enables streaming output for real-time responses."""
656
- stream: Optional[bool] = None
 
 
 
657
 
658
- """If True, only returns the retrieved context without generating a response."""
659
- only_need_context: Optional[bool] = None
 
 
 
660
 
661
- """If True, only returns the generated prompt without producing a response."""
662
- only_need_prompt: Optional[bool] = None
 
 
 
 
 
 
 
 
 
663
 
664
- """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
665
- response_type: Optional[str] = None
 
 
 
666
 
667
- """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
668
- top_k: Optional[int] = None
 
 
 
 
 
 
 
669
 
670
- """Maximum number of tokens allowed for each retrieved text chunk."""
671
- max_token_for_text_unit: Optional[int] = None
 
 
672
 
673
- """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
674
- max_token_for_global_context: Optional[int] = None
 
 
 
675
 
676
- """Maximum number of tokens allocated for entity descriptions in local retrieval."""
677
- max_token_for_local_context: Optional[int] = None
 
 
678
 
679
- """List of high-level keywords to prioritize in retrieval."""
680
- hl_keywords: Optional[List[str]] = None
 
 
 
 
681
 
682
- """List of low-level keywords to refine retrieval focus."""
683
- ll_keywords: Optional[List[str]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
- """Stores past conversation history to maintain context.
686
- Format: [{"role": "user/assistant", "content": "message"}].
687
- """
688
- conversation_history: Optional[List[dict[str, Any]]] = None
689
 
690
- """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
691
- history_turns: Optional[int] = None
 
 
692
 
693
 
694
  class QueryResponse(BaseModel):
695
- response: str
 
 
696
 
697
 
698
  class InsertTextRequest(BaseModel):
699
- text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
 
702
  class InsertResponse(BaseModel):
703
- status: str
704
- message: str
705
 
706
 
707
  class DocStatusResponse(BaseModel):
@@ -720,33 +803,6 @@ class DocsStatusesResponse(BaseModel):
720
  statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
721
 
722
 
723
- def QueryRequestToQueryParams(request: QueryRequest):
724
- param = QueryParam(mode=request.mode, stream=request.stream)
725
- if request.only_need_context is not None:
726
- param.only_need_context = request.only_need_context
727
- if request.only_need_prompt is not None:
728
- param.only_need_prompt = request.only_need_prompt
729
- if request.response_type is not None:
730
- param.response_type = request.response_type
731
- if request.top_k is not None:
732
- param.top_k = request.top_k
733
- if request.max_token_for_text_unit is not None:
734
- param.max_token_for_text_unit = request.max_token_for_text_unit
735
- if request.max_token_for_global_context is not None:
736
- param.max_token_for_global_context = request.max_token_for_global_context
737
- if request.max_token_for_local_context is not None:
738
- param.max_token_for_local_context = request.max_token_for_local_context
739
- if request.hl_keywords is not None:
740
- param.hl_keywords = request.hl_keywords
741
- if request.ll_keywords is not None:
742
- param.ll_keywords = request.ll_keywords
743
- if request.conversation_history is not None:
744
- param.conversation_history = request.conversation_history
745
- if request.history_turns is not None:
746
- param.history_turns = request.history_turns
747
- return param
748
-
749
-
750
  def get_api_key_dependency(api_key: Optional[str]):
751
  if not api_key:
752
  # If no API key is configured, return a dummy dependency that always succeeds
@@ -1525,6 +1581,37 @@ def create_app(args):
1525
  logging.error(traceback.format_exc())
1526
  raise HTTPException(status_code=500, detail=str(e))
1527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1528
  @app.post(
1529
  "/documents/file",
1530
  response_model=InsertResponse,
@@ -1569,7 +1656,7 @@ def create_app(args):
1569
  raise HTTPException(status_code=500, detail=str(e))
1570
 
1571
  @app.post(
1572
- "/documents/batch",
1573
  response_model=InsertResponse,
1574
  dependencies=[Depends(optional_api_key)],
1575
  )
@@ -1673,20 +1760,14 @@ def create_app(args):
1673
  """
1674
  try:
1675
  response = await rag.aquery(
1676
- request.query, param=QueryRequestToQueryParams(request)
1677
  )
1678
 
1679
  # If response is a string (e.g. cache hit), return directly
1680
  if isinstance(response, str):
1681
  return QueryResponse(response=response)
1682
 
1683
- # If it's an async generator, decide whether to stream based on stream parameter
1684
- if request.stream or hasattr(response, "__aiter__"):
1685
- result = ""
1686
- async for chunk in response:
1687
- result += chunk
1688
- return QueryResponse(response=result)
1689
- elif isinstance(response, dict):
1690
  result = json.dumps(response, indent=2)
1691
  return QueryResponse(response=result)
1692
  else:
@@ -1708,11 +1789,8 @@ def create_app(args):
1708
  StreamingResponse: A streaming response containing the RAG query results.
1709
  """
1710
  try:
1711
- params = QueryRequestToQueryParams(request)
1712
-
1713
- params.stream = True
1714
- response = await rag.aquery( # Use aquery instead of query, and add await
1715
- request.query, param=params
1716
  )
1717
 
1718
  from fastapi.responses import StreamingResponse
@@ -1738,7 +1816,7 @@ def create_app(args):
1738
  "Cache-Control": "no-cache",
1739
  "Connection": "keep-alive",
1740
  "Content-Type": "application/x-ndjson",
1741
- "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
1742
  },
1743
  )
1744
  except Exception as e:
 
13
  from fastapi.staticfiles import StaticFiles
14
  import logging
15
  import argparse
16
+ from typing import List, Any, Literal, Optional, Dict
17
+ from pydantic import BaseModel, Field, field_validator
18
  from lightrag import LightRAG, QueryParam
19
+ from lightrag.base import DocProcessingStatus, DocStatus
20
  from lightrag.types import GPTKeywordExtractionFormat
21
  from lightrag.api import __api_version__
22
  from lightrag.utils import EmbeddingFunc
 
 
23
  from pathlib import Path
24
  import shutil
25
  import aiofiles
 
636
  return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
637
 
638
 
639
+ class QueryRequest(BaseModel):
640
+ query: str = Field(
641
+ min_length=1,
642
+ description="The query text",
643
+ )
 
 
644
 
645
+ mode: Literal["local", "global", "hybrid", "naive", "mix"] = Field(
646
+ default="hybrid",
647
+ description="Query mode",
648
+ )
649
 
650
+ only_need_context: Optional[bool] = Field(
651
+ default=None,
652
+ description="If True, only returns the retrieved context without generating a response.",
653
+ )
654
 
655
+ only_need_prompt: Optional[bool] = Field(
656
+ default=None,
657
+ description="If True, only returns the generated prompt without producing a response.",
658
+ )
659
 
660
+ response_type: Optional[str] = Field(
661
+ min_length=1,
662
+ default=None,
663
+ description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.",
664
+ )
665
 
666
+ top_k: Optional[int] = Field(
667
+ ge=1,
668
+ default=None,
669
+ description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
670
+ )
671
 
672
+ max_token_for_text_unit: Optional[int] = Field(
673
+ gt=1,
674
+ default=None,
675
+ description="Maximum number of tokens allowed for each retrieved text chunk.",
676
+ )
677
+
678
+ max_token_for_global_context: Optional[int] = Field(
679
+ gt=1,
680
+ default=None,
681
+ description="Maximum number of tokens allocated for relationship descriptions in global retrieval.",
682
+ )
683
 
684
+ max_token_for_local_context: Optional[int] = Field(
685
+ gt=1,
686
+ default=None,
687
+ description="Maximum number of tokens allocated for entity descriptions in local retrieval.",
688
+ )
689
 
690
+ hl_keywords: Optional[List[str]] = Field(
691
+ default=None,
692
+ description="List of high-level keywords to prioritize in retrieval.",
693
+ )
694
+
695
+ ll_keywords: Optional[List[str]] = Field(
696
+ default=None,
697
+ description="List of low-level keywords to refine retrieval focus.",
698
+ )
699
 
700
+ conversation_history: Optional[List[dict[str, Any]]] = Field(
701
+ default=None,
702
+ description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].",
703
+ )
704
 
705
+ history_turns: Optional[int] = Field(
706
+ ge=0,
707
+ default=None,
708
+ description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
709
+ )
710
 
711
+ @field_validator("query", mode="after")
712
+ @classmethod
713
+ def query_strip_after(cls, query: str) -> str:
714
+ return query.strip()
715
 
716
+ @field_validator("hl_keywords", mode="after")
717
+ @classmethod
718
+ def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None:
719
+ if hl_keywords is None:
720
+ return None
721
+ return [keyword.strip() for keyword in hl_keywords]
722
 
723
+ @field_validator("ll_keywords", mode="after")
724
+ @classmethod
725
+ def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None:
726
+ if ll_keywords is None:
727
+ return None
728
+ return [keyword.strip() for keyword in ll_keywords]
729
+
730
+ @field_validator("conversation_history", mode="after")
731
+ @classmethod
732
+ def conversation_history_role_check(
733
+ cls, conversation_history: List[dict[str, Any]] | None
734
+ ) -> List[dict[str, Any]] | None:
735
+ if conversation_history is None:
736
+ return None
737
+ for msg in conversation_history:
738
+ if "role" not in msg or msg["role"] not in {"user", "assistant"}:
739
+ raise ValueError(
740
+ "Each message must have a 'role' key with value 'user' or 'assistant'."
741
+ )
742
+ return conversation_history
743
 
744
+ def to_query_params(self, is_stream: bool) -> QueryParam:
745
+ """Converts a QueryRequest instance into a QueryParam instance."""
746
+ # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically
747
+ request_data = self.model_dump(exclude_none=True, exclude={"query"})
748
 
749
+ # Ensure `mode` and `stream` are set explicitly
750
+ param = QueryParam(**request_data)
751
+ param.stream = is_stream
752
+ return param
753
 
754
 
755
  class QueryResponse(BaseModel):
756
+ response: str = Field(
757
+ description="The generated response",
758
+ )
759
 
760
 
761
  class InsertTextRequest(BaseModel):
762
+ text: str = Field(
763
+ min_length=1,
764
+ description="The text to insert",
765
+ )
766
+
767
+ @field_validator("text", mode="after")
768
+ @classmethod
769
+ def strip_after(cls, text: str) -> str:
770
+ return text.strip()
771
+
772
+
773
+ class InsertTextsRequest(BaseModel):
774
+ texts: list[str] = Field(
775
+ min_length=1,
776
+ description="The texts to insert",
777
+ )
778
+
779
+ @field_validator("texts", mode="after")
780
+ @classmethod
781
+ def strip_after(cls, texts: list[str]) -> list[str]:
782
+ return [text.strip() for text in texts]
783
 
784
 
785
  class InsertResponse(BaseModel):
786
+ status: str = Field(description="Status of the operation")
787
+ message: str = Field(description="Message describing the operation result")
788
 
789
 
790
  class DocStatusResponse(BaseModel):
 
803
  statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
804
 
805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  def get_api_key_dependency(api_key: Optional[str]):
807
  if not api_key:
808
  # If no API key is configured, return a dummy dependency that always succeeds
 
1581
  logging.error(traceback.format_exc())
1582
  raise HTTPException(status_code=500, detail=str(e))
1583
 
1584
+ @app.post(
1585
+ "/documents/texts",
1586
+ response_model=InsertResponse,
1587
+ dependencies=[Depends(optional_api_key)],
1588
+ )
1589
+ async def insert_texts(
1590
+ request: InsertTextsRequest, background_tasks: BackgroundTasks
1591
+ ):
1592
+ """
1593
+ Insert texts into the Retrieval-Augmented Generation (RAG) system.
1594
+
1595
+ This endpoint allows you to insert text data into the RAG system for later retrieval and use in generating responses.
1596
+
1597
+ Args:
1598
+ request (InsertTextsRequest): The request body containing the text to be inserted.
1599
+ background_tasks: FastAPI BackgroundTasks for async processing
1600
+
1601
+ Returns:
1602
+ InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
1603
+ """
1604
+ try:
1605
+ background_tasks.add_task(pipeline_index_texts, request.texts)
1606
+ return InsertResponse(
1607
+ status="success",
1608
+ message="Text successfully received. Processing will continue in background.",
1609
+ )
1610
+ except Exception as e:
1611
+ logging.error(f"Error /documents/text: {str(e)}")
1612
+ logging.error(traceback.format_exc())
1613
+ raise HTTPException(status_code=500, detail=str(e))
1614
+
1615
  @app.post(
1616
  "/documents/file",
1617
  response_model=InsertResponse,
 
1656
  raise HTTPException(status_code=500, detail=str(e))
1657
 
1658
  @app.post(
1659
+ "/documents/file_batch",
1660
  response_model=InsertResponse,
1661
  dependencies=[Depends(optional_api_key)],
1662
  )
 
1760
  """
1761
  try:
1762
  response = await rag.aquery(
1763
+ request.query, param=request.to_query_params(False)
1764
  )
1765
 
1766
  # If response is a string (e.g. cache hit), return directly
1767
  if isinstance(response, str):
1768
  return QueryResponse(response=response)
1769
 
1770
+ if isinstance(response, dict):
 
 
 
 
 
 
1771
  result = json.dumps(response, indent=2)
1772
  return QueryResponse(response=result)
1773
  else:
 
1789
  StreamingResponse: A streaming response containing the RAG query results.
1790
  """
1791
  try:
1792
+ response = await rag.aquery(
1793
+ request.query, param=request.to_query_params(True)
 
 
 
1794
  )
1795
 
1796
  from fastapi.responses import StreamingResponse
 
1816
  "Cache-Control": "no-cache",
1817
  "Connection": "keep-alive",
1818
  "Content-Type": "application/x-ndjson",
1819
+ "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx
1820
  },
1821
  )
1822
  except Exception as e: