diff --git a/README-zh.md b/README-zh.md index 94a11b61e9e66248c3732fc45b21b3c7e3a6a515..d345562f07780131e779f3bf3b255c3f9b0eecb3 100644 --- a/README-zh.md +++ b/README-zh.md @@ -11,7 +11,6 @@ - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。 - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。 - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。 -- [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。 - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。 - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。 - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。 @@ -1085,9 +1084,10 @@ rag.clear_cache(modes=["local"]) | **参数** | **类型** | **说明** | **默认值** | |--------------|----------|-----------------|-------------| | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` | -| **kv_storage** | `str` | 文档和文本块的存储类型。支持的类型:`JsonKVStorage`、`OracleKVStorage` | `JsonKVStorage` | -| **vector_storage** | `str` | 嵌入向量的存储类型。支持的类型:`NanoVectorDBStorage`、`OracleVectorDBStorage` | `NanoVectorDBStorage` | -| **graph_storage** | `str` | 图边和节点的存储类型。支持的类型:`NetworkXStorage`、`Neo4JStorage`、`OracleGraphStorage` | `NetworkXStorage` | +| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | +| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | +| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | +| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` | | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` | | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` | | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` | diff --git a/README.md b/README.md index 0d04b015f4bc8a514d519da79a00af5e1fd55be9..e154c7196fa96f562b81d34fb5c333d4850472fa 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author. -- [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py). - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge. - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage). @@ -1145,9 +1144,10 @@ Valid modes are: | **Parameter** | **Type** | **Explanation** | **Default** | |--------------|----------|-----------------|-------------| | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | -| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | -| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | -| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | +| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | +| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | +| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | +| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` | | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | diff --git a/config.ini.example b/config.ini.example index 3041611e3157818fa1cbce52101f079627d5424d..5ff7cfbbdd833338f3e04ddcbb0fd9439a426899 100644 --- a/config.ini.example +++ b/config.ini.example @@ -13,23 +13,6 @@ uri=redis://localhost:6379/1 [qdrant] uri = http://localhost:16333 -[oracle] -dsn = localhost:1521/XEPDB1 -user = your_username -password = your_password -config_dir = /path/to/oracle/config -wallet_location = /path/to/wallet # 可选 -wallet_password = your_wallet_password # 可选 -workspace = default # 可选,默认为default - -[tidb] -host = localhost -port = 4000 -user = your_username -password = your_password -database = your_database -workspace = default # 可选,默认为default - [postgres] host = localhost port = 5432 diff --git a/env.example b/env.example index 20d80d43846c5b07397eda4ec6a3bf0fb023643d..d21bbef6fc437c3e893d3b20a356028b93ce8d77 100644 --- a/env.example +++ b/env.example @@ -4,11 +4,9 @@ # HOST=0.0.0.0 # PORT=9621 # WORKERS=2 -### separating data from difference Lightrag instances -# NAMESPACE_PREFIX=lightrag -### Max nodes return from grap retrieval -# MAX_GRAPH_NODES=1000 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 +WEBUI_TITLE='Graph RAG Engine' +WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System" ### Optional SSL Configuration # SSL=true @@ -22,6 +20,9 @@ ### Ollama Emulating Model Tag # OLLAMA_EMULATING_MODEL_TAG=latest +### Max nodes return from grap retrieval +# MAX_GRAPH_NODES=1000 + ### Logging level # LOG_LEVEL=INFO # VERBOSE=False @@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage LIGHTRAG_GRAPH_STORAGE=NetworkXStorage LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage -### Oracle Database Configuration -ORACLE_DSN=localhost:1521/XEPDB1 -ORACLE_USER=your_username -ORACLE_PASSWORD='your_password' -ORACLE_CONFIG_DIR=/path/to/oracle/config -#ORACLE_WALLET_LOCATION=/path/to/wallet -#ORACLE_WALLET_PASSWORD='your_password' -### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) -#ORACLE_WORKSPACE=default - -### TiDB Configuration -TIDB_HOST=localhost -TIDB_PORT=4000 -TIDB_USER=your_username -TIDB_PASSWORD='your_password' -TIDB_DATABASE=your_database -### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) -#TIDB_WORKSPACE=default +### TiDB Configuration (Deprecated) +# TIDB_HOST=localhost +# TIDB_PORT=4000 +# TIDB_USER=your_username +# TIDB_PASSWORD='your_password' +# TIDB_DATABASE=your_database +### separating all data from difference Lightrag instances(deprecating) +# TIDB_WORKSPACE=default ### PostgreSQL Configuration POSTGRES_HOST=localhost @@ -135,8 +126,8 @@ POSTGRES_PORT=5432 POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database -### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) -#POSTGRES_WORKSPACE=default +### separating all data from difference Lightrag instances(deprecating) +# POSTGRES_WORKSPACE=default ### Independent AGM Configuration(not for AMG embedded in PostreSQL) AGE_POSTGRES_DB= @@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD= AGE_POSTGRES_HOST= # AGE_POSTGRES_PORT=8529 -### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) # AGE Graph Name(apply to PostgreSQL and independent AGM) +### AGE_GRAPH_NAME is precated # AGE_GRAPH_NAME=lightrag ### Neo4j Configuration @@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password' ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_DATABASE=LightRAG -### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future) +### separating all data from difference Lightrag instances(deprecating) # MONGODB_GRAPH=false ### Milvus Configuration @@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379 ### For JWT Auth # AUTH_ACCOUNTS='admin:admin123,user1:pass456' # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server -# TOKEN_EXPIRE_HOURS=4 +# TOKEN_EXPIRE_HOURS=48 +# GUEST_TOKEN_EXPIRE_HOURS=24 +# JWT_ALGORITHM=HS256 ### API-Key to access LightRAG Server API # LIGHTRAG_API_KEY=your-secure-api-key-here diff --git a/examples/lightrag_api_ollama_demo.py b/examples/lightrag_api_ollama_demo.py deleted file mode 100644 index dad2a2e01db9c1cf7be78901e6179601f96c7eb1..0000000000000000000000000000000000000000 --- a/examples/lightrag_api_ollama_demo.py +++ /dev/null @@ -1,188 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile -from contextlib import asynccontextmanager -from pydantic import BaseModel -import os -from lightrag import LightRAG, QueryParam -from lightrag.llm.ollama import ollama_embed, ollama_model_complete -from lightrag.utils import EmbeddingFunc -from typing import Optional -import asyncio -import nest_asyncio -import aiofiles -from lightrag.kg.shared_storage import initialize_pipeline_status - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -DEFAULT_RAG_DIR = "index_default" - -DEFAULT_INPUT_FILE = "book.txt" -INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}") -print(f"INPUT_FILE: {INPUT_FILE}") - -# Configure working directory -WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") -print(f"WORKING_DIR: {WORKING_DIR}") - - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -async def init(): - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name="gemma2:9b", - llm_model_max_async=4, - llm_model_max_token_size=8192, - llm_model_kwargs={ - "host": "http://localhost:11434", - "options": {"num_ctx": 8192}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=768, - max_token_size=8192, - func=lambda texts: ollama_embed( - texts, embed_model="nomic-embed-text", host="http://localhost:11434" - ), - ), - ) - - # Add initialization code - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global rag - rag = await init() - print("done!") - yield - - -app = FastAPI( - title="LightRAG API", description="API for RAG operations", lifespan=lifespan -) - - -# Data models -class QueryRequest(BaseModel): - query: str - mode: str = "hybrid" - only_need_context: bool = False - - -class InsertRequest(BaseModel): - text: str - - -class Response(BaseModel): - status: str - data: Optional[str] = None - message: Optional[str] = None - - -# API routes -@app.post("/query", response_model=Response) -async def query_endpoint(request: QueryRequest): - try: - loop = asyncio.get_event_loop() - result = await loop.run_in_executor( - None, - lambda: rag.query( - request.query, - param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context - ), - ), - ) - return Response(status="success", data=result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# insert by text -@app.post("/insert", response_model=Response) -async def insert_endpoint(request: InsertRequest): - try: - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(request.text)) - return Response(status="success", message="Text inserted successfully") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# insert by file in payload -@app.post("/insert_file", response_model=Response) -async def insert_file(file: UploadFile = File(...)): - try: - file_content = await file.read() - # Read file content - try: - content = file_content.decode("utf-8") - except UnicodeDecodeError: - # If UTF-8 decoding fails, try other encodings - content = file_content.decode("gbk") - # Insert file content - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(content)) - - return Response( - status="success", - message=f"File content from {file.filename} inserted successfully", - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -# insert by local default file -@app.post("/insert_default_file", response_model=Response) -@app.get("/insert_default_file", response_model=Response) -async def insert_default_file(): - try: - # Read file content from book.txt - async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file: - content = await file.read() - print(f"read input file {INPUT_FILE} successfully") - # Insert file content - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(content)) - - return Response( - status="success", - message=f"File content from {INPUT_FILE} inserted successfully", - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/health") -async def health_check(): - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8020) - -# Usage example -# To run the server, use the following command in your terminal: -# python lightrag_api_openai_compatible_demo.py - -# Example requests: -# 1. Query: -# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' - -# 2. Insert text: -# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' - -# 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" - -# 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py deleted file mode 100644 index 312be872691fbf0ade0c0623dd03bec1f5967188..0000000000000000000000000000000000000000 --- a/examples/lightrag_api_openai_compatible_demo.py +++ /dev/null @@ -1,204 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile -from contextlib import asynccontextmanager -from pydantic import BaseModel -import os -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc -import numpy as np -from typing import Optional -import asyncio -import nest_asyncio -from lightrag.kg.shared_storage import initialize_pipeline_status - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -DEFAULT_RAG_DIR = "index_default" -app = FastAPI(title="LightRAG API", description="API for RAG operations") - -# Configure working directory -WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") -print(f"WORKING_DIR: {WORKING_DIR}") -LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") -print(f"LLM_MODEL: {LLM_MODEL}") -EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") -print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") -EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) -print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") -BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1") -print(f"BASE_URL: {BASE_URL}") -API_KEY = os.environ.get("API_KEY", "xxxxxxxx") -print(f"API_KEY: {API_KEY}") - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - - -# LLM model function - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - model=LLM_MODEL, - prompt=prompt, - system_prompt=system_prompt, - history_messages=history_messages, - base_url=BASE_URL, - api_key=API_KEY, - **kwargs, - ) - - -# Embedding function - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await openai_embed( - texts=texts, - model=EMBEDDING_MODEL, - base_url=BASE_URL, - api_key=API_KEY, - ) - - -async def get_embedding_dim(): - test_text = ["This is a test sentence."] - embedding = await embedding_func(test_text) - embedding_dim = embedding.shape[1] - print(f"{embedding_dim=}") - return embedding_dim - - -# Initialize RAG instance -async def init(): - embedding_dimension = await get_embedding_dim() - - rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=EMBEDDING_MAX_TOKEN_SIZE, - func=embedding_func, - ), - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global rag - rag = await init() - print("done!") - yield - - -app = FastAPI( - title="LightRAG API", description="API for RAG operations", lifespan=lifespan -) - -# Data models - - -class QueryRequest(BaseModel): - query: str - mode: str = "hybrid" - only_need_context: bool = False - - -class InsertRequest(BaseModel): - text: str - - -class Response(BaseModel): - status: str - data: Optional[str] = None - message: Optional[str] = None - - -# API routes - - -@app.post("/query", response_model=Response) -async def query_endpoint(request: QueryRequest): - try: - loop = asyncio.get_event_loop() - result = await loop.run_in_executor( - None, - lambda: rag.query( - request.query, - param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context - ), - ), - ) - return Response(status="success", data=result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/insert", response_model=Response) -async def insert_endpoint(request: InsertRequest): - try: - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(request.text)) - return Response(status="success", message="Text inserted successfully") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/insert_file", response_model=Response) -async def insert_file(file: UploadFile = File(...)): - try: - file_content = await file.read() - # Read file content - try: - content = file_content.decode("utf-8") - except UnicodeDecodeError: - # If UTF-8 decoding fails, try other encodings - content = file_content.decode("gbk") - # Insert file content - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(content)) - - return Response( - status="success", - message=f"File content from {file.filename} inserted successfully", - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/health") -async def health_check(): - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8020) - -# Usage example -# To run the server, use the following command in your terminal: -# python lightrag_api_openai_compatible_demo.py - -# Example requests: -# 1. Query: -# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' - -# 2. Insert text: -# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' - -# 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" - -# 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py deleted file mode 100644 index 3a82f479221f8ce50a00cdb51f94aa65d2a3c987..0000000000000000000000000000000000000000 --- a/examples/lightrag_api_oracle_demo.py +++ /dev/null @@ -1,267 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile -from fastapi import Query -from contextlib import asynccontextmanager -from pydantic import BaseModel -from typing import Optional, Any - -import sys -import os - - -from pathlib import Path - -import asyncio -import nest_asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc -import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status - - -print(os.getcwd()) -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) - - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - -DEFAULT_RAG_DIR = "index_default" - - -# We use OpenAI compatible API to call LLM on Oracle Cloud -# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway -BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" -APIKEY = "ocigenerativeai" - -# Configure working directory -WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") -print(f"WORKING_DIR: {WORKING_DIR}") -LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024") -print(f"LLM_MODEL: {LLM_MODEL}") -EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") -print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") -EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) -print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -os.environ["ORACLE_USER"] = "" -os.environ["ORACLE_PASSWORD"] = "" -os.environ["ORACLE_DSN"] = "" -os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir" -os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location" -os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password" -os.environ["ORACLE_WORKSPACE"] = "company" - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - LLM_MODEL, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=APIKEY, - base_url=BASE_URL, - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await openai_embed( - texts, - model=EMBEDDING_MODEL, - api_key=APIKEY, - base_url=BASE_URL, - ) - - -async def get_embedding_dim(): - test_text = ["This is a test sentence."] - embedding = await embedding_func(test_text) - embedding_dim = embedding.shape[1] - return embedding_dim - - -async def init(): - # Detect embedding dimension - embedding_dimension = await get_embedding_dim() - print(f"Detected embedding dimension: {embedding_dimension}") - # Create Oracle DB connection - # The `config` parameter is the connection configuration of Oracle DB - # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html - # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query - # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - - # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage - rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage", - ) - - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -# Extract and Insert into LightRAG storage -# with open("./dickens/book.txt", "r", encoding="utf-8") as f: -# await rag.ainsert(f.read()) - -# # Perform search in different modes -# modes = ["naive", "local", "global", "hybrid"] -# for mode in modes: -# print("="*20, mode, "="*20) -# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode))) -# print("-"*100, "\n") - -# Data models - - -class QueryRequest(BaseModel): - query: str - mode: str = "hybrid" - only_need_context: bool = False - only_need_prompt: bool = False - - -class DataRequest(BaseModel): - limit: int = 100 - - -class InsertRequest(BaseModel): - text: str - - -class Response(BaseModel): - status: str - data: Optional[Any] = None - message: Optional[str] = None - - -# API routes - -rag = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global rag - rag = await init() - print("done!") - yield - - -app = FastAPI( - title="LightRAG API", description="API for RAG operations", lifespan=lifespan -) - - -@app.post("/query", response_model=Response) -async def query_endpoint(request: QueryRequest): - # try: - # loop = asyncio.get_event_loop() - if request.mode == "naive": - top_k = 3 - else: - top_k = 60 - result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - only_need_context=request.only_need_context, - only_need_prompt=request.only_need_prompt, - top_k=top_k, - ), - ) - return Response(status="success", data=result) - # except Exception as e: - # raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/data", response_model=Response) -async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)): - if type == "nodes": - result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit) - elif type == "edges": - result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit) - elif type == "statistics": - result = await rag.chunk_entity_relation_graph.get_statistics() - return Response(status="success", data=result) - - -@app.post("/insert", response_model=Response) -async def insert_endpoint(request: InsertRequest): - try: - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(request.text)) - return Response(status="success", message="Text inserted successfully") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/insert_file", response_model=Response) -async def insert_file(file: UploadFile = File(...)): - try: - file_content = await file.read() - # Read file content - try: - content = file_content.decode("utf-8") - except UnicodeDecodeError: - # If UTF-8 decoding fails, try other encodings - content = file_content.decode("gbk") - # Insert file content - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, lambda: rag.insert(content)) - - return Response( - status="success", - message=f"File content from {file.filename} inserted successfully", - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/health") -async def health_check(): - return {"status": "healthy"} - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="127.0.0.1", port=8020) - -# Usage example -# To run the server, use the following command in your terminal: -# python lightrag_api_openai_compatible_demo.py - -# Example requests: -# 1. Query: -# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' - -# 2. Insert text: -# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' - -# 3. Insert file: -# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt" - - -# 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_ollama_gremlin_demo.py b/examples/lightrag_ollama_gremlin_demo.py index 893b5606c1463838ed8a37c33d5edf9bab56e370..7ae6281086af5b87038b367cf24037f03d9a6964 100644 --- a/examples/lightrag_ollama_gremlin_demo.py +++ b/examples/lightrag_ollama_gremlin_demo.py @@ -1,3 +1,7 @@ +############################################## +# Gremlin storage implementation is deprecated +############################################## + import asyncio import inspect import os diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py deleted file mode 100644 index 6663f6a134f7610df67f9c255cdd97c68659471b..0000000000000000000000000000000000000000 --- a/examples/lightrag_oracle_demo.py +++ /dev/null @@ -1,141 +0,0 @@ -import sys -import os -from pathlib import Path -import asyncio -from lightrag import LightRAG, QueryParam -from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc -import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status - -print(os.getcwd()) -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) - -WORKING_DIR = "./dickens" - -# We use OpenAI compatible API to call LLM on Oracle Cloud -# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway -BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" -APIKEY = "ocigenerativeai" -CHATMODEL = "cohere.command-r-plus" -EMBEDMODEL = "cohere.embed-multilingual-v3.0" -CHUNK_TOKEN_SIZE = 1024 -MAX_TOKENS = 4000 - -if not os.path.exists(WORKING_DIR): - os.mkdir(WORKING_DIR) - -os.environ["ORACLE_USER"] = "username" -os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx" -os.environ["ORACLE_DSN"] = "xxxxxxx_medium" -os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir" -os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location" -os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password" -os.environ["ORACLE_WORKSPACE"] = "company" - - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - CHATMODEL, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=APIKEY, - base_url=BASE_URL, - **kwargs, - ) - - -async def embedding_func(texts: list[str]) -> np.ndarray: - return await openai_embed( - texts, - model=EMBEDMODEL, - api_key=APIKEY, - base_url=BASE_URL, - ) - - -async def get_embedding_dim(): - test_text = ["This is a test sentence."] - embedding = await embedding_func(test_text) - embedding_dim = embedding.shape[1] - return embedding_dim - - -async def initialize_rag(): - # Detect embedding dimension - embedding_dimension = await get_embedding_dim() - print(f"Detected embedding dimension: {embedding_dimension}") - - # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage - # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt - rag = LightRAG( - # log_level="DEBUG", - working_dir=WORKING_DIR, - entity_extract_max_gleaning=1, - enable_llm_cache=True, - enable_llm_cache_for_entity_extract=True, - embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, - chunk_token_size=CHUNK_TOKEN_SIZE, - llm_model_max_token_size=MAX_TOKENS, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=500, - func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage", - addon_params={ - "example_number": 1, - "language": "Simplfied Chinese", - "entity_types": ["organization", "person", "geo", "event"], - "insert_batch_size": 2, - }, - ) - await rag.initialize_storages() - await initialize_pipeline_status() - - return rag - - -async def main(): - try: - # Initialize RAG instance - rag = await initialize_rag() - - # Extract and Insert into LightRAG storage - with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: - all_text = f.read() - texts = [x for x in all_text.split("\n") if x] - - # New mode use pipeline - await rag.apipeline_enqueue_documents(texts) - await rag.apipeline_process_enqueue_documents() - - # Old method use ainsert - # await rag.ainsert(texts) - - # Perform search in different modes - modes = ["naive", "local", "global", "hybrid"] - for mode in modes: - print("=" * 20, mode, "=" * 20) - print( - await rag.aquery( - "What are the top themes in this story?", - param=QueryParam(mode=mode), - ) - ) - print("-" * 100, "\n") - - except Exception as e: - print(f"An error occurred: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index 5269556080fbf8cb8ac8d187d03b9e50407c3627..50eac2ca24421ddae212a53f3e6e8f04a272174b 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -1,3 +1,7 @@ +########################################### +# TiDB storage implementation is deprecated +########################################### + import asyncio import os diff --git a/lightrag/api/README-zh.md b/lightrag/api/README-zh.md index 4bf31a61cd2eb222c1da082cf4fc08b2c7675766..0371865600c2f794c767a67edeee9d101d0175ef 100644 --- a/lightrag/api/README-zh.md +++ b/lightrag/api/README-zh.md @@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的: ``` JsonKVStorage JsonFile(默认) -MongoKVStorage MogonDB -RedisKVStorage Redis -TiDBKVStorage TiDB PGKVStorage Postgres -OracleKVStorage Oracle +RedisKVStorage Redis +MongoKVStorage MogonDB ``` * GRAPH_STORAGE 支持的实现名称 @@ -303,25 +301,19 @@ OracleKVStorage Oracle ``` NetworkXStorage NetworkX(默认) Neo4JStorage Neo4J -MongoGraphStorage MongoDB -TiDBGraphStorage TiDB -AGEStorage AGE -GremlinStorage Gremlin PGGraphStorage Postgres -OracleGraphStorage Postgres +AGEStorage AGE ``` * VECTOR_STORAGE 支持的实现名称 ``` NanoVectorDBStorage NanoVector(默认) +PGVectorStorage Postgres MilvusVectorDBStorge Milvus ChromaVectorDBStorage Chroma -TiDBVectorDBStorage TiDB -PGVectorStorage Postgres FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant -OracleVectorDBStorage Oracle MongoVectorDBStorage MongoDB ``` diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 8b2e81778e0aa4823d4eca6ec80abf924dea74c7..27f3d14aea67d697f2eefbf7f874eabd7e53a201 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -302,11 +302,9 @@ Each storage type have servals implementations: ``` JsonKVStorage JsonFile(default) -MongoKVStorage MogonDB -RedisKVStorage Redis -TiDBKVStorage TiDB PGKVStorage Postgres -OracleKVStorage Oracle +RedisKVStorage Redis +MongoKVStorage MogonDB ``` * GRAPH_STORAGE supported implement-name @@ -314,25 +312,19 @@ OracleKVStorage Oracle ``` NetworkXStorage NetworkX(defualt) Neo4JStorage Neo4J -MongoGraphStorage MongoDB -TiDBGraphStorage TiDB -AGEStorage AGE -GremlinStorage Gremlin PGGraphStorage Postgres -OracleGraphStorage Postgres +AGEStorage AGE ``` * VECTOR_STORAGE supported implement-name ``` NanoVectorDBStorage NanoVector(default) -MilvusVectorDBStorage Milvus -ChromaVectorDBStorage Chroma -TiDBVectorDBStorage TiDB PGVectorStorage Postgres +MilvusVectorDBStorge Milvus +ChromaVectorDBStorage Chroma FaissVectorDBStorage Faiss QdrantVectorDBStorage Qdrant -OracleVectorDBStorage Oracle MongoVectorDBStorage MongoDB ``` diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index ec1959de35890f2bb0aa5e174b1322992b9af508..8679dd510b0549ad1b8184dcca9d226193e398c5 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "1.2.8" +__api_version__ = "0132" diff --git a/lightrag/api/auth.py b/lightrag/api/auth.py index 58175b9da7d46e4dc9d232247f0f2e12c99f14a8..0b61095d829f44dc1fd32939a6b2e403a9ddcd38 100644 --- a/lightrag/api/auth.py +++ b/lightrag/api/auth.py @@ -1,9 +1,11 @@ -import os from datetime import datetime, timedelta + import jwt +from dotenv import load_dotenv from fastapi import HTTPException, status from pydantic import BaseModel -from dotenv import load_dotenv + +from .config import global_args # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance @@ -20,13 +22,12 @@ class TokenPayload(BaseModel): class AuthHandler: def __init__(self): - self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46") - self.algorithm = "HS256" - self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4)) - self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2)) - + self.secret = global_args.token_secret + self.algorithm = global_args.jwt_algorithm + self.expire_hours = global_args.token_expire_hours + self.guest_expire_hours = global_args.guest_token_expire_hours self.accounts = {} - auth_accounts = os.getenv("AUTH_ACCOUNTS") + auth_accounts = global_args.auth_accounts if auth_accounts: for account in auth_accounts.split(","): username, password = account.split(":", 1) diff --git a/lightrag/api/config.py b/lightrag/api/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbdb1c97c8ee02241c87ec523e0c297b2460287 --- /dev/null +++ b/lightrag/api/config.py @@ -0,0 +1,335 @@ +""" +Configs for the LightRAG API. +""" + +import os +import argparse +import logging +from dotenv import load_dotenv + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + + +class OllamaServerInfos: + # Constants for emulated Ollama model information + LIGHTRAG_NAME = "lightrag" + LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") + LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" + LIGHTRAG_SIZE = 7365960935 # it's a dummy value + LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" + LIGHTRAG_DIGEST = "sha256:lightrag" + + +ollama_server_infos = OllamaServerInfos() + + +class DefaultRAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + GRAPH_STORAGE = "NetworkXStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" + + +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), + "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), + "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), + "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), + } + return default_hosts.get( + binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") + ) # fallback to ollama if unknown + + +def get_env_value(env_key: str, default: any, value_type: type = str) -> any: + """ + Get value from environment variable with type conversion + + Args: + env_key (str): Environment variable key + default (any): Default value if env variable is not set + value_type (type): Type to convert the value to + + Returns: + any: Converted value from environment or default + """ + value = os.getenv(env_key) + if value is None: + return default + + if value_type is bool: + return value.lower() in ("true", "1", "yes", "t", "on") + try: + return value_type(value) + except ValueError: + return default + + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments with environment variable fallback + + Args: + is_uvicorn_mode: Whether running under uvicorn mode + + Returns: + argparse.Namespace: Parsed arguments + """ + + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + # Server configuration + parser.add_argument( + "--host", + default=get_env_value("HOST", "0.0.0.0"), + help="Server host (default: from env or 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=get_env_value("PORT", 9621, int), + help="Server port (default: from env or 9621)", + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default=get_env_value("WORKING_DIR", "./rag_storage"), + help="Working directory for RAG storage (default: from env or ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default=get_env_value("INPUT_DIR", "./inputs"), + help="Directory containing input documents (default: from env or ./inputs)", + ) + + def timeout_type(value): + if value is None: + return 150 + if value is None or value == "None": + return None + return int(value) + + parser.add_argument( + "--timeout", + default=get_env_value("TIMEOUT", None, timeout_type), + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", + ) + + # RAG configuration + parser.add_argument( + "--max-async", + type=int, + default=get_env_value("MAX_ASYNC", 4, int), + help="Maximum async operations (default: from env or 4)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=get_env_value("MAX_TOKENS", 32768, int), + help="Maximum token size (default: from env or 32768)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default=get_env_value("LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: from env or INFO)", + ) + parser.add_argument( + "--verbose", + action="store_true", + default=get_env_value("VERBOSE", False, bool), + help="Enable verbose debug output(only valid for DEBUG log-level)", + ) + + parser.add_argument( + "--key", + type=str, + default=get_env_value("LIGHTRAG_API_KEY", None), + help="API key for authentication. This protects lightrag server against unauthorized access", + ) + + # Optional https parameters + parser.add_argument( + "--ssl", + action="store_true", + default=get_env_value("SSL", False, bool), + help="Enable HTTPS (default: from env or False)", + ) + parser.add_argument( + "--ssl-certfile", + default=get_env_value("SSL_CERTFILE", None), + help="Path to SSL certificate file (required if --ssl is enabled)", + ) + parser.add_argument( + "--ssl-keyfile", + default=get_env_value("SSL_KEYFILE", None), + help="Path to SSL private key file (required if --ssl is enabled)", + ) + + parser.add_argument( + "--history-turns", + type=int, + default=get_env_value("HISTORY_TURNS", 3, int), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=get_env_value("TOP_K", 60, int), + help="Number of most similar results to return (default: from env or 60)", + ) + parser.add_argument( + "--cosine-threshold", + type=float, + default=get_env_value("COSINE_THRESHOLD", 0.2, float), + help="Cosine similarity threshold (default: from env or 0.4)", + ) + + # Ollama model name + parser.add_argument( + "--simulated-model-name", + type=str, + default=get_env_value( + "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL + ), + help="Number of conversation history turns to include (default: from env or 3)", + ) + + # Namespace + parser.add_argument( + "--namespace-prefix", + type=str, + default=get_env_value("NAMESPACE_PREFIX", ""), + help="Prefix of the namespace", + ) + + parser.add_argument( + "--auto-scan-at-startup", + action="store_true", + default=False, + help="Enable automatic scanning when the program starts", + ) + + # Server workers configuration + parser.add_argument( + "--workers", + type=int, + default=get_env_value("WORKERS", 1, int), + help="Number of worker processes (default: from env or 1)", + ) + + # LLM and embedding bindings + parser.add_argument( + "--llm-binding", + type=str, + default=get_env_value("LLM_BINDING", "ollama"), + choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], + help="LLM binding type (default: from env or ollama)", + ) + parser.add_argument( + "--embedding-binding", + type=str, + default=get_env_value("EMBEDDING_BINDING", "ollama"), + choices=["lollms", "ollama", "openai", "azure_openai"], + help="Embedding binding type (default: from env or ollama)", + ) + + args = parser.parse_args() + + # convert relative path to absolute path + args.working_dir = os.path.abspath(args.working_dir) + args.input_dir = os.path.abspath(args.input_dir) + + # Inject storage configuration from environment variables + args.kv_storage = get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ) + args.doc_status_storage = get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ) + args.graph_storage = get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ) + args.vector_storage = get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ) + + # Get MAX_PARALLEL_INSERT from environment + args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) + + # Handle openai-ollama special case + if args.llm_binding == "openai-ollama": + args.llm_binding = "openai" + args.embedding_binding = "ollama" + + args.llm_binding_host = get_env_value( + "LLM_BINDING_HOST", get_default_host(args.llm_binding) + ) + args.embedding_binding_host = get_env_value( + "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) + ) + args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) + args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") + + # Inject model configuration + args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") + args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") + args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) + args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int) + + # Inject chunk configuration + args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) + args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + + # Inject LLM cache configuration + args.enable_llm_cache_for_extract = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool + ) + + # Inject LLM temperature configuration + args.temperature = get_env_value("TEMPERATURE", 0.5, float) + + # Select Document loading tool (DOCLING, DEFAULT) + args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") + + # Add environment variables that were previously read directly + args.cors_origins = get_env_value("CORS_ORIGINS", "*") + args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en") + args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*") + + # For JWT Auth + args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "") + args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret") + args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int) + args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) + args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") + + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name + + return args + + +def update_uvicorn_mode_config(): + # If in uvicorn mode and workers > 1, force it to 1 and log warning + if global_args.workers > 1: + original_workers = global_args.workers + global_args.workers = 1 + # Log warning directly here + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) + + +global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8110d6d415ea6065fdada28851667a920a0ac480..9f1e6e8a13aa0034ded551733319bd4439ff980d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -19,11 +19,14 @@ from contextlib import asynccontextmanager from dotenv import load_dotenv from lightrag.api.utils_api import ( get_combined_auth_dependency, - parse_args, - get_default_host, display_splash_screen, check_env_file, ) +from .config import ( + global_args, + update_uvicorn_mode_config, + get_default_host, +) import sys from lightrag import LightRAG, __version__ as core_version from lightrag.api import __api_version__ @@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) + +webui_title = os.getenv("WEBUI_TITLE") +webui_description = os.getenv("WEBUI_DESCRIPTION") + # Initialize config parser config = configparser.ConfigParser() config.read("config.ini") @@ -164,10 +171,10 @@ def create_app(args): app = FastAPI(**app_kwargs) def get_cors_origins(): - """Get allowed origins from environment variable + """Get allowed origins from global_args Returns a list of allowed origins, defaults to ["*"] if not set """ - origins_str = os.getenv("CORS_ORIGINS", "*") + origins_str = global_args.cors_origins if origins_str == "*": return ["*"] return [origin.strip() for origin in origins_str.split(",")] @@ -315,9 +322,10 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - namespace_prefix=args.namespace_prefix, + # namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, + addon_params={"language": args.summary_language}, ) else: # azure_openai rag = LightRAG( @@ -345,9 +353,10 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - namespace_prefix=args.namespace_prefix, + # namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, + addon_params={"language": args.summary_language}, ) # Add routes @@ -381,6 +390,8 @@ def create_app(args): "message": "Authentication is disabled. Using guest access.", "core_version": core_version, "api_version": __api_version__, + "webui_title": webui_title, + "webui_description": webui_description, } return { @@ -388,6 +399,8 @@ def create_app(args): "auth_mode": "enabled", "core_version": core_version, "api_version": __api_version__, + "webui_title": webui_title, + "webui_description": webui_description, } @app.post("/login") @@ -404,6 +417,8 @@ def create_app(args): "message": "Authentication is disabled. Using guest access.", "core_version": core_version, "api_version": __api_version__, + "webui_title": webui_title, + "webui_description": webui_description, } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: @@ -454,10 +469,12 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, }, - "core_version": core_version, - "api_version": __api_version__, "auth_mode": auth_mode, "pipeline_busy": pipeline_status.get("busy", False), + "core_version": core_version, + "api_version": __api_version__, + "webui_title": webui_title, + "webui_description": webui_description, } except Exception as e: logger.error(f"Error getting health status: {str(e)}") @@ -490,7 +507,7 @@ def create_app(args): def get_application(args=None): """Factory function for creating the FastAPI application""" if args is None: - args = parse_args() + args = global_args return create_app(args) @@ -611,30 +628,31 @@ def main(): # Configure logging before parsing args configure_logging() - - args = parse_args(is_uvicorn_mode=True) - display_splash_screen(args) + update_uvicorn_mode_config() + display_splash_screen(global_args) # Create application instance directly instead of using factory function - app = create_app(args) + app = create_app(global_args) # Start Uvicorn in single process mode uvicorn_config = { "app": app, # Pass application instance directly instead of string path - "host": args.host, - "port": args.port, + "host": global_args.host, + "port": global_args.port, "log_config": None, # Disable default config } - if args.ssl: + if global_args.ssl: uvicorn_config.update( { - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile, + "ssl_certfile": global_args.ssl_certfile, + "ssl_keyfile": global_args.ssl_keyfile, } ) - print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}") + print( + f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 445008ec8d9bb75a17b2b2d9fc510c7187e114bb..8e6640063f3f0077bf7590f6a6f3c069f41b02b7 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -10,16 +10,14 @@ import traceback import pipmaster as pm from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, Literal from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from lightrag.api.utils_api import ( - get_combined_auth_dependency, - global_args, -) +from lightrag.api.utils_api import get_combined_auth_dependency +from ..config import global_args router = APIRouter( prefix="/documents", @@ -30,7 +28,37 @@ router = APIRouter( temp_prefix = "__tmp__" +class ScanResponse(BaseModel): + """Response model for document scanning operation + + Attributes: + status: Status of the scanning operation + message: Optional message with additional details + """ + + status: Literal["scanning_started"] = Field( + description="Status of the scanning operation" + ) + message: Optional[str] = Field( + default=None, description="Additional details about the scanning operation" + ) + + class Config: + json_schema_extra = { + "example": { + "status": "scanning_started", + "message": "Scanning process has been initiated in the background", + } + } + + class InsertTextRequest(BaseModel): + """Request model for inserting a single text document + + Attributes: + text: The text content to be inserted into the RAG system + """ + text: str = Field( min_length=1, description="The text to insert", @@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel): def strip_after(cls, text: str) -> str: return text.strip() + class Config: + json_schema_extra = { + "example": { + "text": "This is a sample text to be inserted into the RAG system." + } + } + class InsertTextsRequest(BaseModel): + """Request model for inserting multiple text documents + + Attributes: + texts: List of text contents to be inserted into the RAG system + """ + texts: list[str] = Field( min_length=1, description="The texts to insert", @@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel): def strip_after(cls, texts: list[str]) -> list[str]: return [text.strip() for text in texts] + class Config: + json_schema_extra = { + "example": { + "texts": [ + "This is the first text to be inserted.", + "This is the second text to be inserted.", + ] + } + } + class InsertResponse(BaseModel): - status: str = Field(description="Status of the operation") + """Response model for document insertion operations + + Attributes: + status: Status of the operation (success, duplicated, partial_success, failure) + message: Detailed message describing the operation result + """ + + status: Literal["success", "duplicated", "partial_success", "failure"] = Field( + description="Status of the operation" + ) message: str = Field(description="Message describing the operation result") + class Config: + json_schema_extra = { + "example": { + "status": "success", + "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.", + } + } + + +class ClearDocumentsResponse(BaseModel): + """Response model for document clearing operation + + Attributes: + status: Status of the clear operation + message: Detailed message describing the operation result + """ + + status: Literal["success", "partial_success", "busy", "fail"] = Field( + description="Status of the clear operation" + ) + message: str = Field(description="Message describing the operation result") + + class Config: + json_schema_extra = { + "example": { + "status": "success", + "message": "All documents cleared successfully. Deleted 15 files.", + } + } + + +class ClearCacheRequest(BaseModel): + """Request model for clearing cache + + Attributes: + modes: Optional list of cache modes to clear + """ + + modes: Optional[ + List[Literal["default", "naive", "local", "global", "hybrid", "mix"]] + ] = Field( + default=None, + description="Modes of cache to clear. If None, clears all cache.", + ) + + class Config: + json_schema_extra = {"example": {"modes": ["default", "naive"]}} + + +class ClearCacheResponse(BaseModel): + """Response model for cache clearing operation + + Attributes: + status: Status of the clear operation + message: Detailed message describing the operation result + """ + + status: Literal["success", "fail"] = Field( + description="Status of the clear operation" + ) + message: str = Field(description="Message describing the operation result") + + class Config: + json_schema_extra = { + "example": { + "status": "success", + "message": "Successfully cleared cache for modes: ['default', 'naive']", + } + } + + +"""Response model for document status + +Attributes: + id: Document identifier + content_summary: Summary of document content + content_length: Length of document content + status: Current processing status + created_at: Creation timestamp (ISO format string) + updated_at: Last update timestamp (ISO format string) + chunks_count: Number of chunks (optional) + error: Error message if any (optional) + metadata: Additional metadata (optional) + file_path: Path to the document file +""" + class DocStatusResponse(BaseModel): @staticmethod @@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel): return dt return dt.isoformat() - """Response model for document status + id: str = Field(description="Document identifier") + content_summary: str = Field(description="Summary of document content") + content_length: int = Field(description="Length of document content in characters") + status: DocStatus = Field(description="Current processing status") + created_at: str = Field(description="Creation timestamp (ISO format string)") + updated_at: str = Field(description="Last update timestamp (ISO format string)") + chunks_count: Optional[int] = Field( + default=None, description="Number of chunks the document was split into" + ) + error: Optional[str] = Field( + default=None, description="Error message if processing failed" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="Additional metadata about the document" + ) + file_path: str = Field(description="Path to the document file") + + class Config: + json_schema_extra = { + "example": { + "id": "doc_123456", + "content_summary": "Research paper on machine learning", + "content_length": 15240, + "status": "PROCESSED", + "created_at": "2025-03-31T12:34:56", + "updated_at": "2025-03-31T12:35:30", + "chunks_count": 12, + "error": None, + "metadata": {"author": "John Doe", "year": 2025}, + "file_path": "research_paper.pdf", + } + } + + +class DocsStatusesResponse(BaseModel): + """Response model for document statuses Attributes: - id: Document identifier - content_summary: Summary of document content - content_length: Length of document content - status: Current processing status - created_at: Creation timestamp (ISO format string) - updated_at: Last update timestamp (ISO format string) - chunks_count: Number of chunks (optional) - error: Error message if any (optional) - metadata: Additional metadata (optional) + statuses: Dictionary mapping document status to lists of document status responses """ - id: str - content_summary: str - content_length: int - status: DocStatus - created_at: str - updated_at: str - chunks_count: Optional[int] = None - error: Optional[str] = None - metadata: Optional[dict[str, Any]] = None - file_path: str - + statuses: Dict[DocStatus, List[DocStatusResponse]] = Field( + default_factory=dict, + description="Dictionary mapping document status to lists of document status responses", + ) -class DocsStatusesResponse(BaseModel): - statuses: Dict[DocStatus, List[DocStatusResponse]] = {} + class Config: + json_schema_extra = { + "example": { + "statuses": { + "PENDING": [ + { + "id": "doc_123", + "content_summary": "Pending document", + "content_length": 5000, + "status": "PENDING", + "created_at": "2025-03-31T10:00:00", + "updated_at": "2025-03-31T10:00:00", + "file_path": "pending_doc.pdf", + } + ], + "PROCESSED": [ + { + "id": "doc_456", + "content_summary": "Processed document", + "content_length": 8000, + "status": "PROCESSED", + "created_at": "2025-03-31T09:00:00", + "updated_at": "2025-03-31T09:05:00", + "chunks_count": 8, + "file_path": "processed_doc.pdf", + } + ], + } + } + } class PipelineStatusResponse(BaseModel): @@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) return False case ".pdf": - if global_args["main_args"].document_loading_engine == "DOCLING": + if global_args.document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter # type: ignore @@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for page in reader.pages: content += page.extract_text() + "\n" case ".docx": - if global_args["main_args"].document_loading_engine == "DOCLING": + if global_args.document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter # type: ignore @@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: [paragraph.text for paragraph in doc.paragraphs] ) case ".pptx": - if global_args["main_args"].document_loading_engine == "DOCLING": + if global_args.document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter # type: ignore @@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if hasattr(shape, "text"): content += shape.text + "\n" case ".xlsx": - if global_args["main_args"].document_loading_engine == "DOCLING": + if global_args.document_loading_engine == "DOCLING": if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter # type: ignore @@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]): await rag.apipeline_process_enqueue_documents() +# TODO: deprecate after /insert_file is removed async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: """Save the uploaded file to a temporary location @@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): if not new_files: return - # Get MAX_PARALLEL_INSERT from global_args["main_args"] - max_parallel = global_args["main_args"].max_parallel_insert + # Get MAX_PARALLEL_INSERT from global_args + max_parallel = global_args.max_parallel_insert # Calculate batch size as 2 * MAX_PARALLEL_INSERT batch_size = 2 * max_parallel @@ -509,7 +704,9 @@ def create_document_routes( # Create combined auth dependency for document routes combined_auth = get_combined_auth_dependency(api_key) - @router.post("/scan", dependencies=[Depends(combined_auth)]) + @router.post( + "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)] + ) async def scan_for_new_documents(background_tasks: BackgroundTasks): """ Trigger the scanning process for new documents. @@ -519,13 +716,18 @@ def create_document_routes( that fact. Returns: - dict: A dictionary containing the scanning status + ScanResponse: A response object containing the scanning status """ # Start the scanning process in the background background_tasks.add_task(run_scanning_process, rag, doc_manager) - return {"status": "scanning_started"} + return ScanResponse( + status="scanning_started", + message="Scanning process has been initiated in the background", + ) - @router.post("/upload", dependencies=[Depends(combined_auth)]) + @router.post( + "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] + ) async def upload_to_input_dir( background_tasks: BackgroundTasks, file: UploadFile = File(...) ): @@ -645,6 +847,7 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + # TODO: deprecated, use /upload instead @router.post( "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) @@ -688,6 +891,7 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + # TODO: deprecated, use /upload instead @router.post( "/file_batch", response_model=InsertResponse, @@ -752,32 +956,186 @@ def create_document_routes( raise HTTPException(status_code=500, detail=str(e)) @router.delete( - "", response_model=InsertResponse, dependencies=[Depends(combined_auth)] + "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)] ) async def clear_documents(): """ Clear all documents from the RAG system. - This endpoint deletes all text chunks, entities vector database, and relationships - vector database, effectively clearing all documents from the RAG system. + This endpoint deletes all documents, entities, relationships, and files from the system. + It uses the storage drop methods to properly clean up all data and removes all files + from the input directory. Returns: - InsertResponse: A response object containing the status and message. + ClearDocumentsResponse: A response object containing the status and message. + - status="success": All documents and files were successfully cleared. + - status="partial_success": Document clear job exit with some errors. + - status="busy": Operation could not be completed because the pipeline is busy. + - status="fail": All storage drop operations failed, with message + - message: Detailed information about the operation results, including counts + of deleted files and any errors encountered. Raises: - HTTPException: If an error occurs during the clearing process (500). + HTTPException: Raised when a serious error occurs during the clearing process, + with status code 500 and error details in the detail field. """ - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", message="All documents cleared successfully" + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + ) + + # Get pipeline status and lock + pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status_lock = get_pipeline_status_lock() + + # Check and set status with lock + async with pipeline_status_lock: + if pipeline_status.get("busy", False): + return ClearDocumentsResponse( + status="busy", + message="Cannot clear documents while pipeline is busy", + ) + # Set busy to true + pipeline_status.update( + { + "busy": True, + "job_name": "Clearing Documents", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False, # Clear any previous request + "latest_message": "Starting document clearing process", + } ) + # Cleaning history_messages without breaking it as a shared list object + del pipeline_status["history_messages"][:] + pipeline_status["history_messages"].append( + "Starting document clearing process" + ) + + try: + # Use drop method to clear all data + drop_tasks = [] + storages = [ + rag.text_chunks, + rag.full_docs, + rag.entities_vdb, + rag.relationships_vdb, + rag.chunks_vdb, + rag.chunk_entity_relation_graph, + rag.doc_status, + ] + + # Log storage drop start + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append( + "Starting to drop storage components" + ) + + for storage in storages: + if storage is not None: + drop_tasks.append(storage.drop()) + + # Wait for all drop tasks to complete + drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True) + + # Check for errors and log results + errors = [] + storage_success_count = 0 + storage_error_count = 0 + + for i, result in enumerate(drop_results): + storage_name = storages[i].__class__.__name__ + if isinstance(result, Exception): + error_msg = f"Error dropping {storage_name}: {str(result)}" + errors.append(error_msg) + logger.error(error_msg) + storage_error_count += 1 + else: + logger.info(f"Successfully dropped {storage_name}") + storage_success_count += 1 + + # Log storage drop results + if "history_messages" in pipeline_status: + if storage_error_count > 0: + pipeline_status["history_messages"].append( + f"Dropped {storage_success_count} storage components with {storage_error_count} errors" + ) + else: + pipeline_status["history_messages"].append( + f"Successfully dropped all {storage_success_count} storage components" + ) + + # If all storage operations failed, return error status and don't proceed with file deletion + if storage_success_count == 0 and storage_error_count > 0: + error_message = "All storage drop operations failed. Aborting document clearing process." + logger.error(error_message) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(error_message) + return ClearDocumentsResponse(status="fail", message=error_message) + + # Log file deletion start + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append( + "Starting to delete files in input directory" + ) + + # Delete all files in input_dir + deleted_files_count = 0 + file_errors_count = 0 + + for file_path in doc_manager.input_dir.glob("**/*"): + if file_path.is_file(): + try: + file_path.unlink() + deleted_files_count += 1 + except Exception as e: + logger.error(f"Error deleting file {file_path}: {str(e)}") + file_errors_count += 1 + + # Log file deletion results + if "history_messages" in pipeline_status: + if file_errors_count > 0: + pipeline_status["history_messages"].append( + f"Deleted {deleted_files_count} files with {file_errors_count} errors" + ) + errors.append(f"Failed to delete {file_errors_count} files") + else: + pipeline_status["history_messages"].append( + f"Successfully deleted {deleted_files_count} files" + ) + + # Prepare final result message + final_message = "" + if errors: + final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files." + status = "partial_success" + else: + final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files." + status = "success" + + # Log final result + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(final_message) + + # Return response based on results + return ClearDocumentsResponse(status=status, message=final_message) except Exception as e: - logger.error(f"Error DELETE /documents: {str(e)}") + error_msg = f"Error clearing documents: {str(e)}" + logger.error(error_msg) logger.error(traceback.format_exc()) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(error_msg) raise HTTPException(status_code=500, detail=str(e)) + finally: + # Reset busy status after completion + async with pipeline_status_lock: + pipeline_status["busy"] = False + completion_msg = "Document clearing process completed" + pipeline_status["latest_message"] = completion_msg + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(completion_msg) @router.get( "/pipeline_status", @@ -850,7 +1208,9 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - @router.get("", dependencies=[Depends(combined_auth)]) + @router.get( + "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] + ) async def documents() -> DocsStatusesResponse: """ Get the status of all documents in the system. @@ -908,4 +1268,57 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + @router.post( + "/clear_cache", + response_model=ClearCacheResponse, + dependencies=[Depends(combined_auth)], + ) + async def clear_cache(request: ClearCacheRequest): + """ + Clear cache data from the LLM response cache storage. + + This endpoint allows clearing specific modes of cache or all cache if no modes are specified. + Valid modes include: "default", "naive", "local", "global", "hybrid", "mix". + - "default" represents extraction cache. + - Other modes correspond to different query modes. + + Args: + request (ClearCacheRequest): The request body containing optional modes to clear. + + Returns: + ClearCacheResponse: A response object containing the status and message. + + Raises: + HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors). + """ + try: + # Validate modes if provided + valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"] + if request.modes and not all(mode in valid_modes for mode in request.modes): + invalid_modes = [ + mode for mode in request.modes if mode not in valid_modes + ] + raise HTTPException( + status_code=400, + detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}", + ) + + # Call the aclear_cache method + await rag.aclear_cache(request.modes) + + # Prepare success message + if request.modes: + message = f"Successfully cleared cache for modes: {request.modes}" + else: + message = "Successfully cleared all cache" + + return ClearCacheResponse(status="success", message=message) + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logger.error(f"Error clearing cache: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + return router diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index f9d77ff615c145d38af207d75d23ab67463b3ef6..381df90bf1db2042a08ae6ca3eae3e1a9931ef2c 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API. """ from typing import Optional -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from ..utils_api import get_combined_auth_dependency @@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None): @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( - label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False + label: str = Query(..., description="Label to get knowledge graph for"), + max_depth: int = Query(3, description="Maximum depth of graph", ge=1), + max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), ): """ Retrieve a connected subgraph of nodes where the label includes the specified label. - Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. min_degree does not affect nodes directly connected to the matching nodes - 2. Label matching nodes take precedence - 3. Followed by nodes directly connected to the matching nodes - 4. Finally, the degree of the nodes - Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) + 1. Hops(path) to the staring node take precedence + 2. Followed by the degree of the nodes Args: - label (str): Label to get knowledge graph for - max_depth (int, optional): Maximum depth of graph. Defaults to 3. - inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False. - min_degree (int, optional): Minimum degree of nodes. Defaults to 0. + label (str): Label of the starting node + max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3 + max_nodes: Maxiumu nodes to return Returns: Dict[str, List[str]]: Knowledge graph for label @@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None): return await rag.get_knowledge_graph( node_label=label, max_depth=max_depth, - inclusive=inclusive, - min_degree=min_degree, + max_nodes=max_nodes, ) return router diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py index 065a12a121b17b2295228628f25f0337f6497100..cf902a8a32fcadfde56f219c88a096f15287aad0 100644 --- a/lightrag/api/run_with_gunicorn.py +++ b/lightrag/api/run_with_gunicorn.py @@ -7,14 +7,9 @@ import os import sys import signal import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file +from lightrag.api.utils_api import display_splash_screen, check_env_file from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data -from dotenv import load_dotenv - -# use the .env that is inside the current folder -# allows to use different .env file for each lightrag instance -# the OS environment variables take precedence over the .env file -load_dotenv(dotenv_path=".env", override=False) +from .config import global_args def check_and_install_dependencies(): @@ -59,20 +54,17 @@ def main(): signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGTERM, signal_handler) # kill command - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - # Display startup information - display_splash_screen(args) + display_splash_screen(global_args) print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") + print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})") print("🔍 Preloading app: Enabled") print("📝 Note: Using Gunicorn's preload feature for shared data initialization") print("\n\n" + "=" * 80) print("MAIN PROCESS INITIALIZATION") print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") + print(f"Workers setting: {global_args.workers}") print("=" * 80 + "\n") # Import Gunicorn's StandaloneApplication @@ -128,31 +120,43 @@ def main(): # Set configuration variables in gunicorn_config, prioritizing command line arguments gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) + global_args.workers + if global_args.workers + else int(os.getenv("WORKERS", 1)) ) # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) + host = ( + global_args.host + if global_args.host != "0.0.0.0" + else os.getenv("HOST", "0.0.0.0") + ) + port = ( + global_args.port + if global_args.port != 9621 + else int(os.getenv("PORT", 9621)) + ) gunicorn_config.bind = f"{host}:{port}" # Log level configuration prioritizes command line arguments gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level + global_args.log_level.lower() + if global_args.log_level else os.getenv("LOG_LEVEL", "info") ) # Timeout configuration prioritizes command line arguments gunicorn_config.timeout = ( - args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2)) + global_args.timeout + if global_args.timeout * 2 + else int(os.getenv("TIMEOUT", 150 * 2)) ) # Keepalive configuration gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( + if global_args.ssl or os.getenv("SSL", "").lower() in ( "true", "1", "yes", @@ -160,12 +164,14 @@ def main(): "on", ): gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile + global_args.ssl_certfile + if global_args.ssl_certfile else os.getenv("SSL_CERTFILE") ) gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + global_args.ssl_keyfile + if global_args.ssl_keyfile + else os.getenv("SSL_KEYFILE") ) # Set configuration options from the module @@ -190,13 +196,13 @@ def main(): # Import the application from lightrag.api.lightrag_server import get_application - return get_application(args) + return get_application(global_args) # Create the application app = GunicornApp("") # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) + workers_count = int(global_args.workers) if workers_count > 1: # Set a flag to indicate we're in the main process os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index c01b7a37b19655b59a890390a12d7a9838381936..2bdc8d0765bab0ff851e49395ca5d1f2b37fdc6f 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -7,15 +7,13 @@ import argparse from typing import Optional, List, Tuple import sys from ascii_colors import ASCIIColors -import logging from lightrag.api import __api_version__ as api_version from lightrag import __version__ as core_version from fastapi import HTTPException, Security, Request, status -from dotenv import load_dotenv from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN from .auth import auth_handler -from ..prompt import PROMPTS +from .config import ollama_server_infos, global_args def check_env_file(): @@ -36,16 +34,8 @@ def check_env_file(): return True -# use the .env that is inside the current folder -# allows to use different .env file for each lightrag instance -# the OS environment variables take precedence over the .env file -load_dotenv(dotenv_path=".env", override=False) - -global_args = {"main_args": None} - -# Get whitelist paths from environment variable, only once during initialization -default_whitelist = "/health,/api/*" -whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",") +# Get whitelist paths from global_args, only once during initialization +whitelist_paths = global_args.whitelist_paths.split(",") # Pre-compile path matching patterns whitelist_patterns: List[Tuple[str, bool]] = [] @@ -63,19 +53,6 @@ for path in whitelist_paths: auth_configured = bool(auth_handler.accounts) -class OllamaServerInfos: - # Constants for emulated Ollama model information - LIGHTRAG_NAME = "lightrag" - LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest") - LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}" - LIGHTRAG_SIZE = 7365960935 # it's a dummy value - LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" - LIGHTRAG_DIGEST = "sha256:lightrag" - - -ollama_server_infos = OllamaServerInfos() - - def get_combined_auth_dependency(api_key: Optional[str] = None): """ Create a combined authentication dependency that implements authentication logic @@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None): return combined_dependency -class DefaultRAGStorageConfig: - KV_STORAGE = "JsonKVStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" - GRAPH_STORAGE = "NetworkXStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - - -def get_default_host(binding_type: str) -> str: - default_hosts = { - "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), - "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"), - "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"), - "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"), - } - return default_hosts.get( - binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434") - ) # fallback to ollama if unknown - - -def get_env_value(env_key: str, default: any, value_type: type = str) -> any: - """ - Get value from environment variable with type conversion - - Args: - env_key (str): Environment variable key - default (any): Default value if env variable is not set - value_type (type): Type to convert the value to - - Returns: - any: Converted value from environment or default - """ - value = os.getenv(env_key) - if value is None: - return default - - if value_type is bool: - return value.lower() in ("true", "1", "yes", "t", "on") - try: - return value_type(value) - except ValueError: - return default - - -def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: - """ - Parse command line arguments with environment variable fallback - - Args: - is_uvicorn_mode: Whether running under uvicorn mode - - Returns: - argparse.Namespace: Parsed arguments - """ - - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - # Server configuration - parser.add_argument( - "--host", - default=get_env_value("HOST", "0.0.0.0"), - help="Server host (default: from env or 0.0.0.0)", - ) - parser.add_argument( - "--port", - type=int, - default=get_env_value("PORT", 9621, int), - help="Server port (default: from env or 9621)", - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default=get_env_value("WORKING_DIR", "./rag_storage"), - help="Working directory for RAG storage (default: from env or ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default=get_env_value("INPUT_DIR", "./inputs"), - help="Directory containing input documents (default: from env or ./inputs)", - ) - - def timeout_type(value): - if value is None: - return 150 - if value is None or value == "None": - return None - return int(value) - - parser.add_argument( - "--timeout", - default=get_env_value("TIMEOUT", None, timeout_type), - type=timeout_type, - help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", - ) - - # RAG configuration - parser.add_argument( - "--max-async", - type=int, - default=get_env_value("MAX_ASYNC", 4, int), - help="Maximum async operations (default: from env or 4)", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=get_env_value("MAX_TOKENS", 32768, int), - help="Maximum token size (default: from env or 32768)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default=get_env_value("LOG_LEVEL", "INFO"), - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: from env or INFO)", - ) - parser.add_argument( - "--verbose", - action="store_true", - default=get_env_value("VERBOSE", False, bool), - help="Enable verbose debug output(only valid for DEBUG log-level)", - ) - - parser.add_argument( - "--key", - type=str, - default=get_env_value("LIGHTRAG_API_KEY", None), - help="API key for authentication. This protects lightrag server against unauthorized access", - ) - - # Optional https parameters - parser.add_argument( - "--ssl", - action="store_true", - default=get_env_value("SSL", False, bool), - help="Enable HTTPS (default: from env or False)", - ) - parser.add_argument( - "--ssl-certfile", - default=get_env_value("SSL_CERTFILE", None), - help="Path to SSL certificate file (required if --ssl is enabled)", - ) - parser.add_argument( - "--ssl-keyfile", - default=get_env_value("SSL_KEYFILE", None), - help="Path to SSL private key file (required if --ssl is enabled)", - ) - - parser.add_argument( - "--history-turns", - type=int, - default=get_env_value("HISTORY_TURNS", 3, int), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Search parameters - parser.add_argument( - "--top-k", - type=int, - default=get_env_value("TOP_K", 60, int), - help="Number of most similar results to return (default: from env or 60)", - ) - parser.add_argument( - "--cosine-threshold", - type=float, - default=get_env_value("COSINE_THRESHOLD", 0.2, float), - help="Cosine similarity threshold (default: from env or 0.4)", - ) - - # Ollama model name - parser.add_argument( - "--simulated-model-name", - type=str, - default=get_env_value( - "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL - ), - help="Number of conversation history turns to include (default: from env or 3)", - ) - - # Namespace - parser.add_argument( - "--namespace-prefix", - type=str, - default=get_env_value("NAMESPACE_PREFIX", ""), - help="Prefix of the namespace", - ) - - parser.add_argument( - "--auto-scan-at-startup", - action="store_true", - default=False, - help="Enable automatic scanning when the program starts", - ) - - # Server workers configuration - parser.add_argument( - "--workers", - type=int, - default=get_env_value("WORKERS", 1, int), - help="Number of worker processes (default: from env or 1)", - ) - - # LLM and embedding bindings - parser.add_argument( - "--llm-binding", - type=str, - default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], - help="LLM binding type (default: from env or ollama)", - ) - parser.add_argument( - "--embedding-binding", - type=str, - default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai"], - help="Embedding binding type (default: from env or ollama)", - ) - - args = parser.parse_args() - - # If in uvicorn mode and workers > 1, force it to 1 and log warning - if is_uvicorn_mode and args.workers > 1: - original_workers = args.workers - args.workers = 1 - # Log warning directly here - logging.warning( - f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" - ) - - # convert relative path to absolute path - args.working_dir = os.path.abspath(args.working_dir) - args.input_dir = os.path.abspath(args.input_dir) - - # Inject storage configuration from environment variables - args.kv_storage = get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ) - args.doc_status_storage = get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ) - args.graph_storage = get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ) - args.vector_storage = get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ) - - # Get MAX_PARALLEL_INSERT from environment - args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) - - # Handle openai-ollama special case - if args.llm_binding == "openai-ollama": - args.llm_binding = "openai" - args.embedding_binding = "ollama" - - args.llm_binding_host = get_env_value( - "LLM_BINDING_HOST", get_default_host(args.llm_binding) - ) - args.embedding_binding_host = get_env_value( - "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) - ) - args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) - args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") - - # Inject model configuration - args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest") - args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest") - args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int) - args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int) - - # Inject chunk configuration - args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) - args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) - - # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool - ) - - # Inject LLM temperature configuration - args.temperature = get_env_value("TEMPERATURE", 0.5, float) - - # Select Document loading tool (DOCLING, DEFAULT) - args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") - - ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name - - global_args["main_args"] = args - return args - - def display_splash_screen(args: argparse.Namespace) -> None: """ Display a colorful splash screen showing LightRAG server configuration @@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.white(" ├─ Workers: ", end="") ASCIIColors.yellow(f"{args.workers}") ASCIIColors.white(" ├─ CORS Origins: ", end="") - ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") + ASCIIColors.yellow(f"{args.cors_origins}") ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.yellow(f"{args.ssl}") if args.ssl: @@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.verbose}") ASCIIColors.white(" ├─ History Turns: ", end="") ASCIIColors.yellow(f"{args.history_turns}") - ASCIIColors.white(" └─ API Key: ", end="") + ASCIIColors.white(" ├─ API Key: ", end="") ASCIIColors.yellow("Set" if args.key else "Not Set") + ASCIIColors.white(" └─ JWT Auth: ", end="") + ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled") # Directory Configuration ASCIIColors.magenta("\n📂 Directory Configuration:") @@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.embedding_dim}") # RAG Configuration - summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"]) ASCIIColors.magenta("\n⚙️ RAG Configuration:") ASCIIColors.white(" ├─ Summary Language: ", end="") - ASCIIColors.yellow(f"{summary_language}") + ASCIIColors.yellow(f"{args.summary_language}") ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") ASCIIColors.yellow(f"{args.max_parallel_insert}") ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") @@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None: protocol = "https" if args.ssl else "http" if args.host == "0.0.0.0": ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Local Access: ", end="") + ASCIIColors.white(" ├─ WebUI (local): ", end="") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}") ASCIIColors.white(" ├─ Remote Access: ", end="") ASCIIColors.yellow(f"{protocol}://:{args.port}") ASCIIColors.white(" ├─ API Documentation (local): ", end="") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs") - ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="") + ASCIIColors.white(" └─ Alternative Documentation (local): ", end="") ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc") - ASCIIColors.white(" └─ WebUI (local): ", end="") - ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui") - ASCIIColors.yellow("\n📝 Note:") - ASCIIColors.white(""" Since the server is running on 0.0.0.0: + ASCIIColors.magenta("\n📝 Note:") + ASCIIColors.cyan(""" Since the server is running on 0.0.0.0: - Use 'localhost' or '127.0.0.1' for local access - Use your machine's IP address for remote access - To find your IP address: @@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None: else: base_url = f"{protocol}://{args.host}:{args.port}" ASCIIColors.magenta("\n🌐 Server Access Information:") - ASCIIColors.white(" ├─ Base URL: ", end="") + ASCIIColors.white(" ├─ WebUI (local): ", end="") ASCIIColors.yellow(f"{base_url}") ASCIIColors.white(" ├─ API Documentation: ", end="") ASCIIColors.yellow(f"{base_url}/docs") ASCIIColors.white(" └─ Alternative Documentation: ", end="") ASCIIColors.yellow(f"{base_url}/redoc") - # Usage Examples - ASCIIColors.magenta("\n📚 Quick Start Guide:") - ASCIIColors.cyan(""" - 1. Access the Swagger UI: - Open your browser and navigate to the API documentation URL above - - 2. API Authentication:""") - if args.key: - ASCIIColors.cyan(""" Add the following header to your requests: - X-API-Key: - """) - else: - ASCIIColors.cyan(" No authentication required\n") - - ASCIIColors.cyan(""" 3. Basic Operations: - - POST /upload_document: Upload new documents to RAG - - POST /query: Query your document collection - - 4. Monitor the server: - - Check server logs for detailed operation information - - Use healthcheck endpoint: GET /health - """) - # Security Notice if args.key: ASCIIColors.yellow("\n⚠️ Security Notice:") ASCIIColors.white(""" API Key authentication is enabled. Make sure to include the X-API-Key header in all your requests. """) + if args.auth_accounts: + ASCIIColors.yellow("\n⚠️ Security Notice:") + ASCIIColors.white(""" JWT authentication is enabled. + Make sure to login before making the request, and include the 'Authorization' in the header. + """) # Ensure splash output flush to system log sys.stdout.flush() diff --git a/lightrag/api/webui/assets/index-D8zGvNlV.js b/lightrag/api/webui/assets/index-BaHKTcxB.js similarity index 65% rename from lightrag/api/webui/assets/index-D8zGvNlV.js rename to lightrag/api/webui/assets/index-BaHKTcxB.js index bbcae58ff413786fa1da15333541027bffbf7707..70170bb88fad37e4667216b21ac3a2d7a3d0f121 100644 Binary files a/lightrag/api/webui/assets/index-D8zGvNlV.js and b/lightrag/api/webui/assets/index-BaHKTcxB.js differ diff --git a/lightrag/api/webui/assets/index-CD5HxTy1.css b/lightrag/api/webui/assets/index-CD5HxTy1.css deleted file mode 100644 index a0ab321b50da1792a43e9070e300e030812b79a8..0000000000000000000000000000000000000000 Binary files a/lightrag/api/webui/assets/index-CD5HxTy1.css and /dev/null differ diff --git a/lightrag/api/webui/assets/index-f0HMqdqP.css b/lightrag/api/webui/assets/index-f0HMqdqP.css new file mode 100644 index 0000000000000000000000000000000000000000..ede578c74a1c9f7b301d4eaa87be668a1d44fd9d Binary files /dev/null and b/lightrag/api/webui/assets/index-f0HMqdqP.css differ diff --git a/lightrag/api/webui/index.html b/lightrag/api/webui/index.html index 54348636d1f8097dc1de7faeb198fe90cd25a121..5b9d4883d588d979cc2f3708790e50abad843c7f 100644 Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ diff --git a/lightrag/base.py b/lightrag/base.py index ad41fc586710fddf14d795d81928fe484bbd8699..5cf5ab617d00c378ed2a8593dd0aa67f7ced5d54 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -112,6 +112,32 @@ class StorageNameSpace(ABC): async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" + @abstractmethod + async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This abstract method defines the contract for dropping all data from a storage implementation. + Each storage type must implement this method to: + 1. Clear all data from memory and/or external storage + 2. Remove any associated storage files if applicable + 3. Reset the storage to its initial state + 4. Handle cleanup of any resources + 5. Notify other processes if necessary + 6. This action should persistent the data to disk immediately. + + Returns: + dict[str, str]: Operation status and message with the following format: + { + "status": str, # "success" or "error" + "message": str # "data dropped" on success, error details on failure + } + + Implementation specific: + - On success: return {"status": "success", "message": "data dropped"} + - On failure: return {"status": "error", "message": ""} + - If not supported: return {"status": "error", "message": "unsupported"} + """ + @dataclass class BaseVectorStorage(StorageNameSpace, ABC): @@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC): @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Insert or update vectors in the storage.""" + """Insert or update vectors in the storage. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ @abstractmethod async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name.""" + """Delete a single entity by its name. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ @abstractmethod async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity.""" + """Delete relations for a given entity. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ @abstractmethod async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC): """ pass + @abstractmethod + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + ids: List of vector IDs to be deleted + """ + @dataclass class BaseKVStorage(StorageNameSpace, ABC): @@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC): @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Upsert data""" + """Upsert data + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + """ + + @abstractmethod + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by cache mode + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + modes (list[str]): List of cache modes to be dropped from storage + + Returns: + True: if the cache drop successfully + False: if the cache drop failed, or the cache mode is not supported + """ @dataclass @@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC): @abstractmethod async def get_node(self, node_id: str) -> dict[str, str] | None: - """Get an edge by its source and target node ids.""" + """Get node by its label identifier, return only node properties""" @abstractmethod async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - """Get all edges connected to a node.""" + """Get edge properties between two nodes""" @abstractmethod async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - """Delete a node from the graph.""" + """Delete a node from the graph. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ @abstractmethod async def delete_node(self, node_id: str) -> None: @@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC): @abstractmethod async def get_knowledge_graph( - self, node_label: str, max_depth: int = 3 + self, node_label: str, max_depth: int = 3, max_nodes: int = 1000 ) -> KnowledgeGraph: - """Retrieve a subgraph of the knowledge graph starting from a given node.""" + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node,* means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible) + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ class DocStatus(str, Enum): @@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Drop cache is not supported for Doc Status storage""" + return False + class StoragesStatus(str, Enum): """Storages status""" diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 4943fc1d42c072a14f29ab62676847fd8bb991ad..bbddb2857a0a59235916670534a6a3f94bbc96d6 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = { "KV_STORAGE": { "implementations": [ "JsonKVStorage", - "MongoKVStorage", "RedisKVStorage", - "TiDBKVStorage", "PGKVStorage", - "OracleKVStorage", + "MongoKVStorage", + # "TiDBKVStorage", ], "required_methods": ["get_by_id", "upsert"], }, @@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = { "implementations": [ "NetworkXStorage", "Neo4JStorage", - "MongoGraphStorage", - "TiDBGraphStorage", - "AGEStorage", - "GremlinStorage", "PGGraphStorage", - "OracleGraphStorage", + # "AGEStorage", + # "MongoGraphStorage", + # "TiDBGraphStorage", + # "GremlinStorage", ], "required_methods": ["upsert_node", "upsert_edge"], }, @@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = { "NanoVectorDBStorage", "MilvusVectorDBStorage", "ChromaVectorDBStorage", - "TiDBVectorDBStorage", "PGVectorStorage", "FaissVectorDBStorage", "QdrantVectorDBStorage", - "OracleVectorDBStorage", "MongoVectorDBStorage", + # "TiDBVectorDBStorage", ], "required_methods": ["query", "upsert"], }, @@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = { "implementations": [ "JsonDocStatusStorage", "PGDocStatusStorage", - "PGDocStatusStorage", "MongoDocStatusStorage", ], "required_methods": ["get_docs_by_status"], @@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "JsonKVStorage": [], "MongoKVStorage": [], "RedisKVStorage": ["REDIS_URI"], - "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleKVStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], - "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", "AGE_POSTGRES_PASSWORD", ], - "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], + # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], "PGGraphStorage": [ "POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE", ], - "OracleGraphStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], # Vector Storage Implementations "NanoVectorDBStorage": [], "MilvusVectorDBStorage": [], "ChromaVectorDBStorage": [], - "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "FaissVectorDBStorage": [], "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None - "OracleVectorDBStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], "MongoVectorDBStorage": [], # Document Status Storage Implementations "JsonDocStatusStorage": [], @@ -112,9 +90,6 @@ STORAGES = { "NanoVectorDBStorage": ".kg.nano_vector_db_impl", "JsonDocStatusStorage": ".kg.json_doc_status_impl", "Neo4JStorage": ".kg.neo4j_impl", - "OracleKVStorage": ".kg.oracle_impl", - "OracleGraphStorage": ".kg.oracle_impl", - "OracleVectorDBStorage": ".kg.oracle_impl", "MilvusVectorDBStorage": ".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl", @@ -122,14 +97,14 @@ STORAGES = { "MongoVectorDBStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl", - "TiDBKVStorage": ".kg.tidb_impl", - "TiDBVectorDBStorage": ".kg.tidb_impl", - "TiDBGraphStorage": ".kg.tidb_impl", + # "TiDBKVStorage": ".kg.tidb_impl", + # "TiDBVectorDBStorage": ".kg.tidb_impl", + # "TiDBGraphStorage": ".kg.tidb_impl", "PGKVStorage": ".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl", "AGEStorage": ".kg.age_impl", "PGGraphStorage": ".kg.postgres_impl", - "GremlinStorage": ".kg.gremlin_impl", + # "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 2295155453b44bf263a5c7528a8fa62fdcfcd995..b744ae1ec4d04eac3a73da5e64927827378d2836 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"): if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import psycopg -from psycopg.rows import namedtuple_row -from psycopg_pool import AsyncConnectionPool, PoolTimeout +import psycopg # type: ignore +from psycopg.rows import namedtuple_row # type: ignore +from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore class AGEQueryException(Exception): @@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage): async def index_done_callback(self) -> None: # AGES handles persistence automatically pass + + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all nodes and relationships in the graph. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + query = """ + MATCH (n) + DETACH DELETE n + """ + await self._query(query) + logger.info(f"Successfully dropped all data from graph {self.graph_name}") + return {"status": "success", "message": "graph data dropped"} + except Exception as e: + logger.error(f"Error dropping graph {self.graph_name}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 84d43326d86356e957d347ac5b6889251275061b..020e358f6b90473cdd1ce4d8f9622f30ead4f1c6 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,4 +1,5 @@ import asyncio +import os from dataclasses import dataclass from typing import Any, final import numpy as np @@ -10,8 +11,8 @@ import pipmaster as pm if not pm.is_installed("chromadb"): pm.install("chromadb") -from chromadb import HttpClient, PersistentClient -from chromadb.config import Settings +from chromadb import HttpClient, PersistentClient # type: ignore +from chromadb.config import Settings # type: ignore @final @@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] + + async def drop(self) -> dict[str, str]: + """Drop all vector data from storage and clean up resources + + This method will delete all documents from the ChromaDB collection. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + # Get all IDs in the collection + result = self._collection.get(include=[]) + if result and result["ids"] and len(result["ids"]) > 0: + # Delete all documents + self._collection.delete(ids=result["ids"]) + + logger.info( + f"Process {os.getpid()} drop ChromaDB collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index b81760376c14cc1e71cd89010d8545cc1725dca3..c51eb1bf31ab6dc1967cb0bc3d83d58a991ccf2b 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -11,16 +11,20 @@ import pipmaster as pm from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage -if not pm.is_installed("faiss"): - pm.install("faiss") - -import faiss # type: ignore from .shared_storage import ( get_storage_lock, get_update_flag, set_all_update_flags, ) +import faiss # type: ignore + +USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1" +FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu" + +if not pm.is_installed(FAISS_PACKAGE): + pm.install(FAISS_PACKAGE) + @final @dataclass @@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage): async def delete(self, ids: list[str]): """ Delete vectors for the provided custom IDs. + + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption """ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") to_remove = [] @@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage): ) async def delete_entity(self, entity_name: str) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) async def delete_entity_relation(self, entity_name: str) -> None: """ - Delete relations for a given entity by scanning metadata. + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption """ logger.debug(f"Searching relations for entity {entity_name}") relations = [] @@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage): results.append({**metadata, "id": metadata.get("__id__")}) return results + + async def drop(self) -> dict[str, str]: + """Drop all vector data from storage and clean up resources + + This method will: + 1. Remove the vector database storage file if it exists + 2. Reinitialize the vector database client + 3. Update flags to notify other processes + 4. Changes is persisted to disk immediately + + This method will remove all vectors from the Faiss index and delete the storage files. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._storage_lock: + # Reset the index + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + + # Remove storage files if they exist + if os.path.exists(self._faiss_index_file): + os.remove(self._faiss_index_file) + if os.path.exists(self._meta_file): + os.remove(self._meta_file) + + self._id_to_meta = {} + self._load_faiss_index() + + # Notify other processes + await set_all_update_flags(self.namespace) + self.storage_updated.value = False + + logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping FAISS index {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index ddb7559f12eca1ca8dd725fd399c507c683dd2b2..e27c561ee972c1ca06805e823577971ed2e513f9 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -24,9 +24,9 @@ from ..base import BaseGraphStorage if not pm.is_installed("gremlinpython"): pm.install("gremlinpython") -from gremlin_python.driver import client, serializer -from gremlin_python.driver.aiohttp.transport import AiohttpTransport -from gremlin_python.driver.protocol import GremlinServerError +from gremlin_python.driver import client, serializer # type: ignore +from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore +from gremlin_python.driver.protocol import GremlinServerError # type: ignore @final @@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage): except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") raise + + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all nodes and relationships in the graph. + + This function deletes all nodes with the specified graph name property, + which automatically removes all associated edges. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + query = f"""g + .V().has('graph', {self.graph_name}) + .drop() + """ + await self._query(query) + logger.info(f"Successfully dropped all data from graph {self.graph_name}") + return {"status": "success", "message": "graph data dropped"} + except Exception as e: + logger.error(f"Error dropping graph {self.graph_name}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 22da07b5f5b9ab46675bdd0ba1a3fac6e13e14c8..944d57d7e99300b6eee14801028c070893301659 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage): await clear_all_update_flags(self.namespace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """ + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + """ if not data: return logger.info(f"Inserting {len(data)} records to {self.namespace}") @@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: return self._data.get(id) - async def delete(self, doc_ids: list[str]): + async def delete(self, doc_ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ async with self._storage_lock: + any_deleted = False for doc_id in doc_ids: - self._data.pop(doc_id, None) - await set_all_update_flags(self.namespace) - await self.index_done_callback() + result = self._data.pop(doc_id, None) + if result is not None: + any_deleted = True - async def drop(self) -> None: - """Drop the storage""" - async with self._storage_lock: - self._data.clear() - await set_all_update_flags(self.namespace) - await self.index_done_callback() + if any_deleted: + await set_all_update_flags(self.namespace) + + async def drop(self) -> dict[str, str]: + """Drop all document status data from storage and clean up resources + + This method will: + 1. Clear all document status data from memory + 2. Update flags to notify other processes + 3. Trigger index_done_callback to save the empty state + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._storage_lock: + self._data.clear() + await set_all_update_flags(self.namespace) + + await self.index_done_callback() + logger.info(f"Process {os.getpid()} drop {self.namespace}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index e7deaf159415a0c5dfe35f21cde5d6d39ae45ee3..82c18d95b50eedbb19fc4124ad44eaf705ce715e 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage): return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """ + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + """ if not data: return logger.info(f"Inserting {len(data)} records to {self.namespace}") @@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage): await set_all_update_flags(self.namespace) async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ async with self._storage_lock: + any_deleted = False for doc_id in ids: - self._data.pop(doc_id, None) - await set_all_update_flags(self.namespace) - await self.index_done_callback() + result = self._data.pop(doc_id, None) + if result is not None: + any_deleted = True + + if any_deleted: + await set_all_update_flags(self.namespace) + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by by cache mode + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + ids (list[str]): List of cache mode to be drop from storage + + Returns: + True: if the cache drop successfully + False: if the cache drop failed + """ + if not modes: + return False + + try: + await self.delete(modes) + return True + except Exception: + return False + + async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + This action will persistent the data to disk immediately. + + This method will: + 1. Clear all data from memory + 2. Update flags to notify other processes + 3. Trigger index_done_callback to save the empty state + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._storage_lock: + self._data.clear() + await set_all_update_flags(self.namespace) + + await self.index_done_callback() + logger.info(f"Process {os.getpid()} drop {self.namespace}") + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 4b4577caf2135d0b96dd15b06119299cf017ccd6..2cff0079c5d54b8b6c715df899857f221312cb18 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"): pm.install("pymilvus") import configparser -from pymilvus import MilvusClient +from pymilvus import MilvusClient # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] + + async def drop(self) -> dict[str, str]: + """Drop all vector data from storage and clean up resources + + This method will delete all data from the Milvus collection. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + # Drop the collection and recreate it + if self._client.has_collection(self.namespace): + self._client.drop_collection(self.namespace) + + # Recreate the collection + MilvusVectorDBStorage.create_collection_if_not_exist( + self._client, + self.namespace, + dimension=self.embedding_func.embedding_dim, + ) + + logger.info( + f"Process {os.getpid()} drop Milvus collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Milvus collection {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index e4ae0a8de40b94592d94c5c6760686269293a490..dd4f7447dab50692b2da2bf00bfd87dbb3973261 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"): if not pm.is_installed("motor"): pm.install("motor") -from motor.motor_asyncio import ( +from motor.motor_asyncio import ( # type: ignore AsyncIOMotorClient, AsyncIOMotorDatabase, AsyncIOMotorCollection, ) -from pymongo.operations import SearchIndexModel -from pymongo.errors import PyMongoError +from pymongo.operations import SearchIndexModel # type: ignore +from pymongo.errors import PyMongoError # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage): # Mongo handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete documents with specified IDs + + Args: + ids: List of document IDs to be deleted + """ + if not ids: + return + + try: + result = await self._data.delete_many({"_id": {"$in": ids}}) + logger.info( + f"Deleted {result.deleted_count} documents from {self.namespace}" + ) + except PyMongoError as e: + logger.error(f"Error deleting documents from {self.namespace}: {e}") + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by cache mode + + Args: + modes (list[str]): List of cache modes to be dropped from storage + + Returns: + bool: True if successful, False otherwise + """ + if not modes: + return False + + try: + # Build regex pattern to match documents with the specified modes + pattern = f"^({'|'.join(modes)})_" + result = await self._data.delete_many({"_id": {"$regex": pattern}}) + logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}") + return True + except Exception as e: + logger.error(f"Error deleting cache by modes {modes}: {e}") + return False + + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all documents in the collection. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count + + logger.info( + f"Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error(f"Error dropping doc status {self._collection_name}: {e}") + return {"status": "error", "message": str(e)} + @final @dataclass @@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage): # Mongo handles persistence automatically pass + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all documents in the collection. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count + + logger.info( + f"Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error(f"Error dropping doc status {self._collection_name}: {e}") + return {"status": "error", "message": str(e)} + @final @dataclass @@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage): logger.debug(f"Successfully deleted edges: {edges}") + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all documents in the collection. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + result = await self.collection.delete_many({}) + deleted_count = result.deleted_count + + logger.info( + f"Dropped {deleted_count} documents from graph {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error(f"Error dropping graph {self._collection_name}: {e}") + return {"status": "error", "message": str(e)} + @final @dataclass @@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all documents in the collection and recreating vector index. + + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' + """ + try: + # Delete all documents + result = await self._data.delete_many({}) + deleted_count = result.deleted_count + + # Recreate vector index + await self.create_vector_index_if_not_exists() + + logger.info( + f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped and vector index recreated", + } + except PyMongoError as e: + logger.error(f"Error dropping vector storage {self._collection_name}: {e}") + return {"status": "error", "message": str(e)} + async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): collection_names = await db.list_collection_names() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 553ba0b29c2cfb3f2b2c6df906d76697fcac7d9a..56a52b923c2cc3dec737bb6af9292b426f0c64d4 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage): return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return @@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete(self, ids: list[str]): """Delete vectors with specified IDs + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + Args: ids: List of vector IDs to be deleted """ @@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error while deleting vectors from {self.namespace}: {e}") async def delete_entity(self, entity_name: str) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( @@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + try: client = await self._get_client() storage = getattr(client, "_NanoVectorDB__storage") @@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage): client = await self._get_client() return client.get(ids) + + async def drop(self) -> dict[str, str]: + """Drop all vector data from storage and clean up resources + + This method will: + 1. Remove the vector database storage file if it exists + 2. Reinitialize the vector database client + 3. Update flags to notify other processes + 4. Changes is persisted to disk immediately + + This method is intended for use in scenarios where all data needs to be removed, + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._storage_lock: + # delete _client_file_name + if os.path.exists(self._client_file_name): + os.remove(self._client_file_name) + + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + self.storage_updated.value = False + + logger.info( + f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 2df420dffcb6e344b03c6b820effeeb89662412c..b84a0c6a6becd5d764f2068a7c1adfb13a9a6c9b 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1,9 +1,8 @@ -import asyncio import inspect import os import re from dataclasses import dataclass -from typing import Any, final, Optional +from typing import Any, final import numpy as np import configparser @@ -29,7 +28,6 @@ from neo4j import ( # type: ignore exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, - GraphDatabase, ) config = configparser.ConfigParser() @@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage): embedding_func=embedding_func, ) self._driver = None - self._driver_lock = asyncio.Lock() + def __post_init__(self): + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + async def initialize(self): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) USERNAME = os.environ.get( "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) @@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage): ), ) DATABASE = os.environ.get( - "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) + "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( @@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage): max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) - # Try to connect to the database - with GraphDatabase.driver( - URI, - auth=(USERNAME, PASSWORD), - max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, - connection_timeout=CONNECTION_TIMEOUT, - connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, - ) as _sync_driver: - for database in (DATABASE, None): - self._DATABASE = database - connected = False + # Try to connect to the database and create it if it doesn't exist + for database in (DATABASE, None): + self._DATABASE = database + connected = False - try: - with _sync_driver.session(database=database) as session: - try: - session.run("MATCH (n) RETURN n LIMIT 0") - logger.info(f"Connected to {database} at {URI}") - connected = True - except neo4jExceptions.ServiceUnavailable as e: - logger.error( - f"{database} at {URI} is not available".capitalize() + try: + async with self._driver.session(database=database) as session: + try: + result = await session.run("MATCH (n) RETURN n LIMIT 0") + await result.consume() # Ensure result is consumed + logger.info(f"Connected to {database} at {URI}") + connected = True + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"{database} at {URI} is not available".capitalize() + ) + raise e + except neo4jExceptions.AuthError as e: + logger.error(f"Authentication failed for {database} at {URI}") + raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"{database} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + async with self._driver.session() as session: + result = await session.run( + f"CREATE DATABASE `{database}` IF NOT EXISTS" ) + await result.consume() # Ensure result is consumed + logger.info(f"{database} at {URI} created".capitalize()) + connected = True + except ( + neo4jExceptions.ClientError, + neo4jExceptions.DatabaseError, + ) as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): + if database is not None: + logger.warning( + "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." + ) + if database is None: + logger.error(f"Failed to create {database} at {URI}") raise e - except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {database} at {URI}") - raise e - except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Database.DatabaseNotFound": - logger.info( - f"{database} at {URI} not found. Try to create specified database.".capitalize() - ) + + if connected: + # Create index for base nodes on entity_id if it doesn't exist + try: + async with self._driver.session(database=database) as session: + # Check if index exists first + check_query = """ + CALL db.indexes() YIELD name, labelsOrTypes, properties + WHERE labelsOrTypes = ['base'] AND properties = ['entity_id'] + RETURN count(*) > 0 AS exists + """ try: - with _sync_driver.session() as session: - session.run( - f"CREATE DATABASE `{database}` IF NOT EXISTS" - ) - logger.info(f"{database} at {URI} created".capitalize()) - connected = True - except ( - neo4jExceptions.ClientError, - neo4jExceptions.DatabaseError, - ) as e: - if ( - e.code - == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" - ) or ( - e.code == "Neo.DatabaseError.Statement.ExecutionFailed" - ): - if database is not None: - logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." - ) - if database is None: - logger.error(f"Failed to create {database} at {URI}") - raise e + check_result = await session.run(check_query) + record = await check_result.single() + await check_result.consume() - if connected: - break + index_exists = record and record.get("exists", False) - def __post_init__(self): - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } + if not index_exists: + # Create index only if it doesn't exist + result = await session.run( + "CREATE INDEX FOR (n:base) ON (n.entity_id)" + ) + await result.consume() + logger.info( + f"Created index for base nodes on entity_id in {database}" + ) + except Exception: + # Fallback if db.indexes() is not supported in this Neo4j version + result = await session.run( + "CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)" + ) + await result.consume() + except Exception as e: + logger.warning(f"Failed to create index: {str(e)}") + break - async def close(self): + async def finalize(self): """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() @@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage): async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" - await self.close() + await self.finalize() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically @@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage): raise async def get_node(self, node_id: str) -> dict[str, str] | None: - """Get node by its label identifier. + """Get node by its label identifier, return only node properties Args: node_id: The node label to look up @@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage): logger.debug( f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}" ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } + # Return None when no edge found + return None finally: await result.consume() # Ensure result is fully consumed @@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage): """ properties = node_data entity_type = properties["entity_type"] - entity_id = properties["entity_id"] if "entity_id" not in properties: raise ValueError("Neo4j: node properties must contain an 'entity_id' field") @@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage): async def execute_upsert(tx: AsyncManagedTransaction): query = ( """ - MERGE (n:base {entity_id: $properties.entity_id}) + MERGE (n:base {entity_id: $entity_id}) SET n += $properties SET n:`%s` """ % entity_type ) - result = await tx.run(query, properties=properties) + result = await tx.run( + query, entity_id=node_id, properties=properties + ) logger.debug( - f"Upserted node with entity_id '{entity_id}' and properties: {properties}" + f"Upserted node with entity_id '{node_id}' and properties: {properties}" ) await result.consume() # Ensure result is fully consumed @@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - min_degree: int = 0, - inclusive: bool = False, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. - Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: - 1. min_degree does not affect nodes directly connected to the matching nodes - 2. Label matching nodes take precedence - 3. Followed by nodes directly connected to the matching nodes - 4. Finally, the degree of the nodes Args: - node_label: Label of the starting node - max_depth: Maximum depth of the subgraph - min_degree: Minimum degree of nodes to include. Defaults to 0 - inclusive: Do an inclusive search if true + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + Returns: - KnowledgeGraph: Complete connected subgraph for specified node + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ result = KnowledgeGraph() seen_nodes = set() @@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage): ) as session: try: if node_label == "*": + # First check total node count to determine if graph is truncated + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() + + # Run main query to get nodes with highest degree main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree - WHERE degree >= $min_degree ORDER BY degree DESC LIMIT $max_nodes WITH collect({node: n}) AS filtered_nodes @@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage): RETURN filtered_nodes AS node_info, collect(DISTINCT r) AS relationships """ - result_set = await session.run( - main_query, - {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, - ) + result_set = None + try: + result_set = await session.run( + main_query, + {"max_nodes": max_nodes}, + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() else: - # Main query uses partial matching - main_query = """ + # return await self._robust_fallback(node_label, max_depth, max_nodes) + # First try without limit to check if we need to truncate + full_query = """ MATCH (start) - WHERE - CASE - WHEN $inclusive THEN start.entity_id CONTAINS $entity_id - ELSE start.entity_id = $entity_id - END + WHERE start.entity_id = $entity_id WITH start CALL apoc.path.subgraphAll(start, { relationshipFilter: '', @@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage): bfs: true }) YIELD nodes, relationships - WITH start, nodes, relationships + WITH nodes, relationships, size(nodes) AS total_nodes UNWIND nodes AS node - OPTIONAL MATCH (node)-[r]-() - WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships - WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree - ORDER BY - CASE - WHEN node = start THEN 3 - WHEN EXISTS((start)--(node)) THEN 2 - ELSE 1 - END DESC, - degree DESC - LIMIT $max_nodes - WITH collect({node: node}) AS filtered_nodes - UNWIND filtered_nodes AS node_info - WITH collect(node_info.node) AS kept_nodes, filtered_nodes - OPTIONAL MATCH (a)-[r]-(b) - WHERE a IN kept_nodes AND b IN kept_nodes - RETURN filtered_nodes AS node_info, - collect(DISTINCT r) AS relationships + WITH collect({node: node}) AS node_info, relationships, total_nodes + RETURN node_info, relationships, total_nodes """ - result_set = await session.run( - main_query, - { - "max_nodes": MAX_GRAPH_NODES, - "entity_id": node_label, - "inclusive": inclusive, - "max_depth": max_depth, - "min_degree": min_degree, - }, - ) - try: - record = await result_set.single() - - if record: - # Handle nodes (compatible with multi-label cases) - for node_info in record["node_info"]: - node = node_info["node"] - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=[node.get("entity_id")], - properties=dict(node), - ) + # Try to get full result + full_result = None + try: + full_result = await session.run( + full_query, + { + "entity_id": node_label, + "max_depth": max_depth, + }, + ) + full_record = await full_result.single() + + # If no record found, return empty KnowledgeGraph + if not full_record: + logger.debug(f"No nodes found for entity_id: {node_label}") + return result + + # If record found, check node count + total_nodes = full_record["total_nodes"] + + if total_nodes <= max_nodes: + # If node count is within limit, use full result directly + logger.debug( + f"Using full result with {total_nodes} nodes (no truncation needed)" + ) + record = full_record + else: + # If node count exceeds limit, set truncated flag and run limited query + result.is_truncated = True + logger.info( + f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" + ) + + # Run limited query + limited_query = """ + MATCH (start) + WHERE start.entity_id = $entity_id + WITH start + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', + minLevel: 0, + maxLevel: $max_depth, + limit: $max_nodes, + bfs: true + }) + YIELD nodes, relationships + UNWIND nodes AS node + WITH collect({node: node}) AS node_info, relationships + RETURN node_info, relationships + """ + result_set = None + try: + result_set = await session.run( + limited_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, ) - seen_nodes.add(node_id) - - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + finally: + if full_result: + await full_result.consume() + + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) - logger.info( - f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" - ) - finally: - await result_set.consume() # Ensure result set is consumed + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) except neo4jExceptions.ClientError as e: logger.warning(f"APOC plugin error: {str(e)}") @@ -767,46 +837,89 @@ class Neo4JStorage(BaseGraphStorage): logger.warning( "Neo4j: falling back to basic Cypher recursive search..." ) - if inclusive: - logger.warning( - "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" - ) - return await self._robust_fallback( - node_label, max_depth, min_degree + return await self._robust_fallback(node_label, max_depth, max_nodes) + else: + logger.warning( + "Neo4j: APOC plugin error with wildcard query, returning empty result" ) return result async def _robust_fallback( - self, node_label: str, max_depth: int, min_degree: int = 0 + self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses - only basic Cypher queries and recursive traversal instead of APOC procedures. + only basic Cypher queries and true breadth-first traversal instead of APOC procedures. """ + from collections import deque + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() + visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id) - async def traverse( - node: KnowledgeGraphNode, - edge: Optional[KnowledgeGraphEdge], - current_depth: int, - ): - # Check traversal limits - if current_depth > max_depth: - logger.debug(f"Reached max depth: {max_depth}") - return - if len(visited_nodes) >= MAX_GRAPH_NODES: - logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") - return + # Get the starting node's data + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + MATCH (n:base {entity_id: $entity_id}) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query, entity_id=node_label) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=f"{node_record['n'].get('entity_id')}", + labels=[node_record["n"].get("entity_id")], + properties=dict(node_record["n"]._properties), + ) + finally: + await node_result.consume() # Ensure results are consumed - # Check if node already visited - if node.id in visited_nodes: - return + # Initialize queue for BFS with (node, edge, depth) tuples + # edge is None for the starting node + queue = deque([(start_node, None, 0)]) - # Get all edges and target nodes + # True BFS implementation using a queue + while queue and len(visited_nodes) < max_nodes: + # Dequeue the next node to process + current_node, current_edge, current_depth = queue.popleft() + + # Skip if already visited or exceeds max depth + if current_node.id in visited_nodes: + continue + + if current_depth > max_depth: + logger.debug( + f"Skipping node at depth {current_depth} (max_depth: {max_depth})" + ) + continue + + # Add current node to result + result.nodes.append(current_node) + visited_nodes.add(current_node.id) + + # Add edge to result if it exists and not already added + if current_edge and current_edge.id not in visited_edges: + result.edges.append(current_edge) + visited_edges.add(current_edge.id) + + # Stop if we've reached the node limit + if len(visited_nodes) >= max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" + ) + break + + # Get all edges and target nodes for the current node (even at max_depth) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -815,32 +928,17 @@ class Neo4JStorage(BaseGraphStorage): WITH r, b, id(r) as edge_id, id(b) as target_id RETURN r, b, edge_id, target_id """ - results = await session.run(query, entity_id=node.id) + results = await session.run(query, entity_id=current_node.id) # Get all records and release database connection - records = await results.fetch( - 1000 - ) # Max neighbour nodes we can handled + records = await results.fetch(1000) # Max neighbor nodes we can handle await results.consume() # Ensure results are consumed - # Nodes not connected to start node need to check degree - if current_depth > 1 and len(records) < min_degree: - return - - # Add current node to result - result.nodes.append(node) - visited_nodes.add(node.id) - - # Add edge to result if it exists and not already added - if edge and edge.id not in visited_edges: - result.edges.append(edge) - visited_edges.add(edge.id) - - # Prepare nodes and edges for recursive processing - nodes_to_process = [] + # Process all neighbors - capture all edges but only queue unvisited nodes for record in records: rel = record["r"] edge_id = str(record["edge_id"]) + if edge_id not in visited_edges: b_node = record["b"] target_id = b_node.get("entity_id") @@ -849,55 +947,59 @@ class Neo4JStorage(BaseGraphStorage): # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( id=f"{target_id}", - labels=list(f"{target_id}"), - properties=dict(b_node.properties), + labels=[target_id], + properties=dict(b_node._properties), ) # Create KnowledgeGraphEdge target_edge = KnowledgeGraphEdge( id=f"{edge_id}", type=rel.type, - source=f"{node.id}", + source=f"{current_node.id}", target=f"{target_id}", properties=dict(rel), ) - nodes_to_process.append((target_node, target_edge)) + # 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边 + sorted_pair = tuple(sorted([current_node.id, target_id])) + + # 检查是否已存在相同的边(考虑无向性) + if sorted_pair not in visited_edge_pairs: + # 只有当目标节点已经在结果中或将被添加到结果中时,才添加边 + if target_id in visited_nodes or ( + target_id not in visited_nodes + and current_depth < max_depth + ): + result.edges.append(target_edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + + # Only add unvisited nodes to the queue for further expansion + if target_id not in visited_nodes: + # Only add to queue if we're not at max depth yet + if current_depth < max_depth: + # Add node to queue with incremented depth + # Edge is already added to result, so we pass None as edge + queue.append((target_node, None, current_depth + 1)) + else: + # At max depth, we've already added the edge but we don't add the node + # This prevents adding nodes beyond max_depth to the result + logger.debug( + f"Node {target_id} beyond max depth {max_depth}, edge added but node not included" + ) + else: + # If target node already exists in result, we don't need to add it again + logger.debug( + f"Node {target_id} already visited, edge added but node not queued" + ) else: logger.warning( - f"Skipping edge {edge_id} due to missing labels on target node" + f"Skipping edge {edge_id} due to missing entity_id on target node" ) - # Process nodes after releasing database connection - for target_node, target_edge in nodes_to_process: - await traverse(target_node, target_edge, current_depth + 1) - - # Get the starting node's data - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) - RETURN id(n) as node_id, n - """ - node_result = await session.run(query, entity_id=node_label) - try: - node_record = await node_result.single() - if not node_record: - return result - - # Create initial KnowledgeGraphNode - start_node = KnowledgeGraphNode( - id=f"{node_record['n'].get('entity_id')}", - labels=list(f"{node_record['n'].get('entity_id')}"), - properties=dict(node_record["n"].properties), - ) - finally: - await node_result.consume() # Ensure results are consumed - - # Start traversal with the initial node - await traverse(start_node, None, 0) - + logger.info( + f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) return result async def get_all_labels(self) -> list[str]: @@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage): # Method 2: Query compatible with older versions query = """ - MATCH (n) + MATCH (n:base) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label @@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage): self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + + async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This method will delete all nodes and relationships in the Neo4j database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._driver.session(database=self._DATABASE) as session: + # Delete all nodes and relationships + query = "MATCH (n) DETACH DELETE n" + result = await session.run(query) + await result.consume() # Ensure result is fully consumed + + logger.info( + f"Process {os.getpid()} drop Neo4j database {self._DATABASE}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 324fe7af89d043d620924fd5bc59fa3675ce4354..2bea85a3c5fa84e3c767028339a44c38f42bb133 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage): ) nx.write_graphml(graph, file_name) + # TODO:deprecated, remove later @staticmethod def _stabilize_graph(graph: nx.Graph) -> nx.Graph: """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py @@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage): return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ graph = await self._get_graph() graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ graph = await self._get_graph() graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: + """ + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ graph = await self._get_graph() if graph.has_node(node_id): graph.remove_node(node_id) @@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage): else: logger.warning(f"Node {node_id} not found in the graph for deletion.") + # TODO: NOT USED async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: @@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage): async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + Args: nodes: List of node IDs to be deleted """ @@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage): async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ @@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 3, - min_degree: int = 0, - inclusive: bool = False, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. - Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: - 1. min_degree does not affect nodes directly connected to the matching nodes - 2. Label matching nodes take precedence - 3. Followed by nodes directly connected to the matching nodes - 4. Finally, the degree of the nodes Args: - node_label: Label of the starting node - max_depth: Maximum depth of the subgraph - min_degree: Minimum degree of nodes to include. Defaults to 0 - inclusive: Do an inclusive search if true + node_label: Label of the starting node,* means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: - KnowledgeGraph object containing nodes and edges + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ - result = KnowledgeGraph() - seen_nodes = set() - seen_edges = set() - graph = await self._get_graph() - # Initialize sets for start nodes and direct connected nodes - start_nodes = set() - direct_connected_nodes = set() + result = KnowledgeGraph() # Handle special case for "*" label if node_label == "*": - # For "*", return the entire graph including all nodes and edges - subgraph = ( - graph.copy() - ) # Create a copy to avoid modifying the original graph + # Get degrees of all nodes + degrees = dict(graph.degree()) + # Sort nodes by degree in descending order and take top max_nodes + sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True) + + # Check if graph is truncated + if len(sorted_nodes) > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}" + ) + + limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]] + # Create subgraph with the highest degree nodes + subgraph = graph.subgraph(limited_nodes) else: - # Find nodes with matching node id based on search_mode - nodes_to_explore = [] - for n, attr in graph.nodes(data=True): - node_str = str(n) - if not inclusive: - if node_label == node_str: # Use exact matching - nodes_to_explore.append(n) - else: # inclusive mode - if node_label in node_str: # Use partial matching - nodes_to_explore.append(n) - - if not nodes_to_explore: - logger.warning(f"No nodes found with label {node_label}") - return result - - # Get subgraph using ego_graph from all matching nodes - combined_subgraph = nx.Graph() - for start_node in nodes_to_explore: - node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) - combined_subgraph = nx.compose(combined_subgraph, node_subgraph) - - # Get start nodes and direct connected nodes - if nodes_to_explore: - start_nodes = set(nodes_to_explore) - # Get nodes directly connected to all start nodes - for start_node in start_nodes: - direct_connected_nodes.update( - combined_subgraph.neighbors(start_node) - ) - - # Remove start nodes from directly connected nodes (avoid duplicates) - direct_connected_nodes -= start_nodes - - subgraph = combined_subgraph - - # Filter nodes based on min_degree, but keep start nodes and direct connected nodes - if min_degree > 0: - nodes_to_keep = [ - node - for node, degree in subgraph.degree() - if node in start_nodes - or node in direct_connected_nodes - or degree >= min_degree - ] - subgraph = subgraph.subgraph(nodes_to_keep) - - # Check if number of nodes exceeds max_graph_nodes - if len(subgraph.nodes()) > MAX_GRAPH_NODES: - origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - - def priority_key(node_item): - node, degree = node_item - # Priority order: start(2) > directly connected(1) > other nodes(0) - if node in start_nodes: - priority = 2 - elif node in direct_connected_nodes: - priority = 1 - else: - priority = 0 - return (priority, degree) - - # Sort by priority and degree and select top MAX_GRAPH_NODES nodes - top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[ - :MAX_GRAPH_NODES - ] - top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph and keep nodes only with most degree - subgraph = subgraph.subgraph(top_node_ids) - logger.info( - f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})" - ) + # Check if node exists + if node_label not in graph: + logger.warning(f"Node {node_label} not found in the graph") + return KnowledgeGraph() # Return empty graph + + # Use BFS to get nodes + bfs_nodes = [] + visited = set() + queue = [(node_label, 0)] # (node, depth) tuple + + # Breadth-first search + while queue and len(bfs_nodes) < max_nodes: + current, depth = queue.pop(0) + if current not in visited: + visited.add(current) + bfs_nodes.append(current) + + # Only explore neighbors if we haven't reached max_depth + if depth < max_depth: + # Add neighbor nodes to queue with incremented depth + neighbors = list(graph.neighbors(current)) + queue.extend( + [(n, depth + 1) for n in neighbors if n not in visited] + ) + + # Check if graph is truncated - if we still have nodes in the queue + # and we've reached max_nodes, then the graph is truncated + if queue and len(bfs_nodes) >= max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + ) + + # Create subgraph with BFS discovered nodes + subgraph = graph.subgraph(bfs_nodes) # Add nodes to result + seen_nodes = set() + seen_edges = set() for node in subgraph.nodes(): if str(node) in seen_nodes: continue @@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage): for edge in subgraph.edges(): source, target = edge # Esure unique edge_id for undirect graph - if source > target: + if str(source) > str(target): source, target = target, source edge_id = f"{source}-{target}" if edge_id in seen_edges: @@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage): return False # Return error return True + + async def drop(self) -> dict[str, str]: + """Drop all graph data from storage and clean up resources + + This method will: + 1. Remove the graph storage file if it exists + 2. Reset the graph to an empty state + 3. Update flags to notify other processes + 4. Changes is persisted to disk immediately + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + async with self._storage_lock: + # delete _client_file_name + if os.path.exists(self._graphml_xml_file): + os.remove(self._graphml_xml_file) + self._graph = nx.Graph() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + self.storage_updated.value = False + logger.info( + f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping graph {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py deleted file mode 100644 index c42f0f767bf58797455914cf7884cab6cc6c3407..0000000000000000000000000000000000000000 --- a/lightrag/kg/oracle_impl.py +++ /dev/null @@ -1,1346 +0,0 @@ -import array -import asyncio - -# import html -import os -from dataclasses import dataclass, field -from typing import Any, Union, final -import numpy as np -import configparser - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - -from ..base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, -) -from ..namespace import NameSpace, is_namespace -from ..utils import logger - -import pipmaster as pm - -if not pm.is_installed("graspologic"): - pm.install("graspologic") - -if not pm.is_installed("oracledb"): - pm.install("oracledb") - -from graspologic import embed -import oracledb - - -class OracleDB: - def __init__(self, config, **kwargs): - self.host = config.get("host", None) - self.port = config.get("port", None) - self.user = config.get("user", None) - self.password = config.get("password", None) - self.dsn = config.get("dsn", None) - self.config_dir = config.get("config_dir", None) - self.wallet_location = config.get("wallet_location", None) - self.wallet_password = config.get("wallet_password", None) - self.workspace = config.get("workspace", None) - self.max = 12 - self.increment = 1 - logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") - if self.user is None or self.password is None: - raise ValueError("Missing database user or password") - - try: - oracledb.defaults.fetch_lobs = False - - self.pool = oracledb.create_pool_async( - user=self.user, - password=self.password, - dsn=self.dsn, - config_dir=self.config_dir, - wallet_location=self.wallet_location, - wallet_password=self.wallet_password, - min=1, - max=self.max, - increment=self.increment, - ) - logger.info(f"Connected to Oracle database at {self.dsn}") - except Exception as e: - logger.error(f"Failed to connect to Oracle database at {self.dsn}") - logger.error(f"Oracle database error: {e}") - raise - - def numpy_converter_in(self, value): - """Convert numpy array to array.array""" - if value.dtype == np.float64: - dtype = "d" - elif value.dtype == np.float32: - dtype = "f" - else: - dtype = "b" - return array.array(dtype, value) - - def input_type_handler(self, cursor, value, arraysize): - """Set the type handler for the input data""" - if isinstance(value, np.ndarray): - return cursor.var( - oracledb.DB_TYPE_VECTOR, - arraysize=arraysize, - inconverter=self.numpy_converter_in, - ) - - def numpy_converter_out(self, value): - """Convert array.array to numpy array""" - if value.typecode == "b": - dtype = np.int8 - elif value.typecode == "f": - dtype = np.float32 - else: - dtype = np.float64 - return np.array(value, copy=False, dtype=dtype) - - def output_type_handler(self, cursor, metadata): - """Set the type handler for the output data""" - if metadata.type_code is oracledb.DB_TYPE_VECTOR: - return cursor.var( - metadata.type_code, - arraysize=cursor.arraysize, - outconverter=self.numpy_converter_out, - ) - - async def check_tables(self): - for k, v in TABLES.items(): - try: - if k.lower() == "lightrag_graph": - await self.query( - "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only" - ) - else: - await self.query(f"SELECT 1 FROM {k}") - except Exception as e: - logger.error(f"Failed to check table {k} in Oracle database") - logger.error(f"Oracle database error: {e}") - try: - # print(v["ddl"]) - await self.execute(v["ddl"]) - logger.info(f"Created table {k} in Oracle database") - except Exception as e: - logger.error(f"Failed to create table {k} in Oracle database") - logger.error(f"Oracle database error: {e}") - - logger.info("Finished check all tables in Oracle database") - - async def query( - self, sql: str, params: dict = None, multirows: bool = False - ) -> Union[dict, None]: - async with self.pool.acquire() as connection: - connection.inputtypehandler = self.input_type_handler - connection.outputtypehandler = self.output_type_handler - with connection.cursor() as cursor: - try: - await cursor.execute(sql, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - raise - columns = [column[0].lower() for column in cursor.description] - if multirows: - rows = await cursor.fetchall() - if rows: - data = [dict(zip(columns, row)) for row in rows] - else: - data = [] - else: - row = await cursor.fetchone() - if row: - data = dict(zip(columns, row)) - else: - data = None - return data - - async def execute(self, sql: str, data: Union[list, dict] = None): - # logger.info("go into OracleDB execute method") - try: - async with self.pool.acquire() as connection: - connection.inputtypehandler = self.input_type_handler - connection.outputtypehandler = self.output_type_handler - with connection.cursor() as cursor: - if data is None: - await cursor.execute(sql) - else: - await cursor.execute(sql, data) - await connection.commit() - except Exception as e: - logger.error(f"Oracle database error: {e}") - raise - - -class ClientManager: - _instances: dict[str, Any] = {"db": None, "ref_count": 0} - _lock = asyncio.Lock() - - @staticmethod - def get_config() -> dict[str, Any]: - config = configparser.ConfigParser() - config.read("config.ini", "utf-8") - - return { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - @classmethod - async def get_client(cls) -> OracleDB: - async with cls._lock: - if cls._instances["db"] is None: - config = ClientManager.get_config() - db = OracleDB(config) - await db.check_tables() - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] - - @classmethod - async def release_client(cls, db: OracleDB): - async with cls._lock: - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - await db.pool.close() - logger.info("Closed OracleDB database connection pool") - cls._instances["db"] = None - else: - await db.pool.close() - - -@final -@dataclass -class OracleKVStorage(BaseKVStorage): - db: OracleDB = field(default=None) - meta_fields = None - - def __post_init__(self): - self._data = {} - self._max_batch_size = self.global_config.get("embedding_batch_num", 10) - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - ################ QUERY METHODS ################ - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get doc_full data based on id.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace": self.db.workspace, "id": id} - # print("get_by_id:"+SQL) - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(SQL, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - if res: - return res - else: - return None - else: - return await self.db.query(SQL, params) - - async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: - """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] - params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(SQL, params, multirows=True) - res = {} - for row in array_res: - res[row["id"]] = row - return res - else: - return None - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get doc_chunks data based on id""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - # print("get_by_ids:"+SQL) - res = await self.db.query(SQL, params, multirows=True) - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - modes = set() - dict_res: dict[str, dict] = {} - for row in res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in res: - dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] - return res - - async def filter_keys(self, keys: set[str]) -> set[str]: - """Return keys that don't exist in storage""" - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), - ids=",".join([f"'{id}'" for id in keys]), - ) - params = {"workspace": self.db.workspace} - res = await self.db.query(SQL, params, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - data = set([s for s in keys if s not in exist_keys]) - return data - else: - return set(keys) - - ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - list_data = [ - { - "id": k, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - - merge_sql = SQL_TEMPLATES["merge_chunk"] - for item in list_data: - _data = { - "id": item["id"], - "content": item["content"], - "workspace": self.db.workspace, - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content_vector": item["__vector__"], - "status": item["status"], - } - await self.db.execute(merge_sql, _data) - if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): - for k, v in data.items(): - # values.clear() - merge_sql = SQL_TEMPLATES["merge_doc_full"] - _data = { - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - } - await self.db.execute(merge_sql, _data) - - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - for mode, items in data.items(): - for k, v in items.items(): - upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - _data = { - "workspace": self.db.workspace, - "id": k, - "original_prompt": v["original_prompt"], - "return_value": v["return"], - "cache_mode": mode, - } - - await self.db.execute(upsert_sql, _data) - - async def index_done_callback(self) -> None: - # Oracle handles persistence automatically - pass - - -@final -@dataclass -class OracleVectorDBStorage(BaseVectorStorage): - db: OracleDB | None = field(default=None) - - def __post_init__(self): - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") - if cosine_threshold is None: - raise ValueError( - "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" - ) - self.cosine_better_than_threshold = cosine_threshold - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - #################### query method ############### - async def query( - self, query: str, top_k: int, ids: list[str] | None = None - ) -> list[dict[str, Any]]: - embeddings = await self.embedding_func([query]) - embedding = embeddings[0] - # 转换精度 - dtype = str(embedding.dtype).upper() - dimension = embedding.shape[0] - embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" - - SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) - params = { - "embedding_string": embedding_string, - "workspace": self.db.workspace, - "top_k": top_k, - "better_than_threshold": self.cosine_better_than_threshold, - } - results = await self.db.query(SQL, params=params, multirows=True) - return results - - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - raise NotImplementedError - - async def index_done_callback(self) -> None: - # Oracles handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete vectors with specified IDs - - Args: - ids: List of vector IDs to be deleted - """ - if not ids: - return - - try: - SQL = SQL_TEMPLATES["delete_vectors"].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - await self.db.execute(SQL, params) - logger.info( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - raise - - async def delete_entity(self, entity_name: str) -> None: - """Delete entity by name - - Args: - entity_name: Name of the entity to delete - """ - try: - SQL = SQL_TEMPLATES["delete_entity"] - params = {"workspace": self.db.workspace, "entity_name": entity_name} - await self.db.execute(SQL, params) - logger.info(f"Successfully deleted entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - raise - - async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations connected to an entity - - Args: - entity_name: Name of the entity whose relations should be deleted - """ - try: - SQL = SQL_TEMPLATES["delete_entity_relations"] - params = {"workspace": self.db.workspace, "entity_name": entity_name} - await self.db.execute(SQL, params) - logger.info(f"Successfully deleted relations for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for entity {entity_name}: {e}") - raise - - async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]: - """Search for records with IDs starting with a specific prefix. - - Args: - prefix: The prefix to search for in record IDs - - Returns: - List of records with matching ID prefixes - """ - try: - # Determine the appropriate table based on namespace - table_name = namespace_to_table_name(self.namespace) - - # Create SQL query to find records with IDs starting with prefix - search_sql = f""" - SELECT * FROM {table_name} - WHERE workspace = :workspace - AND id LIKE :prefix_pattern - ORDER BY id - """ - - params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"} - - # Execute query and get results - results = await self.db.query(search_sql, params, multirows=True) - - logger.debug( - f"Found {len(results) if results else 0} records with prefix '{prefix}'" - ) - return results or [] - - except Exception as e: - logger.error(f"Error searching records with prefix '{prefix}': {e}") - return [] - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get vector data by its ID - - Args: - id: The unique identifier of the vector - - Returns: - The vector data if found, or None if not found - """ - try: - # Determine the table name based on namespace - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for ID lookup: {self.namespace}") - return None - - # Create the appropriate ID field name based on namespace - id_field = "entity_id" if "NODES" in table_name else "relation_id" - if "CHUNKS" in table_name: - id_field = "chunk_id" - - # Prepare and execute the query - query = f""" - SELECT * FROM {table_name} - WHERE {id_field} = :id AND workspace = :workspace - """ - params = {"id": id, "workspace": self.db.workspace} - - result = await self.db.query(query, params) - return result - except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") - return None - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get multiple vector data by their IDs - - Args: - ids: List of unique identifiers - - Returns: - List of vector data objects that were found - """ - if not ids: - return [] - - try: - # Determine the table name based on namespace - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") - return [] - - # Create the appropriate ID field name based on namespace - id_field = "entity_id" if "NODES" in table_name else "relation_id" - if "CHUNKS" in table_name: - id_field = "chunk_id" - - # Format the list of IDs for SQL IN clause - ids_list = ", ".join([f"'{id}'" for id in ids]) - - # Prepare and execute the query - query = f""" - SELECT * FROM {table_name} - WHERE {id_field} IN ({ids_list}) AND workspace = :workspace - """ - params = {"workspace": self.db.workspace} - - results = await self.db.query(query, params, multirows=True) - return results or [] - except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") - return [] - - -@final -@dataclass -class OracleGraphStorage(BaseGraphStorage): - db: OracleDB = field(default=None) - - def __post_init__(self): - self._max_batch_size = self.global_config.get("embedding_batch_num", 10) - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - #################### insert method ################ - - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - entity_name = node_id - entity_type = node_data["entity_type"] - description = node_data["description"] - source_id = node_data["source_id"] - logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") - - content = entity_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"] - data = { - "workspace": self.db.workspace, - "name": entity_name, - "entity_type": entity_type, - "description": description, - "source_chunk_id": source_id, - "content": content, - "content_vector": content_vector, - } - await self.db.execute(merge_sql, data) - # self._graph.add_node(node_id, **node_data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """插入或更新边""" - # print("go into upsert edge method") - source_name = source_node_id - target_name = target_node_id - weight = edge_data["weight"] - keywords = edge_data["keywords"] - description = edge_data["description"] - source_chunk_id = edge_data["source_id"] - logger.debug( - f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" - ) - - content = keywords + source_name + target_name + description - contents = [content] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"] - data = { - "workspace": self.db.workspace, - "source_name": source_name, - "target_name": target_name, - "weight": weight, - "keywords": keywords, - "description": description, - "source_chunk_id": source_chunk_id, - "content": content, - "content_vector": content_vector, - } - # print(merge_sql) - await self.db.execute(merge_sql, data) - # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - - async def embed_nodes( - self, algorithm: str - ) -> tuple[np.ndarray[Any, Any], list[str]]: - if algorithm not in self._node_embed_algorithms: - raise ValueError(f"Node embedding algorithm {algorithm} not supported") - return await self._node_embed_algorithms[algorithm]() - - async def _node2vec_embed(self): - """为节点生成向量""" - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids - - async def index_done_callback(self) -> None: - # Oracles handles persistence automatically - pass - - #################### query method ################# - async def has_node(self, node_id: str) -> bool: - """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"] - params = {"workspace": self.db.workspace, "node_id": node_id} - res = await self.db.query(SQL, params) - if res: - # print("Node exist!",res) - return True - else: - # print("Node not exist!") - return False - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - SQL = SQL_TEMPLATES["has_edge"] - params = { - "workspace": self.db.workspace, - "source_node_id": source_node_id, - "target_node_id": target_node_id, - } - res = await self.db.query(SQL, params) - if res: - # print("Edge exist!",res) - return True - else: - # print("Edge not exist!") - return False - - async def node_degree(self, node_id: str) -> int: - SQL = SQL_TEMPLATES["node_degree"] - params = {"workspace": self.db.workspace, "node_id": node_id} - res = await self.db.query(SQL, params) - if res: - return res["degree"] - else: - return 0 - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - """根据源和目标节点id获取边的度""" - degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) - return degree - - async def get_node(self, node_id: str) -> dict[str, str] | None: - """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"] - params = {"workspace": self.db.workspace, "node_id": node_id} - res = await self.db.query(SQL, params) - if res: - return res - else: - return None - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - SQL = SQL_TEMPLATES["get_edge"] - params = { - "workspace": self.db.workspace, - "source_node_id": source_node_id, - "target_node_id": target_node_id, - } - res = await self.db.query(SQL, params) - if res: - # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) - return res - else: - # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) - return None - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"] - params = {"workspace": self.db.workspace, "source_node_id": source_node_id} - res = await self.db.query(sql=SQL, params=params, multirows=True) - if res: - data = [(i["source_name"], i["target_name"]) for i in res] - # print("Get node edge!",self.db.workspace, source_node_id,data) - return data - else: - # print("Node Edge not exist!",self.db.workspace, source_node_id) - return [] - - async def get_all_nodes(self, limit: int): - """查询所有节点""" - SQL = SQL_TEMPLATES["get_all_nodes"] - params = {"workspace": self.db.workspace, "limit": str(limit)} - res = await self.db.query(sql=SQL, params=params, multirows=True) - if res: - return res - - async def get_all_edges(self, limit: int): - """查询所有边""" - SQL = SQL_TEMPLATES["get_all_edges"] - params = {"workspace": self.db.workspace, "limit": str(limit)} - res = await self.db.query(sql=SQL, params=params, multirows=True) - if res: - return res - - async def get_statistics(self): - SQL = SQL_TEMPLATES["get_statistics"] - params = {"workspace": self.db.workspace} - res = await self.db.query(sql=SQL, params=params, multirows=True) - if res: - return res - - async def delete_node(self, node_id: str) -> None: - """Delete a node from the graph - - Args: - node_id: ID of the node to delete - """ - try: - # First delete all relations connected to this node - delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] - params_relations = {"workspace": self.db.workspace, "entity_name": node_id} - await self.db.execute(delete_relations_sql, params_relations) - - # Then delete the node itself - delete_node_sql = SQL_TEMPLATES["delete_entity"] - params_node = {"workspace": self.db.workspace, "entity_name": node_id} - await self.db.execute(delete_node_sql, params_node) - - logger.info( - f"Successfully deleted node {node_id} and all its relationships" - ) - except Exception as e: - logger.error(f"Error deleting node {node_id}: {e}") - raise - - async def remove_nodes(self, nodes: list[str]) -> None: - """Delete multiple nodes from the graph - - Args: - nodes: List of node IDs to be deleted - """ - if not nodes: - return - - try: - for node in nodes: - # For each node, first delete all its relationships - delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"] - params_relations = {"workspace": self.db.workspace, "entity_name": node} - await self.db.execute(delete_relations_sql, params_relations) - - # Then delete the node itself - delete_node_sql = SQL_TEMPLATES["delete_entity"] - params_node = {"workspace": self.db.workspace, "entity_name": node} - await self.db.execute(delete_node_sql, params_node) - - logger.info( - f"Successfully deleted {len(nodes)} nodes and their relationships" - ) - except Exception as e: - logger.error(f"Error during batch node deletion: {e}") - raise - - async def remove_edges(self, edges: list[tuple[str, str]]) -> None: - """Delete multiple edges from the graph - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - if not edges: - return - - try: - for source, target in edges: - # Check if the edge exists before attempting to delete - if await self.has_edge(source, target): - # Delete the edge using a SQL query that matches both source and target - delete_edge_sql = """ - DELETE FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace - AND source_name = :source_name - AND target_name = :target_name - """ - params = { - "workspace": self.db.workspace, - "source_name": source, - "target_name": target, - } - await self.db.execute(delete_edge_sql, params) - - logger.info(f"Successfully deleted {len(edges)} edges from the graph") - except Exception as e: - logger.error(f"Error during batch edge deletion: {e}") - raise - - async def get_all_labels(self) -> list[str]: - """Get all unique entity types (labels) in the graph - - Returns: - List of unique entity types/labels - """ - try: - SQL = """ - SELECT DISTINCT entity_type - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY entity_type - """ - params = {"workspace": self.db.workspace} - results = await self.db.query(SQL, params, multirows=True) - - if results: - labels = [row["entity_type"] for row in results] - return labels - else: - return [] - except Exception as e: - logger.error(f"Error retrieving entity types: {e}") - return [] - - async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 - ) -> KnowledgeGraph: - """Retrieve a connected subgraph starting from nodes matching the given label - - Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable. - Prioritizes nodes by: - 1. Nodes matching the specified label - 2. Nodes directly connected to matching nodes - 3. Node degree (number of connections) - - Args: - node_label: Label to match for starting nodes (use "*" for all nodes) - max_depth: Maximum depth of traversal from starting nodes - - Returns: - KnowledgeGraph object containing nodes and edges - """ - result = KnowledgeGraph() - - try: - # Define maximum number of nodes to return - max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000)) - - if node_label == "*": - # For "*" label, get all nodes up to the limit - nodes_sql = """ - SELECT name, entity_type, description, source_chunk_id - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - ORDER BY id - FETCH FIRST :limit ROWS ONLY - """ - nodes_params = { - "workspace": self.db.workspace, - "limit": max_graph_nodes, - } - nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) - else: - # For specific label, find matching nodes and related nodes - nodes_sql = """ - WITH matching_nodes AS ( - SELECT name - FROM LIGHTRAG_GRAPH_NODES - WHERE workspace = :workspace - AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%') - ) - SELECT n.name, n.entity_type, n.description, n.source_chunk_id, - CASE - WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2 - WHEN EXISTS ( - SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e - WHERE workspace = :workspace - AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes)) - OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes))) - ) THEN 1 - ELSE 0 - END AS priority, - (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e - WHERE workspace = :workspace - AND (e.source_name = n.name OR e.target_name = n.name)) AS degree - FROM LIGHTRAG_GRAPH_NODES n - WHERE workspace = :workspace - ORDER BY priority DESC, degree DESC - FETCH FIRST :limit ROWS ONLY - """ - nodes_params = { - "workspace": self.db.workspace, - "node_label": node_label, - "limit": max_graph_nodes, - } - nodes = await self.db.query(nodes_sql, nodes_params, multirows=True) - - if not nodes: - logger.warning(f"No nodes found matching '{node_label}'") - return result - - # Create mapping of node IDs to be used to filter edges - node_names = [node["name"] for node in nodes] - - # Add nodes to result - seen_nodes = set() - for node in nodes: - node_id = node["name"] - if node_id in seen_nodes: - continue - - # Create node properties dictionary - properties = { - "entity_type": node["entity_type"], - "description": node["description"] or "", - "source_id": node["source_chunk_id"] or "", - } - - # Add node to result - result.nodes.append( - KnowledgeGraphNode( - id=node_id, labels=[node["entity_type"]], properties=properties - ) - ) - seen_nodes.add(node_id) - - # Get edges between these nodes - edges_sql = """ - SELECT source_name, target_name, weight, keywords, description, source_chunk_id - FROM LIGHTRAG_GRAPH_EDGES - WHERE workspace = :workspace - AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) - AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST))) - ORDER BY id - """ - edges_params = {"workspace": self.db.workspace, "node_names": node_names} - edges = await self.db.query(edges_sql, edges_params, multirows=True) - - # Add edges to result - seen_edges = set() - for edge in edges: - source = edge["source_name"] - target = edge["target_name"] - edge_id = f"{source}-{target}" - - if edge_id in seen_edges: - continue - - # Create edge properties dictionary - properties = { - "weight": edge["weight"] or 0.0, - "keywords": edge["keywords"] or "", - "description": edge["description"] or "", - "source_id": edge["source_chunk_id"] or "", - } - - # Add edge to result - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="RELATED", - source=source, - target=target, - properties=properties, - ) - ) - seen_edges.add(edge_id) - - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) - - except Exception as e: - logger.error(f"Error retrieving knowledge graph: {e}") - - return result - - -N_T = { - NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", - NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", -} - - -def namespace_to_table_name(namespace: str) -> str: - for k, v in N_T.items(): - if is_namespace(namespace, k): - return v - - -TABLES = { - "LIGHTRAG_DOC_FULL": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256), - workspace varchar(1024), - doc_name varchar(1024), - content CLOB, - meta JSON, - content_summary varchar(1024), - content_length NUMBER, - status varchar(256), - chunks_count NUMBER, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL, - error varchar(4096) - )""" - }, - "LIGHTRAG_DOC_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id varchar(256), - workspace varchar(1024), - full_doc_id varchar(256), - status varchar(256), - chunk_order_index NUMBER, - tokens NUMBER, - content CLOB, - content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )""" - }, - "LIGHTRAG_GRAPH_NODES": { - "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES ( - id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, - workspace varchar(1024), - name varchar(2048), - entity_type varchar(1024), - description CLOB, - source_chunk_id varchar(256), - content CLOB, - content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )""" - }, - "LIGHTRAG_GRAPH_EDGES": { - "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES ( - id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, - workspace varchar(1024), - source_name varchar(2048), - target_name varchar(2048), - weight NUMBER, - keywords CLOB, - description CLOB, - source_chunk_id varchar(256), - content CLOB, - content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )""" - }, - "LIGHTRAG_LLM_CACHE": { - "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - workspace varchar(1024), - cache_mode varchar(256), - model_name varchar(256), - original_prompt clob, - return_value clob, - embedding CLOB, - embedding_shape NUMBER, - embedding_min NUMBER, - embedding_max NUMBER, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL - )""" - }, - "LIGHTRAG_GRAPH": { - "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph - VERTEX TABLES ( - lightrag_graph_nodes KEY (id) - LABEL entity - PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id) - ) - EDGE TABLES ( - lightrag_graph_edges KEY (id) - SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name) - DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name) - LABEL has_relation - PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id) - ) OPTIONS(ALLOW MIXED PROPERTY TYPES)""" - }, -} - - -SQL_TEMPLATES = { - # SQL for KVStorage - "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" - FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" - FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" - FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", - "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", - "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", - "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", - "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", - "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :id and a.workspace = :workspace) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS - USING DUAL - ON (id = :id and workspace = :workspace) - WHEN NOT MATCHED THEN INSERT - (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, - "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a - USING DUAL - ON (a.id = :id) - WHEN NOT MATCHED THEN - INSERT (workspace,id,original_prompt,return_value,cache_mode) - VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) - WHEN MATCHED THEN UPDATE - SET original_prompt = :original_prompt, - return_value = :return_value, - cache_mode = :cache_mode, - updatetime = SYSDATE""", - # SQL for VectorStorage - "entities": """SELECT name as entity_name FROM - (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace) - WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", - "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM - (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace) - WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", - "chunks": """SELECT id FROM - (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace) - WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", - # SQL for GraphStorage - "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph - MATCH (a) - WHERE a.workspace=:workspace AND a.name=:node_id - COLUMNS (a.name))""", - "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph - MATCH (a) -[e]-> (b) - WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace - AND a.name=:source_node_id AND b.name=:target_node_id - COLUMNS (e.source_name,e.target_name) )""", - "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph - MATCH (a)-[e]->(b) - WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace - AND a.name=:node_id or b.name = :node_id - COLUMNS (a.name))""", - "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description - FROM GRAPH_TABLE (lightrag_graph - MATCH (a) - WHERE a.workspace=:workspace AND a.name=:node_id - COLUMNS (a.name) - ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name - WHERE t2.workspace=:workspace""", - "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, - NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords - FROM GRAPH_TABLE (lightrag_graph - MATCH (a)-[e]->(b) - WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace - AND a.name=:source_node_id and b.name = :target_node_id - COLUMNS (e.id,a.name as source_id) - ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", - "get_node_edges": """SELECT source_name,target_name - FROM GRAPH_TABLE (lightrag_graph - MATCH (a)-[e]->(b) - WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace - AND a.name=:source_node_id - COLUMNS (a.name as source_name,b.name as target_name))""", - "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a - USING DUAL - ON (a.workspace=:workspace and a.name=:name) - WHEN NOT MATCHED THEN - INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) - WHEN MATCHED THEN - UPDATE SET - entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", - "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a - USING DUAL - ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) - WHEN NOT MATCHED THEN - INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) - WHEN MATCHED THEN - UPDATE SET - weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", - "get_all_nodes": """WITH t0 AS ( - SELECT name AS id, entity_type AS label, entity_type, description, - '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids - FROM lightrag_graph_nodes - WHERE workspace = :workspace - ORDER BY createtime DESC fetch first :limit rows only - ), t1 AS ( - SELECT t0.id, source_chunk_id - FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) - ), t2 AS ( - SELECT t1.id, LISTAGG(t2.content, '\n') content - FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id - GROUP BY t1.id - ) - SELECT t0.id, label, entity_type, description, t2.content - FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", - "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, - t1.weight,t1.DESCRIPTION,t2.content - FROM LIGHTRAG_GRAPH_EDGES t1 - LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id - WHERE t1.workspace=:workspace - order by t1.CREATETIME DESC - fetch first :limit rows only""", - "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, - count(distinct CASE WHEN type='edge' THEN id END) as edges_count - FROM ( - select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph - MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) - UNION - select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph - MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) - )""", - # SQL for deletion - "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})", - "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name", - "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)", - "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph - MATCH (a) - WHERE a.workspace=:workspace AND a.name=:node_id - ACTION DELETE a)""", -} diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 4ff34e1309f9edc36c2f4f28a5a598113aacf675..29bf2e454f5b62fe87f3dc1c7e954ba1fc6d28d7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -9,7 +9,6 @@ import configparser from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -import sys from tenacity import ( retry, retry_if_exception_type, @@ -28,11 +27,6 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger -if sys.platform.startswith("win"): - import asyncio.windows_events - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - import pipmaster as pm if not pm.is_installed("asyncpg"): @@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"): import asyncpg # type: ignore from asyncpg import Pool # type: ignore +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + class PostgreSQLDB: def __init__(self, config: dict[str, Any], **kwargs: Any): @@ -118,6 +115,25 @@ class PostgreSQLDB: ) raise e + # Create index for id column in each table + try: + index_name = f"idx_{k.lower()}_id" + check_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{index_name}' + AND tablename = '{k.lower()}' + """ + index_exists = await self.query(check_index_sql) + + if not index_exists: + create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" + logger.info(f"PostgreSQL, Creating index {index_name} on table {k}") + await self.execute(create_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index on table {k}, Got: {e}" + ) + async def query( self, sql: str, @@ -254,8 +270,6 @@ class PGKVStorage(BaseKVStorage): db: PostgreSQLDB = field(default=None) def __post_init__(self): - namespace_prefix = self.global_config.get("namespace_prefix") - self.base_namespace = self.namespace.replace(namespace_prefix, "") self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): @@ -271,7 +285,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" - sql = SQL_TEMPLATES["get_by_id_" + self.base_namespace] + sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -285,7 +299,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" - sql = SQL_TEMPLATES["get_by_mode_id_" + self.base_namespace] + sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] params = {"workspace": self.db.workspace, mode: mode, "id": id} if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): array_res = await self.db.query(sql, params, multirows=True) @@ -299,7 +313,7 @@ class PGKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.base_namespace].format( + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} @@ -320,7 +334,7 @@ class PGKVStorage(BaseKVStorage): async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.base_namespace] + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) @@ -380,10 +394,85 @@ class PGKVStorage(BaseKVStorage): # PG handles persistence automatically pass - async def drop(self) -> None: + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for deletion: {self.namespace}") + return + + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + + try: + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) + logger.debug( + f"Successfully deleted {len(ids)} records from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting records from {self.namespace}: {e}") + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by cache mode + + Args: + modes (list[str]): List of cache modes to be dropped from storage + + Returns: + bool: True if successful, False otherwise + """ + if not modes: + return False + + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return False + + if table_name != "LIGHTRAG_LLM_CACHE": + return False + + sql = f""" + DELETE FROM {table_name} + WHERE workspace = $1 AND mode = ANY($2) + """ + params = {"workspace": self.db.workspace, "modes": modes} + + logger.info(f"Deleting cache by modes: {modes}") + await self.db.execute(sql, params) + return True + except Exception as e: + logger.error(f"Error deleting cache by modes {modes}: {e}") + return False + + async def drop(self) -> dict[str, str]: """Drop the storage""" - drop_sql = SQL_TEMPLATES["drop_all"] - await self.db.execute(drop_sql) + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -393,8 +482,6 @@ class PGVectorStorage(BaseVectorStorage): def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] - namespace_prefix = self.global_config.get("namespace_prefix") - self.base_namespace = self.namespace.replace(namespace_prefix, "") config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: @@ -523,7 +610,7 @@ class PGVectorStorage(BaseVectorStorage): else: formatted_ids = "NULL" - sql = SQL_TEMPLATES[self.base_namespace].format( + sql = SQL_TEMPLATES[self.namespace].format( embedding_string=embedding_string, doc_ids=formatted_ids ) params = { @@ -552,13 +639,12 @@ class PGVectorStorage(BaseVectorStorage): logger.error(f"Unknown namespace for vector deletion: {self.namespace}") return - ids_list = ",".join([f"'{id}'" for id in ids]) - delete_sql = ( - f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})" - ) + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" try: - await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -690,6 +776,24 @@ class PGVectorStorage(BaseVectorStorage): logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + @final @dataclass @@ -810,6 +914,35 @@ class PGDocStatusStorage(DocStatusStorage): # PG handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for deletion: {self.namespace}") + return + + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + + try: + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) + logger.debug( + f"Successfully deleted {len(ids)} records from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting records from {self.namespace}: {e}") + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Update or insert document status @@ -846,10 +979,23 @@ class PGDocStatusStorage(DocStatusStorage): }, ) - async def drop(self) -> None: + async def drop(self) -> dict[str, str]: """Drop the storage""" - drop_sql = SQL_TEMPLATES["drop_doc_full"] - await self.db.execute(drop_sql) + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} class PGGraphQueryException(Exception): @@ -937,31 +1083,11 @@ class PGGraphStorage(BaseGraphStorage): if v.startswith("[") and v.endswith("]"): if "::vertex" in v: v = v.replace("::vertex", "") - vertexes = json.loads(v) - dl = [] - for vertex in vertexes: - prop = vertex.get("properties") - if not prop: - prop = {} - prop["label"] = PGGraphStorage._decode_graph_label( - prop["node_id"] - ) - dl.append(prop) - d[k] = dl + d[k] = json.loads(v) elif "::edge" in v: v = v.replace("::edge", "") - edges = json.loads(v) - dl = [] - for edge in edges: - dl.append( - ( - vertices[edge["start_id"]], - edge["label"], - vertices[edge["end_id"]], - ) - ) - d[k] = dl + d[k] = json.loads(v) else: print("WARNING: unsupported type") continue @@ -970,32 +1096,19 @@ class PGGraphStorage(BaseGraphStorage): dtype = v.split("::")[-1] v = v.split("::")[0] if dtype == "vertex": - vertex = json.loads(v) - field = vertex.get("properties") - if not field: - field = {} - field["label"] = PGGraphStorage._decode_graph_label( - field["node_id"] - ) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation + d[k] = json.loads(v) elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) + d[k] = json.loads(v) else: - d[k] = ( - json.loads(v) - if isinstance(v, str) and ("{" in v or "[" in v) - else v - ) + try: + d[k] = ( + json.loads(v) + if isinstance(v, str) + and (v.startswith("{") or v.startswith("[")) + else v + ) + except json.JSONDecodeError: + d[k] = v return d @@ -1025,56 +1138,6 @@ class PGGraphStorage(BaseGraphStorage): ) return "{" + ", ".join(props) + "}" - @staticmethod - def _encode_graph_label(label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - label (str): the original label - - Returns: - str: the encoded label - """ - return "x" + label.encode().hex() - - @staticmethod - def _decode_graph_label(encoded_label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - encoded_label (str): the encoded label - - Returns: - str: the decoded label - """ - return bytes.fromhex(encoded_label.removeprefix("x")).decode() - - @staticmethod - def _get_col_name(field: str, idx: int) -> str: - """ - Convert a cypher return field to a pgsql select field - If possible keep the cypher column name, but create a generic name if necessary - - Args: - field (str): a return field from a cypher query to be formatted for pgsql - idx (int): the position of the field in the return statement - - Returns: - str: the field to be used in the pgsql select statement - """ - # remove white space - field = field.strip() - # if an alias is provided for the field, use it - if " as " in field: - return field.split(" as ")[-1].strip() - # if the return value is an unnamed primitive, give it a generic name - if field.isnumeric() or field in ("true", "false", "null"): - return f"column_{idx}" - # otherwise return the value stripping out some common special chars - return field.replace("(", "_").replace(")", "") - async def _query( self, query: str, @@ -1125,10 +1188,10 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = self._encode_graph_label(node_id.strip('"')) + entity_name_label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) RETURN count(n) > 0 AS node_exists $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) @@ -1137,11 +1200,11 @@ class PGGraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) + MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) RETURN COUNT(r) > 0 AS edge_exists $$) AS (edge_exists bool)""" % ( self.graph_name, @@ -1154,30 +1217,31 @@ class PGGraphStorage(BaseGraphStorage): return single_result["edge_exists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - label = self._encode_graph_label(node_id.strip('"')) + """Get node by its label identifier, return only node properties""" + + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) RETURN n $$) AS (n agtype)""" % (self.graph_name, label) record = await self._query(query) if record: node = record[0] - node_dict = node["n"] + node_dict = node["n"]["properties"] return node_dict return None async def node_degree(self, node_id: str) -> int: - label = self._encode_graph_label(node_id.strip('"')) + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"})-[]->(x) + MATCH (n:base {entity_id: "%s"})-[]-(x) RETURN count(x) AS total_edge_count $$) AS (total_edge_count integer)""" % (self.graph_name, label) record = (await self._query(query))[0] if record: edge_count = int(record["total_edge_count"]) - return edge_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -1195,11 +1259,13 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) + """Get edge properties between two nodes""" + + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) + MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) RETURN properties(r) as edge_properties LIMIT 1 $$) AS (edge_properties agtype)""" % ( @@ -1218,11 +1284,11 @@ class PGGraphStorage(BaseGraphStorage): Retrieves all edges (relationships) for a particular node identified by its label. :return: list of dictionaries containing edge information """ - label = self._encode_graph_label(source_node_id.strip('"')) + label = source_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) - OPTIONAL MATCH (n)-[]-(connected) + MATCH (n:base {entity_id: "%s"}) + OPTIONAL MATCH (n)-[]-(connected:base) RETURN n, connected $$) AS (n agtype, connected agtype)""" % ( self.graph_name, @@ -1235,24 +1301,17 @@ class PGGraphStorage(BaseGraphStorage): source_node = record["n"] if record["n"] else None connected_node = record["connected"] if record["connected"] else None - source_label = ( - source_node["node_id"] - if source_node and source_node["node_id"] - else None - ) - target_label = ( - connected_node["node_id"] - if connected_node and connected_node["node_id"] - else None - ) + if ( + source_node + and connected_node + and "properties" in source_node + and "properties" in connected_node + ): + source_label = source_node["properties"].get("entity_id") + target_label = connected_node["properties"].get("entity_id") - if source_label and target_label: - edges.append( - ( - self._decode_graph_label(source_label), - self._decode_graph_label(target_label), - ) - ) + if source_label and target_label: + edges.append((source_label, target_label)) return edges @@ -1262,24 +1321,36 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - label = self._encode_graph_label(node_id.strip('"')) - properties = node_data + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + if "entity_id" not in node_data: + raise ValueError( + "PostgreSQL: node properties must contain an 'entity_id' field" + ) + + label = node_id.strip('"') + properties = self._format_properties(node_data) query = """SELECT * FROM cypher('%s', $$ - MERGE (n:Entity {node_id: "%s"}) + MERGE (n:base {entity_id: "%s"}) SET n += %s RETURN n $$) AS (n agtype)""" % ( self.graph_name, label, - self._format_properties(properties), + properties, ) try: await self._query(query, readonly=False, upsert=True) - except Exception as e: - logger.error("POSTGRES, Error during upsert: {%s}", e) + except Exception: + logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`") raise @retry( @@ -1298,14 +1369,14 @@ class PGGraphStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): dictionary of properties to set on the edge """ - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) - edge_properties = edge_data + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') + edge_properties = self._format_properties(edge_data) query = """SELECT * FROM cypher('%s', $$ - MATCH (source:Entity {node_id: "%s"}) + MATCH (source:base {entity_id: "%s"}) WITH source - MATCH (target:Entity {node_id: "%s"}) + MATCH (target:base {entity_id: "%s"}) MERGE (source)-[r:DIRECTED]->(target) SET r += %s RETURN r @@ -1313,14 +1384,16 @@ class PGGraphStorage(BaseGraphStorage): self.graph_name, src_label, tgt_label, - self._format_properties(edge_properties), + edge_properties, ) try: await self._query(query, readonly=False, upsert=True) - except Exception as e: - logger.error("Error during edge upsert: {%s}", e) + except Exception: + logger.error( + f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`" + ) raise async def _node2vec_embed(self): @@ -1333,10 +1406,10 @@ class PGGraphStorage(BaseGraphStorage): Args: node_id (str): The ID of the node to delete. """ - label = self._encode_graph_label(node_id.strip('"')) + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) DETACH DELETE n $$) AS (n agtype)""" % (self.graph_name, label) @@ -1353,14 +1426,12 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids (list[str]): A list of node IDs to remove. """ - encoded_node_ids = [ - self._encode_graph_label(node_id.strip('"')) for node_id in node_ids - ] - node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) + node_ids = [node_id.strip('"') for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) query = """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - WHERE n.node_id IN [%s] + MATCH (n:base) + WHERE n.entity_id IN [%s] DETACH DELETE n $$) AS (n agtype)""" % (self.graph_name, node_id_list) @@ -1377,26 +1448,21 @@ class PGGraphStorage(BaseGraphStorage): Args: edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ - encoded_edges = [ - ( - self._encode_graph_label(src.strip('"')), - self._encode_graph_label(tgt.strip('"')), - ) - for src, tgt in edges - ] - edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges]) + for source, target in edges: + src_label = source.strip('"') + tgt_label = target.strip('"') - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:Entity)-[r]->(b:Entity) - WHERE [a.node_id, b.node_id] IN [%s] - DELETE r - $$) AS (r agtype)""" % (self.graph_name, edge_list) + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) + DELETE r + $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) - try: - await self._query(query, readonly=False) - except Exception as e: - logger.error("Error during edge removal: {%s}", e) - raise + try: + await self._query(query, readonly=False) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def get_all_labels(self) -> list[str]: """ @@ -1407,15 +1473,16 @@ class PGGraphStorage(BaseGraphStorage): """ query = ( """SELECT * FROM cypher('%s', $$ - MATCH (n:Entity) - RETURN DISTINCT n.node_id AS label + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY n.entity_id $$) AS (label text)""" % self.graph_name ) results = await self._query(query) - labels = [self._decode_graph_label(result["label"]) for result in results] - + labels = [result["label"] for result in results] return labels async def embed_nodes( @@ -1437,105 +1504,135 @@ class PGGraphStorage(BaseGraphStorage): return await embed_func() async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ - Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Args: - node_label (str): The label of the node to start from. If "*", the entire graph is returned. - max_depth (int): The maximum depth to traverse from the starting node. + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed) Returns: - KnowledgeGraph: The retrieved subgraph. + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ - MAX_GRAPH_NODES = 1000 - - # Build the query based on whether we want the full graph or a specific subgraph. + # First, count the total number of nodes that would be returned without limit + if node_label == "*": + count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + RETURN count(distinct n) AS total_nodes + $$) AS (total_nodes bigint)""" + else: + strip_label = node_label.strip('"') + count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base {{entity_id: "{strip_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN count(distinct m) AS total_nodes + $$) AS (total_nodes bigint)""" + + count_result = await self._query(count_query) + total_nodes = count_result[0]["total_nodes"] if count_result else 0 + is_truncated = total_nodes > max_nodes + + # Now get the actual data with limit if node_label == "*": query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:Entity) - OPTIONAL MATCH (n)-[r]->(m:Entity) - RETURN n, r, m - LIMIT {MAX_GRAPH_NODES} - $$) AS (n agtype, r agtype, m agtype)""" + MATCH (n:base) + OPTIONAL MATCH (n)-[r]->(target:base) + RETURN collect(distinct n) AS n, collect(distinct r) AS r + LIMIT {max_nodes} + $$) AS (n agtype, r agtype)""" else: - encoded_label = self._encode_graph_label(node_label.strip('"')) + strip_label = node_label.strip('"') query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:Entity {{node_id: "{encoded_label}"}}) - OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) - RETURN nodes(p) AS nodes, relationships(p) AS relationships - LIMIT {MAX_GRAPH_NODES} - $$) AS (nodes agtype, relationships agtype)""" + MATCH (n:base {{entity_id: "{strip_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN nodes(p) AS n, relationships(p) AS r + LIMIT {max_nodes} + $$) AS (n agtype, r agtype)""" results = await self._query(query) - nodes = {} - edges = [] - unique_edge_ids = set() - - def add_node(node_data: dict): - node_id = self._decode_graph_label(node_data["node_id"]) - if node_id not in nodes: - nodes[node_id] = node_data - - def add_edge(edge_data: list): - src_id = self._decode_graph_label(edge_data[0]["node_id"]) - tgt_id = self._decode_graph_label(edge_data[2]["node_id"]) - edge_key = f"{src_id},{tgt_id}" - if edge_key not in unique_edge_ids: - unique_edge_ids.add(edge_key) - edges.append( - ( - edge_key, - src_id, - tgt_id, - {"source": edge_data[0], "target": edge_data[2]}, + # Process the query results with deduplication by node and edge IDs + nodes_dict = {} + edges_dict = {} + for result in results: + # Handle single node cases + if result.get("n") and isinstance(result["n"], dict): + node_id = str(result["n"]["id"]) + if node_id not in nodes_dict: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[result["n"]["properties"]["entity_id"]], + properties=result["n"]["properties"], ) - ) + # Handle node list cases + elif result.get("n") and isinstance(result["n"], list): + for node in result["n"]: + if isinstance(node, dict) and "id" in node: + node_id = str(node["id"]) + if node_id not in nodes_dict and "properties" in node: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node["properties"]["entity_id"]], + properties=node["properties"], + ) - # Process the query results. - if node_label == "*": - for result in results: - if result.get("n"): - add_node(result["n"]) - if result.get("m"): - add_node(result["m"]) - if result.get("r"): - add_edge(result["r"]) - else: - for result in results: - for node in result.get("nodes", []): - add_node(node) - for edge in result.get("relationships", []): - add_edge(edge) + # Handle single edge cases + if result.get("r") and isinstance(result["r"], dict): + edge_id = str(result["r"]["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(result["r"]["start_id"]), + target=str(result["r"]["end_id"]), + properties=result["r"]["properties"], + ) + # Handle edge list cases + elif result.get("r") and isinstance(result["r"], list): + for edge in result["r"]: + if isinstance(edge, dict) and "id" in edge: + edge_id = str(edge["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(edge["start_id"]), + target=str(edge["end_id"]), + properties=edge["properties"], + ) - # Construct and return the KnowledgeGraph. + # Construct and return the KnowledgeGraph with deduplicated nodes and edges kg = KnowledgeGraph( - nodes=[ - KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data) - for node_id, node_data in nodes.items() - ], - edges=[ - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=src, - target=tgt, - properties=props, - ) - for edge_id, src, tgt, props in edges - ], + nodes=list(nodes_dict.values()), + edges=list(edges_dict.values()), + is_truncated=is_truncated, ) + logger.info( + f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + ) return kg - async def drop(self) -> None: + async def drop(self) -> dict[str, str]: """Drop the storage""" - drop_sql = SQL_TEMPLATES["drop_vdb_entity"] - await self.db.execute(drop_sql) - drop_sql = SQL_TEMPLATES["drop_vdb_relation"] - await self.db.execute(drop_sql) + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" + + await self._query(drop_query, readonly=False) + return {"status": "success", "message": "graph data dropped"} + except Exception as e: + logger.error(f"Error dropping graph: {e}") + return {"status": "error", "message": str(e)} NAMESPACE_TABLE_MAP = { @@ -1693,6 +1790,7 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time = CURRENT_TIMESTAMP """, + # SQL for VectorStorage "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector, chunk_ids, file_path) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7) @@ -1716,45 +1814,6 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time = CURRENT_TIMESTAMP """, - # SQL for VectorStorage - # "entities": """SELECT entity_name FROM - # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - # FROM LIGHTRAG_VDB_ENTITY where workspace=$1) - # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - # """, - # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM - # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - # FROM LIGHTRAG_VDB_RELATION where workspace=$1) - # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - # """, - # "chunks": """SELECT id FROM - # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) - # WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - # """, - # DROP tables - "drop_all": """ - DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE; - DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE; - DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE; - DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE; - DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE; - """, - "drop_doc_full": """ - DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE; - """, - "drop_doc_chunks": """ - DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE; - """, - "drop_llm_cache": """ - DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE; - """, - "drop_vdb_entity": """ - DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE; - """, - "drop_vdb_relation": """ - DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE; - """, "relationships": """ WITH relevant_chunks AS ( SELECT id as chunk_id @@ -1795,9 +1854,9 @@ SQL_TEMPLATES = { FROM LIGHTRAG_DOC_CHUNKS WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}]) ) - SELECT id FROM + SELECT id, content, file_path FROM ( - SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + SELECT id, content, file_path, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance FROM LIGHTRAG_DOC_CHUNKS where workspace=$1 AND id IN (SELECT chunk_id FROM relevant_chunks) @@ -1806,4 +1865,8 @@ SQL_TEMPLATES = { ORDER BY distance DESC LIMIT $3 """, + # DROP tables + "drop_specifiy_table_workspace": """ + DELETE FROM {table_name} WHERE workspace=$1 + """, } diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e32c43351d433d47dd9704d80b70d898a79ddee7..d758ca5c419914972deb23e4b2c6062e214d5c67 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -8,17 +8,15 @@ import uuid from ..utils import logger from ..base import BaseVectorStorage import configparser - - -config = configparser.ConfigParser() -config.read("config.ini", "utf-8") - import pipmaster as pm if not pm.is_installed("qdrant-client"): pm.install("qdrant-client") -from qdrant_client import QdrantClient, models +from qdrant_client import QdrantClient, models # type: ignore + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") def compute_mdhash_id_for_qdrant( @@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error searching for prefix '{prefix}': {e}") return [] + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + try: + # Convert to Qdrant compatible ID + qdrant_id = compute_mdhash_id_for_qdrant(id) + + # Retrieve the point by ID + result = self._client.retrieve( + collection_name=self.namespace, + ids=[qdrant_id], + with_payload=True, + ) + + if not result: + return None + + return result[0].payload + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + try: + # Convert to Qdrant compatible IDs + qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] + + # Retrieve the points by IDs + results = self._client.retrieve( + collection_name=self.namespace, + ids=qdrant_ids, + with_payload=True, + ) + + return [point.payload for point in results] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + + async def drop(self) -> dict[str, str]: + """Drop all vector data from storage and clean up resources + + This method will delete all data from the Qdrant collection. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + """ + try: + # Delete the collection and recreate it + if self._client.collection_exists(self.namespace): + self._client.delete_collection(self.namespace) + + # Recreate the collection + QdrantVectorDBStorage.create_collection_if_not_exist( + self._client, + self.namespace, + vectors_config=models.VectorParams( + size=self.embedding_func.embedding_dim, + distance=models.Distance.COSINE, + ), + ) + + logger.info( + f"Process {os.getpid()} drop Qdrant collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 3feb4985b00061dae96a5ddf14fa74e267d1c418..4452d55f876724fe4a8c4b7b5026bbc397c923ff 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -8,8 +8,8 @@ if not pm.is_installed("redis"): pm.install("redis") # aioredis is a depricated library, replaced with redis -from redis.asyncio import Redis -from lightrag.utils import logger, compute_mdhash_id +from redis.asyncio import Redis # type: ignore +from lightrag.utils import logger from lightrag.base import BaseKVStorage import json @@ -84,66 +84,50 @@ class RedisKVStorage(BaseKVStorage): f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" ) - async def delete_entity(self, entity_name: str) -> None: - """Delete an entity by name + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by by cache mode + + Importance notes for Redis storage: + 1. This will immediately delete the specified cache modes from Redis Args: - entity_name: Name of the entity to delete + modes (list[str]): List of cache mode to be drop from storage + + Returns: + True: if the cache drop successfully + False: if the cache drop failed """ + if not modes: + return False try: - entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" - ) - - # Delete the entity - result = await self._redis.delete(f"{self.namespace}:{entity_id}") + await self.delete(modes) + return True + except Exception: + return False - if result: - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") + async def drop(self) -> dict[str, str]: + """Drop the storage by removing all keys under the current namespace. - async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations associated with an entity - - Args: - entity_name: Name of the entity whose relations should be deleted + Returns: + dict[str, str]: Status of the operation with keys 'status' and 'message' """ try: - # Get all keys in this namespace - cursor = 0 - relation_keys = [] - pattern = f"{self.namespace}:*" + keys = await self._redis.keys(f"{self.namespace}:*") - while True: - cursor, keys = await self._redis.scan(cursor, match=pattern) - - # For each key, get the value and check if it's related to entity_name + if keys: + pipe = self._redis.pipeline() for key in keys: - value = await self._redis.get(key) - if value: - data = json.loads(value) - # Check if this is a relation involving the entity - if ( - data.get("src_id") == entity_name - or data.get("tgt_id") == entity_name - ): - relation_keys.append(key) - - # Exit loop when cursor returns to 0 - if cursor == 0: - break - - # Delete the relation keys - if relation_keys: - deleted = await self._redis.delete(*relation_keys) - logger.debug(f"Deleted {deleted} relations for {entity_name}") + pipe.delete(key) + results = await pipe.execute() + deleted_count = sum(results) + + logger.info(f"Dropped {deleted_count} keys from {self.namespace}") + return {"status": "success", "message": f"{deleted_count} keys dropped"} else: - logger.debug(f"No relations found for entity {entity_name}") + logger.info(f"No keys found to drop in {self.namespace}") + return {"status": "success", "message": "no keys to drop"} except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") + logger.error(f"Error dropping keys from {self.namespace}: {e}") + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 0982c9140e5a968e3038da3c257313b2e5a1e51b..e57357de34bfecac994eae2abd71c46f4efe550c 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"): if not pm.is_installed("sqlalchemy"): pm.install("sqlalchemy") -from sqlalchemy import create_engine, text +from sqlalchemy import create_engine, text # type: ignore class TiDB: @@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage): # Ti handles persistence automatically pass + async def delete(self, ids: list[str]) -> None: + """Delete records with specified IDs from the storage. + + Args: + ids: List of record IDs to be deleted + """ + if not ids: + return + + try: + table_name = namespace_to_table_name(self.namespace) + id_field = namespace_to_id(self.namespace) + + if not table_name or not id_field: + logger.error(f"Unknown namespace for deletion: {self.namespace}") + return + + ids_list = ",".join([f"'{id}'" for id in ids]) + delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" + + await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + logger.info( + f"Successfully deleted {len(ids)} records from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error deleting records from {self.namespace}: {e}") + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by cache mode + + Args: + modes (list[str]): List of cache modes to be dropped from storage + + Returns: + bool: True if successful, False otherwise + """ + if not modes: + return False + + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return False + + if table_name != "LIGHTRAG_LLM_CACHE": + return False + + # 构建MySQL风格的IN查询 + modes_list = ", ".join([f"'{mode}'" for mode in modes]) + sql = f""" + DELETE FROM {table_name} + WHERE workspace = :workspace + AND mode IN ({modes_list}) + """ + + logger.info(f"Deleting cache by modes: {modes}") + await self.db.execute(sql, {"workspace": self.db.workspace}) + return True + except Exception as e: + logger.error(f"Error deleting cache by modes {modes}: {e}") + return False + + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + @final @dataclass @@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs from the storage. + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + id_field = namespace_to_id(self.namespace) + + if not table_name or not id_field: + logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + return + + ids_list = ",".join([f"'{id}'" for id in ids]) + delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" + + try: + await self.db.execute(delete_sql, {"workspace": self.db.workspace}) + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + async def delete_entity(self, entity_name: str) -> None: - raise NotImplementedError + """Delete an entity by its name from the vector storage. + + Args: + entity_name: The name of the entity to delete + """ + try: + # Construct SQL to delete the entity + delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES + WHERE workspace = :workspace AND name = :entity_name""" + + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - raise NotImplementedError + """Delete all relations associated with an entity. + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Delete relations where the entity is either the source or target + delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES + WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)""" + + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") async def index_done_callback(self) -> None: # Ti handles persistence automatically pass + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]: """Search for records with IDs starting with a specific prefix. @@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage): # Ti handles persistence automatically pass + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + drop_sql = """ + DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace; + DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace; + """ + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "graph data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + async def delete_node(self, node_id: str) -> None: """Delete a node and all its related edges @@ -1129,4 +1296,6 @@ SQL_TEMPLATES = { FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace """, + # Drop tables + "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace", } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index dece78b4e0c2a8806a2a3f0e0c944c9b5da48106..50ee079ad3ca514dff513434f260a53422aedd68 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -13,7 +13,6 @@ import pandas as pd from lightrag.kg import ( - STORAGE_ENV_REQUIREMENTS, STORAGES, verify_storage_implementation, ) @@ -230,6 +229,7 @@ class LightRAG: vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) """Additional parameters for vector database storage.""" + # TODO:deprecated, remove in the future, use WORKSPACE instead namespace_prefix: str = field(default="") """Prefix for namespacing stored data across different environments.""" @@ -510,36 +510,22 @@ class LightRAG: self, node_label: str, max_depth: int = 3, - min_degree: int = 0, - inclusive: bool = False, + max_nodes: int = 1000, ) -> KnowledgeGraph: """Get knowledge graph for a given label Args: node_label (str): Label to get knowledge graph for max_depth (int): Maximum depth of graph - min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0. - inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False. + max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000. Returns: KnowledgeGraph: Knowledge graph containing nodes and edges """ - # get params supported by get_knowledge_graph of specified storage - import inspect - storage_params = inspect.signature( - self.chunk_entity_relation_graph.get_knowledge_graph - ).parameters - - kwargs = {"node_label": node_label, "max_depth": max_depth} - - if "min_degree" in storage_params and min_degree > 0: - kwargs["min_degree"] = min_degree - - if "inclusive" in storage_params: - kwargs["inclusive"] = inclusive - - return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs) + return await self.chunk_entity_relation_graph.get_knowledge_graph( + node_label, max_depth, max_nodes + ) def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] @@ -1449,6 +1435,7 @@ class LightRAG: loop = always_get_an_event_loop() return loop.run_until_complete(self.adelete_by_entity(entity_name)) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def adelete_by_entity(self, entity_name: str) -> None: try: await self.entities_vdb.delete_entity(entity_name) @@ -1486,6 +1473,7 @@ class LightRAG: self.adelete_by_relation(source_entity, target_entity) ) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None: """Asynchronously delete a relation between two entities. @@ -1494,6 +1482,7 @@ class LightRAG: target_entity: Name of the target entity """ try: + # TODO: check if has_edge function works on reverse relation # Check if the relation exists edge_exists = await self.chunk_entity_relation_graph.has_edge( source_entity, target_entity @@ -1554,6 +1543,7 @@ class LightRAG: """ return await self.doc_status.get_docs_by_status(status) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def adelete_by_doc_id(self, doc_id: str) -> None: """Delete a document and all its related data @@ -1586,6 +1576,8 @@ class LightRAG: chunk_ids = set(related_chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") + # TODO: self.entities_vdb.client_storage only works for local storage, need to fix this + # 3. Before deleting, check the related entities and relationships for these chunks for chunk_id in chunk_ids: # Check entities @@ -1857,24 +1849,6 @@ class LightRAG: return result - def check_storage_env_vars(self, storage_name: str) -> None: - """Check if all required environment variables for storage implementation exist - - Args: - storage_name: Storage implementation name - - Raises: - ValueError: If required environment variables are missing - """ - required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) - missing_vars = [var for var in required_vars if var not in os.environ] - - if missing_vars: - raise ValueError( - f"Storage implementation '{storage_name}' requires the following " - f"environment variables: {', '.join(missing_vars)}" - ) - async def aclear_cache(self, modes: list[str] | None = None) -> None: """Clear cache data from the LLM response cache storage. @@ -1906,12 +1880,18 @@ class LightRAG: try: # Reset the cache storage for specified mode if modes: - await self.llm_response_cache.delete(modes) - logger.info(f"Cleared cache for modes: {modes}") + success = await self.llm_response_cache.drop_cache_by_modes(modes) + if success: + logger.info(f"Cleared cache for modes: {modes}") + else: + logger.warning(f"Failed to clear cache for modes: {modes}") else: # Clear all modes - await self.llm_response_cache.delete(valid_modes) - logger.info("Cleared all cache") + success = await self.llm_response_cache.drop_cache_by_modes(valid_modes) + if success: + logger.info("Cleared all cache") + else: + logger.warning("Failed to clear all cache") await self.llm_response_cache.index_done_callback() @@ -1922,6 +1902,7 @@ class LightRAG: """Synchronous version of aclear_cache.""" return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes)) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def aedit_entity( self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True ) -> dict[str, Any]: @@ -2134,6 +2115,7 @@ class LightRAG: ] ) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def aedit_relation( self, source_entity: str, target_entity: str, updated_data: dict[str, Any] ) -> dict[str, Any]: @@ -2448,6 +2430,7 @@ class LightRAG: self.acreate_relation(source_entity, target_entity, relation_data) ) + # TODO: Lock all KG relative DB to esure consistency across multiple processes async def amerge_entities( self, source_entities: list[str], diff --git a/lightrag/operate.py b/lightrag/operate.py index 088ca6175dd630481b72c08233c5af5d50a87b5f..634326a735a24c471bdef40953e5e1599eb7806c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -25,7 +25,6 @@ from .utils import ( CacheData, statistic_data, get_conversation_turns, - verbose_debug, ) from .base import ( BaseGraphStorage, @@ -441,6 +440,13 @@ async def extract_entities( processed_chunks = 0 total_chunks = len(ordered_chunks) + total_entities_count = 0 + total_relations_count = 0 + + # Get lock manager from shared storage + from .kg.shared_storage import get_graph_db_lock + + graph_db_lock = get_graph_db_lock(enable_logging=False) async def _user_llm_func_with_cache( input_text: str, history_messages: list[dict[str, str]] = None @@ -539,7 +545,7 @@ async def extract_entities( chunk_key_dp (tuple[str, TextChunkSchema]): ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) """ - nonlocal processed_chunks + nonlocal processed_chunks, total_entities_count, total_relations_count chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] @@ -597,102 +603,74 @@ async def extract_entities( async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - return dict(maybe_nodes), dict(maybe_edges) - - tasks = [_process_single_content(c) for c in ordered_chunks] - results = await asyncio.gather(*tasks) - maybe_nodes = defaultdict(list) - maybe_edges = defaultdict(list) - for m_nodes, m_edges in results: - for k, v in m_nodes.items(): - maybe_nodes[k].extend(v) - for k, v in m_edges.items(): - maybe_edges[tuple(sorted(k))].extend(v) + # Use graph database lock to ensure atomic merges and updates + chunk_entities_data = [] + chunk_relationships_data = [] - from .kg.shared_storage import get_graph_db_lock - - graph_db_lock = get_graph_db_lock(enable_logging=False) - - # Ensure that nodes and edges are merged and upserted atomically - async with graph_db_lock: - all_entities_data = await asyncio.gather( - *[ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ) - - all_relationships_data = await asyncio.gather( - *[ - _merge_edges_then_upsert( - k[0], k[1], v, knowledge_graph_inst, global_config + async with graph_db_lock: + # Process and update entities + for entity_name, entities in maybe_nodes.items(): + entity_data = await _merge_nodes_then_upsert( + entity_name, entities, knowledge_graph_inst, global_config ) - for k, v in maybe_edges.items() - ] - ) + chunk_entities_data.append(entity_data) + + # Process and update relationships + for edge_key, edges in maybe_edges.items(): + # Ensure edge direction consistency + sorted_edge_key = tuple(sorted(edge_key)) + edge_data = await _merge_edges_then_upsert( + sorted_edge_key[0], + sorted_edge_key[1], + edges, + knowledge_graph_inst, + global_config, + ) + chunk_relationships_data.append(edge_data) + + # Update vector database (within the same lock to ensure atomicity) + if entity_vdb is not None and chunk_entities_data: + data_for_vdb = { + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "entity_name": dp["entity_name"], + "entity_type": dp["entity_type"], + "content": f"{dp['entity_name']}\n{dp['description']}", + "source_id": dp["source_id"], + "file_path": dp.get("file_path", "unknown_source"), + } + for dp in chunk_entities_data + } + await entity_vdb.upsert(data_for_vdb) + + if relationships_vdb is not None and chunk_relationships_data: + data_for_vdb = { + compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { + "src_id": dp["src_id"], + "tgt_id": dp["tgt_id"], + "keywords": dp["keywords"], + "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", + "source_id": dp["source_id"], + "file_path": dp.get("file_path", "unknown_source"), + } + for dp in chunk_relationships_data + } + await relationships_vdb.upsert(data_for_vdb) - if not (all_entities_data or all_relationships_data): - log_message = "Didn't extract any entities and relationships." - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - return + # Update counters + total_entities_count += len(chunk_entities_data) + total_relations_count += len(chunk_relationships_data) - if not all_entities_data: - log_message = "Didn't extract any entities" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - if not all_relationships_data: - log_message = "Didn't extract any relationships" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Handle all chunks in parallel + tasks = [_process_single_content(c) for c in ordered_chunks] + await asyncio.gather(*tasks) - log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)" + log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - verbose_debug( - f"New entities:{all_entities_data}, relationships:{all_relationships_data}" - ) - verbose_debug(f"New relationships:{all_relationships_data}") - - if entity_vdb is not None: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "entity_name": dp["entity_name"], - "entity_type": dp["entity_type"], - "content": f"{dp['entity_name']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in all_entities_data - } - await entity_vdb.upsert(data_for_vdb) - - if relationships_vdb is not None: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "keywords": dp["keywords"], - "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in all_relationships_data - } - await relationships_vdb.upsert(data_for_vdb) async def kg_query( @@ -1367,7 +1345,9 @@ async def _get_node_data( text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"], t["file_path"]]) + text_units_section_list.append( + [i, t["content"], t.get("file_path", "unknown_source")] + ) text_units_context = list_of_list_to_csv(text_units_section_list) return entities_context, relations_context, text_units_context diff --git a/lightrag/types.py b/lightrag/types.py index 5e3d2948ab0933833f76ab8d33f84734782c8855..a18f2d3cd81a6ad2f9242ef5d103804b32dba4ab 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel): class KnowledgeGraph(BaseModel): nodes: list[KnowledgeGraphNode] = [] edges: list[KnowledgeGraphEdge] = [] + is_truncated: bool = False diff --git a/lightrag_webui/src/AppRouter.tsx b/lightrag_webui/src/AppRouter.tsx index 9aec0a14feb9d19a3ac49846e7b9f0e6a8d8d328..e7130ad6dcd813f1ac93b45163f2617f1ae9dbdb 100644 --- a/lightrag_webui/src/AppRouter.tsx +++ b/lightrag_webui/src/AppRouter.tsx @@ -80,7 +80,12 @@ const AppRouter = () => { - + ) diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index 364ecb44b55e13c78398e65d6d2f0d86e1fe2b20..bf208f8b2bb034532e105d76f8e2897d6f95a30e 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -3,6 +3,7 @@ import { backendBaseUrl } from '@/lib/constants' import { errorMessage } from '@/lib/utils' import { useSettingsStore } from '@/stores/settings' import { navigationService } from '@/services/navigation' +import { useAuthStore } from '@/stores/state' // Types export type LightragNodeType = { @@ -46,6 +47,8 @@ export type LightragStatus = { api_version?: string auth_mode?: 'enabled' | 'disabled' pipeline_busy: boolean + webui_title?: string + webui_description?: string } export type LightragDocumentsScanProgress = { @@ -140,6 +143,8 @@ export type AuthStatusResponse = { message?: string core_version?: string api_version?: string + webui_title?: string + webui_description?: string } export type PipelineStatusResponse = { @@ -163,6 +168,8 @@ export type LoginResponse = { message?: string // Optional message core_version?: string api_version?: string + webui_title?: string + webui_description?: string } export const InvalidApiKeyError = 'Invalid API Key' @@ -221,9 +228,9 @@ axiosInstance.interceptors.response.use( export const queryGraphs = async ( label: string, maxDepth: number, - minDegree: number + maxNodes: number ): Promise => { - const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`) + const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`) return response.data } @@ -382,6 +389,14 @@ export const clearDocuments = async (): Promise => { return response.data } +export const clearCache = async (modes?: string[]): Promise<{ + status: 'success' | 'fail' + message: string +}> => { + const response = await axiosInstance.post('/documents/clear_cache', { modes }) + return response.data +} + export const getAuthStatus = async (): Promise => { try { // Add a timeout to the request to prevent hanging @@ -411,12 +426,26 @@ export const getAuthStatus = async (): Promise => { // For unconfigured auth, ensure we have an access token if (!response.data.auth_configured) { if (response.data.access_token && typeof response.data.access_token === 'string') { + // Update custom title if available + if ('webui_title' in response.data || 'webui_description' in response.data) { + useAuthStore.getState().setCustomTitle( + 'webui_title' in response.data ? (response.data.webui_title ?? null) : null, + 'webui_description' in response.data ? (response.data.webui_description ?? null) : null + ); + } return response.data; } else { console.warn('Auth not configured but no valid access token provided'); } } else { // For configured auth, just return the data + // Update custom title if available + if ('webui_title' in response.data || 'webui_description' in response.data) { + useAuthStore.getState().setCustomTitle( + 'webui_title' in response.data ? (response.data.webui_title ?? null) : null, + 'webui_description' in response.data ? (response.data.webui_description ?? null) : null + ); + } return response.data; } } @@ -455,5 +484,13 @@ export const loginToServer = async (username: string, password: string): Promise } }); + // Update custom title if available + if ('webui_title' in response.data || 'webui_description' in response.data) { + useAuthStore.getState().setCustomTitle( + 'webui_title' in response.data ? (response.data.webui_title ?? null) : null, + 'webui_description' in response.data ? (response.data.webui_description ?? null) : null + ); + } + return response.data; } diff --git a/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx b/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx index cc11ac5dde0431bbf7650701133a7f803fc46292..bad2978877012783e9323c864f85eb98bfa35420 100644 --- a/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +++ b/lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx @@ -1,4 +1,4 @@ -import { useState, useCallback } from 'react' +import { useState, useCallback, useEffect } from 'react' import Button from '@/components/ui/Button' import { Dialog, @@ -6,32 +6,88 @@ import { DialogDescription, DialogHeader, DialogTitle, - DialogTrigger + DialogTrigger, + DialogFooter } from '@/components/ui/Dialog' +import Input from '@/components/ui/Input' +import Checkbox from '@/components/ui/Checkbox' import { toast } from 'sonner' import { errorMessage } from '@/lib/utils' -import { clearDocuments } from '@/api/lightrag' +import { clearDocuments, clearCache } from '@/api/lightrag' -import { EraserIcon } from 'lucide-react' +import { EraserIcon, AlertTriangleIcon } from 'lucide-react' import { useTranslation } from 'react-i18next' -export default function ClearDocumentsDialog() { +// 简单的Label组件 +const Label = ({ + htmlFor, + className, + children, + ...props +}: React.LabelHTMLAttributes) => ( + +) + +interface ClearDocumentsDialogProps { + onDocumentsCleared?: () => Promise +} + +export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) { const { t } = useTranslation() const [open, setOpen] = useState(false) + const [confirmText, setConfirmText] = useState('') + const [clearCacheOption, setClearCacheOption] = useState(false) + const isConfirmEnabled = confirmText.toLowerCase() === 'yes' + + // 重置状态当对话框关闭时 + useEffect(() => { + if (!open) { + setConfirmText('') + setClearCacheOption(false) + } + }, [open]) const handleClear = useCallback(async () => { + if (!isConfirmEnabled) return + try { const result = await clearDocuments() - if (result.status === 'success') { - toast.success(t('documentPanel.clearDocuments.success')) - setOpen(false) - } else { + + if (result.status !== 'success') { toast.error(t('documentPanel.clearDocuments.failed', { message: result.message })) + setConfirmText('') + return } + + toast.success(t('documentPanel.clearDocuments.success')) + + if (clearCacheOption) { + try { + await clearCache() + toast.success(t('documentPanel.clearDocuments.cacheCleared')) + } catch (cacheErr) { + toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) })) + } + } + + // Refresh document list if provided + if (onDocumentsCleared) { + onDocumentsCleared().catch(console.error) + } + + // 所有操作成功后关闭对话框 + setOpen(false) } catch (err) { toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) })) + setConfirmText('') } - }, [setOpen, t]) + }, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared]) return ( @@ -42,12 +98,58 @@ export default function ClearDocumentsDialog() { e.preventDefault()}> - {t('documentPanel.clearDocuments.title')} - {t('documentPanel.clearDocuments.confirm')} + + + {t('documentPanel.clearDocuments.title')} + + +
+ {t('documentPanel.clearDocuments.warning')} +
+
+ {t('documentPanel.clearDocuments.confirm')} +
+
- + +
+
+ + ) => setConfirmText(e.target.value)} + placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')} + className="w-full" + /> +
+ +
+ setClearCacheOption(checked === true)} + /> + +
+
+ + + + +
) diff --git a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx index 977c403091c881e21a41ace70dd4af697ab86712..5785a7d32ade7ee7a19b428b4ccdd87561ca267d 100644 --- a/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx +++ b/lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx @@ -17,7 +17,11 @@ import { uploadDocument } from '@/api/lightrag' import { UploadIcon } from 'lucide-react' import { useTranslation } from 'react-i18next' -export default function UploadDocumentsDialog() { +interface UploadDocumentsDialogProps { + onDocumentsUploaded?: () => Promise +} + +export default function UploadDocumentsDialog({ onDocumentsUploaded }: UploadDocumentsDialogProps) { const { t } = useTranslation() const [open, setOpen] = useState(false) const [isUploading, setIsUploading] = useState(false) @@ -55,6 +59,7 @@ export default function UploadDocumentsDialog() { const handleDocumentsUpload = useCallback( async (filesToUpload: File[]) => { setIsUploading(true) + let hasSuccessfulUpload = false // Only clear errors for files that are being uploaded, keep errors for rejected files setFileErrors(prev => { @@ -101,6 +106,9 @@ export default function UploadDocumentsDialog() { ...prev, [file.name]: result.message })) + } else { + // Mark that we had at least one successful upload + hasSuccessfulUpload = true } } catch (err) { console.error(`Upload failed for ${file.name}:`, err) @@ -142,6 +150,16 @@ export default function UploadDocumentsDialog() { } else { toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId }) } + + // Only update if at least one file was uploaded successfully + if (hasSuccessfulUpload) { + // Refresh document list + if (onDocumentsUploaded) { + onDocumentsUploaded().catch(err => { + console.error('Error refreshing documents:', err) + }) + } + } } catch (err) { console.error('Unexpected error during upload:', err) toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId }) @@ -149,7 +167,7 @@ export default function UploadDocumentsDialog() { setIsUploading(false) } }, - [setIsUploading, setProgresses, setFileErrors, t] + [setIsUploading, setProgresses, setFileErrors, t, onDocumentsUploaded] ) return ( diff --git a/lightrag_webui/src/components/graph/Settings.tsx b/lightrag_webui/src/components/graph/Settings.tsx index 1989a01e8422b3d481d6567adf67186fc49ca97e..e4085cefb397e7f5b3105668cfa1d9cb3ea4decf 100644 --- a/lightrag_webui/src/components/graph/Settings.tsx +++ b/lightrag_webui/src/components/graph/Settings.tsx @@ -8,7 +8,7 @@ import Input from '@/components/ui/Input' import { controlButtonVariant } from '@/lib/constants' import { useSettingsStore } from '@/stores/settings' -import { SettingsIcon } from 'lucide-react' +import { SettingsIcon, Undo2 } from 'lucide-react' import { useTranslation } from 'react-i18next'; /** @@ -44,14 +44,17 @@ const LabeledNumberInput = ({ onEditFinished, label, min, - max + max, + defaultValue }: { value: number onEditFinished: (value: number) => void label: string min: number max?: number + defaultValue?: number }) => { + const { t } = useTranslation(); const [currentValue, setCurrentValue] = useState(value) const onValueChange = useCallback( @@ -81,6 +84,13 @@ const LabeledNumberInput = ({ } }, [value, currentValue, onEditFinished]) + const handleReset = useCallback(() => { + if (defaultValue !== undefined && value !== defaultValue) { + setCurrentValue(defaultValue) + onEditFinished(defaultValue) + } + }, [defaultValue, value, onEditFinished]) + return (
- { - if (e.key === 'Enter') { - onBlur() - } - }} - /> +
+ { + if (e.key === 'Enter') { + onBlur() + } + }} + /> + {defaultValue !== undefined && ( + + )} +
) } @@ -121,7 +145,7 @@ export default function Settings() { const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges() const showEdgeLabel = useSettingsStore.use.showEdgeLabel() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() - const graphMinDegree = useSettingsStore.use.graphMinDegree() + const graphMaxNodes = useSettingsStore.use.graphMaxNodes() const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations() const enableHealthCheck = useSettingsStore.use.enableHealthCheck() @@ -180,15 +204,14 @@ export default function Settings() { }, 300) }, []) - const setGraphMinDegree = useCallback((degree: number) => { - if (degree < 0) return - useSettingsStore.setState({ graphMinDegree: degree }) + const setGraphMaxNodes = useCallback((nodes: number) => { + if (nodes < 1 || nodes > 1000) return + useSettingsStore.setState({ graphMaxNodes: nodes }) const currentLabel = useSettingsStore.getState().queryLabel useSettingsStore.getState().setQueryLabel('') setTimeout(() => { useSettingsStore.getState().setQueryLabel(currentLabel) }, 300) - }, []) const setGraphLayoutMaxIterations = useCallback((iterations: number) => { @@ -274,19 +297,23 @@ export default function Settings() { label={t('graphPanel.sideBar.settings.maxQueryDepth')} min={1} value={graphQueryMaxDepth} + defaultValue={3} onEditFinished={setGraphQueryMaxDepth} /> diff --git a/lightrag_webui/src/components/graph/SettingsDisplay.tsx b/lightrag_webui/src/components/graph/SettingsDisplay.tsx index dec44c11d62f15214b41d9d288979f42e2dd5be3..93fc0e017d937436af82e9aeaf926bf34ae0fa29 100644 --- a/lightrag_webui/src/components/graph/SettingsDisplay.tsx +++ b/lightrag_webui/src/components/graph/SettingsDisplay.tsx @@ -8,12 +8,12 @@ import { useTranslation } from 'react-i18next' const SettingsDisplay = () => { const { t } = useTranslation() const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth() - const graphMinDegree = useSettingsStore.use.graphMinDegree() + const graphMaxNodes = useSettingsStore.use.graphMaxNodes() return (
{t('graphPanel.sideBar.settings.depth')}: {graphQueryMaxDepth}
-
{t('graphPanel.sideBar.settings.degree')}: {graphMinDegree}
+
{t('graphPanel.sideBar.settings.max')}: {graphMaxNodes}
) } diff --git a/lightrag_webui/src/components/status/StatusCard.tsx b/lightrag_webui/src/components/status/StatusCard.tsx index e67cbd300ef4b128d73aeab02827a842a2381f91..c9e64db9641f4f3f4871a574fea180eb11de29b1 100644 --- a/lightrag_webui/src/components/status/StatusCard.tsx +++ b/lightrag_webui/src/components/status/StatusCard.tsx @@ -4,14 +4,14 @@ import { useTranslation } from 'react-i18next' const StatusCard = ({ status }: { status: LightragStatus | null }) => { const { t } = useTranslation() if (!status) { - return
{t('graphPanel.statusCard.unavailable')}
+ return
{t('graphPanel.statusCard.unavailable')}
} return ( -
+

{t('graphPanel.statusCard.storageInfo')}

-
+
{t('graphPanel.statusCard.workingDirectory')}: {status.working_directory} {t('graphPanel.statusCard.inputDirectory')}: @@ -21,7 +21,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {

{t('graphPanel.statusCard.llmConfig')}

-
+
{t('graphPanel.statusCard.llmBinding')}: {status.configuration.llm_binding} {t('graphPanel.statusCard.llmBindingHost')}: @@ -35,7 +35,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {

{t('graphPanel.statusCard.embeddingConfig')}

-
+
{t('graphPanel.statusCard.embeddingBinding')}: {status.configuration.embedding_binding} {t('graphPanel.statusCard.embeddingBindingHost')}: @@ -47,7 +47,7 @@ const StatusCard = ({ status }: { status: LightragStatus | null }) => {

{t('graphPanel.statusCard.storageConfig')}

-
+
{t('graphPanel.statusCard.kvStorage')}: {status.configuration.kv_storage} {t('graphPanel.statusCard.docStatusStorage')}: diff --git a/lightrag_webui/src/components/status/StatusDialog.tsx b/lightrag_webui/src/components/status/StatusDialog.tsx new file mode 100644 index 0000000000000000000000000000000000000000..48eaa4f7fe95ab0eccc3c69d3fc465236b772205 --- /dev/null +++ b/lightrag_webui/src/components/status/StatusDialog.tsx @@ -0,0 +1,32 @@ +import { LightragStatus } from '@/api/lightrag' +import { useTranslation } from 'react-i18next' +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from '@/components/ui/Dialog' +import StatusCard from './StatusCard' + +interface StatusDialogProps { + open: boolean + onOpenChange: (open: boolean) => void + status: LightragStatus | null +} + +const StatusDialog = ({ open, onOpenChange, status }: StatusDialogProps) => { + const { t } = useTranslation() + + return ( + + + + {t('graphPanel.statusDialog.title')} + + + + + ) +} + +export default StatusDialog diff --git a/lightrag_webui/src/components/status/StatusIndicator.tsx b/lightrag_webui/src/components/status/StatusIndicator.tsx index 263bb99e9ea2118b7466d2d8ac3a24ffe666b61f..5a9fc751f7bcdcbaeb4e3226ab84ac0b51a8f345 100644 --- a/lightrag_webui/src/components/status/StatusIndicator.tsx +++ b/lightrag_webui/src/components/status/StatusIndicator.tsx @@ -1,8 +1,7 @@ import { cn } from '@/lib/utils' import { useBackendState } from '@/stores/state' import { useEffect, useState } from 'react' -import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' -import StatusCard from '@/components/status/StatusCard' +import StatusDialog from './StatusDialog' import { useTranslation } from 'react-i18next' const StatusIndicator = () => { @@ -11,6 +10,7 @@ const StatusIndicator = () => { const lastCheckTime = useBackendState.use.lastCheckTime() const status = useBackendState.use.status() const [animate, setAnimate] = useState(false) + const [dialogOpen, setDialogOpen] = useState(false) // listen to health change useEffect(() => { @@ -21,28 +21,30 @@ const StatusIndicator = () => { return (
- - -
-
- - {health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')} - -
- - - - - +
setDialogOpen(true)} + > +
+ + {health ? t('graphPanel.statusIndicator.connected') : t('graphPanel.statusIndicator.disconnected')} + +
+ +
) } diff --git a/lightrag_webui/src/components/ui/Checkbox.tsx b/lightrag_webui/src/components/ui/Checkbox.tsx index 36ebe6e0a288eb803d3afca5449b976abf6d86cd..c9d4fafeefea6076e3213c3656b0f46cc9fd6a36 100644 --- a/lightrag_webui/src/components/ui/Checkbox.tsx +++ b/lightrag_webui/src/components/ui/Checkbox.tsx @@ -11,7 +11,7 @@ const Checkbox = React.forwardRef< >( { + const sortDocuments = useCallback((documents: DocStatusResponse[]) => { return [...documents].sort((a, b) => { let valueA, valueB; @@ -194,7 +194,7 @@ export default function DocumentManager() { return sortMultiplier * (valueA > valueB ? 1 : valueA < valueB ? -1 : 0); } }); - } + }, [sortField, sortDirection, showFileName]); const filteredAndSortedDocs = useMemo(() => { if (!docs) return null; @@ -223,7 +223,7 @@ export default function DocumentManager() { }, {} as DocsStatusesResponse['statuses']); return { ...filteredDocs, statuses: sortedStatuses }; - }, [docs, sortField, sortDirection, statusFilter]); + }, [docs, sortField, sortDirection, statusFilter, sortDocuments]); // Calculate document counts for each status const documentCounts = useMemo(() => { @@ -435,8 +435,8 @@ export default function DocumentManager() {
- - + + setStatusFilter('all')} + className={cn( + statusFilter === 'all' && 'bg-gray-100 dark:bg-gray-900 font-medium border border-gray-400 dark:border-gray-500 shadow-sm' + )} > {t('documentPanel.documentManager.status.all')} ({documentCounts.all}) @@ -461,7 +464,10 @@ export default function DocumentManager() { size="sm" variant={statusFilter === 'processed' ? 'secondary' : 'outline'} onClick={() => setStatusFilter('processed')} - className={documentCounts.processed > 0 ? 'text-green-600' : 'text-gray-500'} + className={cn( + documentCounts.processed > 0 ? 'text-green-600' : 'text-gray-500', + statusFilter === 'processed' && 'bg-green-100 dark:bg-green-900/30 font-medium border border-green-400 dark:border-green-600 shadow-sm' + )} > {t('documentPanel.documentManager.status.completed')} ({documentCounts.processed || 0}) @@ -469,7 +475,10 @@ export default function DocumentManager() { size="sm" variant={statusFilter === 'processing' ? 'secondary' : 'outline'} onClick={() => setStatusFilter('processing')} - className={documentCounts.processing > 0 ? 'text-blue-600' : 'text-gray-500'} + className={cn( + documentCounts.processing > 0 ? 'text-blue-600' : 'text-gray-500', + statusFilter === 'processing' && 'bg-blue-100 dark:bg-blue-900/30 font-medium border border-blue-400 dark:border-blue-600 shadow-sm' + )} > {t('documentPanel.documentManager.status.processing')} ({documentCounts.processing || 0}) @@ -477,7 +486,10 @@ export default function DocumentManager() { size="sm" variant={statusFilter === 'pending' ? 'secondary' : 'outline'} onClick={() => setStatusFilter('pending')} - className={documentCounts.pending > 0 ? 'text-yellow-600' : 'text-gray-500'} + className={cn( + documentCounts.pending > 0 ? 'text-yellow-600' : 'text-gray-500', + statusFilter === 'pending' && 'bg-yellow-100 dark:bg-yellow-900/30 font-medium border border-yellow-400 dark:border-yellow-600 shadow-sm' + )} > {t('documentPanel.documentManager.status.pending')} ({documentCounts.pending || 0}) @@ -485,7 +497,10 @@ export default function DocumentManager() { size="sm" variant={statusFilter === 'failed' ? 'secondary' : 'outline'} onClick={() => setStatusFilter('failed')} - className={documentCounts.failed > 0 ? 'text-red-600' : 'text-gray-500'} + className={cn( + documentCounts.failed > 0 ? 'text-red-600' : 'text-gray-500', + statusFilter === 'failed' && 'bg-red-100 dark:bg-red-900/30 font-medium border border-red-400 dark:border-red-600 shadow-sm' + )} > {t('documentPanel.documentManager.status.failed')} ({documentCounts.failed || 0}) diff --git a/lightrag_webui/src/features/SiteHeader.tsx b/lightrag_webui/src/features/SiteHeader.tsx index 8077d39010350ebcc2d0ec6f2852c395b8651906..4881e4b6b079b244f90c46298e9164196a56cc19 100644 --- a/lightrag_webui/src/features/SiteHeader.tsx +++ b/lightrag_webui/src/features/SiteHeader.tsx @@ -8,6 +8,7 @@ import { cn } from '@/lib/utils' import { useTranslation } from 'react-i18next' import { navigationService } from '@/services/navigation' import { ZapIcon, GithubIcon, LogOutIcon } from 'lucide-react' +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip' interface NavigationTabProps { value: string @@ -55,7 +56,7 @@ function TabsNavigation() { export default function SiteHeader() { const { t } = useTranslation() - const { isGuestMode, coreVersion, apiVersion, username } = useAuthStore() + const { isGuestMode, coreVersion, apiVersion, username, webuiTitle, webuiDescription } = useAuthStore() const versionDisplay = (coreVersion && apiVersion) ? `${coreVersion}/${apiVersion}` @@ -67,17 +68,31 @@ export default function SiteHeader() { return (
-
+
+ {webuiTitle && ( +
+ | + + + + + {webuiTitle} + + + {webuiDescription && ( + + {webuiDescription} + + )} + + +
+ )}
@@ -91,6 +106,11 @@ export default function SiteHeader() {