Daniel.y commited on
Commit
9b50052
·
unverified ·
2 Parent(s): 01576a1 0b7db86

Merge pull request #1237 from danielaskdd/clear-doc

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README-zh.md +4 -4
  2. README.md +4 -4
  3. config.ini.example +0 -17
  4. env.example +20 -27
  5. examples/lightrag_api_ollama_demo.py +0 -188
  6. examples/lightrag_api_openai_compatible_demo.py +0 -204
  7. examples/lightrag_api_oracle_demo.py +0 -267
  8. examples/lightrag_ollama_gremlin_demo.py +4 -0
  9. examples/lightrag_oracle_demo.py +0 -141
  10. examples/lightrag_tidb_demo.py +4 -0
  11. lightrag/api/README-zh.md +4 -12
  12. lightrag/api/README.md +5 -13
  13. lightrag/api/__init__.py +1 -1
  14. lightrag/api/auth.py +9 -8
  15. lightrag/api/config.py +335 -0
  16. lightrag/api/lightrag_server.py +37 -19
  17. lightrag/api/routers/document_routes.py +465 -52
  18. lightrag/api/routers/graph_routes.py +10 -14
  19. lightrag/api/run_with_gunicorn.py +31 -25
  20. lightrag/api/utils_api.py +18 -353
  21. lightrag/api/webui/assets/{index-D8zGvNlV.js → index-BaHKTcxB.js} +0 -0
  22. lightrag/api/webui/assets/index-CD5HxTy1.css +0 -0
  23. lightrag/api/webui/assets/index-f0HMqdqP.css +0 -0
  24. lightrag/api/webui/index.html +0 -0
  25. lightrag/base.py +122 -9
  26. lightrag/kg/__init__.py +15 -40
  27. lightrag/kg/age_impl.py +21 -3
  28. lightrag/kg/chroma_impl.py +28 -2
  29. lightrag/kg/faiss_impl.py +64 -5
  30. lightrag/kg/gremlin_impl.py +24 -3
  31. lightrag/kg/json_doc_status_impl.py +49 -10
  32. lightrag/kg/json_kv_impl.py +73 -3
  33. lightrag/kg/milvus_impl.py +31 -1
  34. lightrag/kg/mongo_impl.py +130 -3
  35. lightrag/kg/nano_vector_db_impl.py +66 -0
  36. lightrag/kg/neo4j_impl.py +373 -246
  37. lightrag/kg/networkx_impl.py +122 -97
  38. lightrag/kg/oracle_impl.py +0 -1346
  39. lightrag/kg/postgres_impl.py +380 -317
  40. lightrag/kg/qdrant_impl.py +93 -6
  41. lightrag/kg/redis_impl.py +35 -51
  42. lightrag/kg/tidb_impl.py +172 -3
  43. lightrag/lightrag.py +25 -42
  44. lightrag/operate.py +68 -88
  45. lightrag/types.py +1 -0
  46. lightrag_webui/src/AppRouter.tsx +6 -1
  47. lightrag_webui/src/api/lightrag.ts +39 -2
  48. lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +117 -15
  49. lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx +20 -2
  50. lightrag_webui/src/components/graph/Settings.tsx +52 -25
README-zh.md CHANGED
@@ -11,7 +11,6 @@
11
  - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
12
  - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
13
  - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
14
- - [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
15
  - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
16
  - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
17
  - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
@@ -1085,9 +1084,10 @@ rag.clear_cache(modes=["local"])
1085
  | **参数** | **类型** | **说明** | **默认值** |
1086
  |--------------|----------|-----------------|-------------|
1087
  | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
1088
- | **kv_storage** | `str` | 文档和文本块的存储类型。支持的类型:`JsonKVStorage`、`OracleKVStorage` | `JsonKVStorage` |
1089
- | **vector_storage** | `str` | 嵌入向量的存储类型。支持的类型:`NanoVectorDBStorage`、`OracleVectorDBStorage` | `NanoVectorDBStorage` |
1090
- | **graph_storage** | `str` | 图边和节点的存储类型。支持的类型:`NetworkXStorage`、`Neo4JStorage`、`OracleGraphStorage` | `NetworkXStorage` |
 
1091
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1092
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1093
  | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
 
11
  - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
12
  - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
13
  - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
 
14
  - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
15
  - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
16
  - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
 
1084
  | **参数** | **类型** | **说明** | **默认值** |
1085
  |--------------|----------|-----------------|-------------|
1086
  | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
1087
+ | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
1088
+ | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
1089
+ | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
1090
+ | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1091
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1092
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1093
  | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
README.md CHANGED
@@ -41,7 +41,6 @@
41
  - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
42
  - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
43
  - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
44
- - [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
45
  - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
46
  - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
47
  - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
@@ -1145,9 +1144,10 @@ Valid modes are:
1145
  | **Parameter** | **Type** | **Explanation** | **Default** |
1146
  |--------------|----------|-----------------|-------------|
1147
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
1148
- | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
1149
- | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
1150
- | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
 
1151
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1152
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1153
  | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
 
41
  - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
42
  - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
43
  - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
 
44
  - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
45
  - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
46
  - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
 
1144
  | **Parameter** | **Type** | **Explanation** | **Default** |
1145
  |--------------|----------|-----------------|-------------|
1146
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
1147
+ | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
1148
+ | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
1149
+ | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
1150
+ | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1151
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1152
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1153
  | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
config.ini.example CHANGED
@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
13
  [qdrant]
14
  uri = http://localhost:16333
15
 
16
- [oracle]
17
- dsn = localhost:1521/XEPDB1
18
- user = your_username
19
- password = your_password
20
- config_dir = /path/to/oracle/config
21
- wallet_location = /path/to/wallet # 可选
22
- wallet_password = your_wallet_password # 可选
23
- workspace = default # 可选,默认为default
24
-
25
- [tidb]
26
- host = localhost
27
- port = 4000
28
- user = your_username
29
- password = your_password
30
- database = your_database
31
- workspace = default # 可选,默认为default
32
-
33
  [postgres]
34
  host = localhost
35
  port = 5432
 
13
  [qdrant]
14
  uri = http://localhost:16333
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  [postgres]
17
  host = localhost
18
  port = 5432
env.example CHANGED
@@ -4,11 +4,9 @@
4
  # HOST=0.0.0.0
5
  # PORT=9621
6
  # WORKERS=2
7
- ### separating data from difference Lightrag instances
8
- # NAMESPACE_PREFIX=lightrag
9
- ### Max nodes return from grap retrieval
10
- # MAX_GRAPH_NODES=1000
11
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
 
 
12
 
13
  ### Optional SSL Configuration
14
  # SSL=true
@@ -22,6 +20,9 @@
22
  ### Ollama Emulating Model Tag
23
  # OLLAMA_EMULATING_MODEL_TAG=latest
24
 
 
 
 
25
  ### Logging level
26
  # LOG_LEVEL=INFO
27
  # VERBOSE=False
@@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
110
  LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
111
  LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
112
 
113
- ### Oracle Database Configuration
114
- ORACLE_DSN=localhost:1521/XEPDB1
115
- ORACLE_USER=your_username
116
- ORACLE_PASSWORD='your_password'
117
- ORACLE_CONFIG_DIR=/path/to/oracle/config
118
- #ORACLE_WALLET_LOCATION=/path/to/wallet
119
- #ORACLE_WALLET_PASSWORD='your_password'
120
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
121
- #ORACLE_WORKSPACE=default
122
-
123
- ### TiDB Configuration
124
- TIDB_HOST=localhost
125
- TIDB_PORT=4000
126
- TIDB_USER=your_username
127
- TIDB_PASSWORD='your_password'
128
- TIDB_DATABASE=your_database
129
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
130
- #TIDB_WORKSPACE=default
131
 
132
  ### PostgreSQL Configuration
133
  POSTGRES_HOST=localhost
@@ -135,8 +126,8 @@ POSTGRES_PORT=5432
135
  POSTGRES_USER=your_username
136
  POSTGRES_PASSWORD='your_password'
137
  POSTGRES_DATABASE=your_database
138
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
139
- #POSTGRES_WORKSPACE=default
140
 
141
  ### Independent AGM Configuration(not for AMG embedded in PostreSQL)
142
  AGE_POSTGRES_DB=
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
145
  AGE_POSTGRES_HOST=
146
  # AGE_POSTGRES_PORT=8529
147
 
148
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
149
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
 
150
  # AGE_GRAPH_NAME=lightrag
151
 
152
  ### Neo4j Configuration
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
157
  ### MongoDB Configuration
158
  MONGO_URI=mongodb://root:root@localhost:27017/
159
  MONGO_DATABASE=LightRAG
160
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
161
  # MONGODB_GRAPH=false
162
 
163
  ### Milvus Configuration
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
177
  ### For JWT Auth
178
  # AUTH_ACCOUNTS='admin:admin123,user1:pass456'
179
  # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
180
- # TOKEN_EXPIRE_HOURS=4
 
 
181
 
182
  ### API-Key to access LightRAG Server API
183
  # LIGHTRAG_API_KEY=your-secure-api-key-here
 
4
  # HOST=0.0.0.0
5
  # PORT=9621
6
  # WORKERS=2
 
 
 
 
7
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
8
+ WEBUI_TITLE='Graph RAG Engine'
9
+ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
10
 
11
  ### Optional SSL Configuration
12
  # SSL=true
 
20
  ### Ollama Emulating Model Tag
21
  # OLLAMA_EMULATING_MODEL_TAG=latest
22
 
23
+ ### Max nodes return from grap retrieval
24
+ # MAX_GRAPH_NODES=1000
25
+
26
  ### Logging level
27
  # LOG_LEVEL=INFO
28
  # VERBOSE=False
 
111
  LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
112
  LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
113
 
114
+ ### TiDB Configuration (Deprecated)
115
+ # TIDB_HOST=localhost
116
+ # TIDB_PORT=4000
117
+ # TIDB_USER=your_username
118
+ # TIDB_PASSWORD='your_password'
119
+ # TIDB_DATABASE=your_database
120
+ ### separating all data from difference Lightrag instances(deprecating)
121
+ # TIDB_WORKSPACE=default
 
 
 
 
 
 
 
 
 
 
122
 
123
  ### PostgreSQL Configuration
124
  POSTGRES_HOST=localhost
 
126
  POSTGRES_USER=your_username
127
  POSTGRES_PASSWORD='your_password'
128
  POSTGRES_DATABASE=your_database
129
+ ### separating all data from difference Lightrag instances(deprecating)
130
+ # POSTGRES_WORKSPACE=default
131
 
132
  ### Independent AGM Configuration(not for AMG embedded in PostreSQL)
133
  AGE_POSTGRES_DB=
 
136
  AGE_POSTGRES_HOST=
137
  # AGE_POSTGRES_PORT=8529
138
 
 
139
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
140
+ ### AGE_GRAPH_NAME is precated
141
  # AGE_GRAPH_NAME=lightrag
142
 
143
  ### Neo4j Configuration
 
148
  ### MongoDB Configuration
149
  MONGO_URI=mongodb://root:root@localhost:27017/
150
  MONGO_DATABASE=LightRAG
151
+ ### separating all data from difference Lightrag instances(deprecating)
152
  # MONGODB_GRAPH=false
153
 
154
  ### Milvus Configuration
 
168
  ### For JWT Auth
169
  # AUTH_ACCOUNTS='admin:admin123,user1:pass456'
170
  # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
171
+ # TOKEN_EXPIRE_HOURS=48
172
+ # GUEST_TOKEN_EXPIRE_HOURS=24
173
+ # JWT_ALGORITHM=HS256
174
 
175
  ### API-Key to access LightRAG Server API
176
  # LIGHTRAG_API_KEY=your-secure-api-key-here
examples/lightrag_api_ollama_demo.py DELETED
@@ -1,188 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from contextlib import asynccontextmanager
3
- from pydantic import BaseModel
4
- import os
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.ollama import ollama_embed, ollama_model_complete
7
- from lightrag.utils import EmbeddingFunc
8
- from typing import Optional
9
- import asyncio
10
- import nest_asyncio
11
- import aiofiles
12
- from lightrag.kg.shared_storage import initialize_pipeline_status
13
-
14
- # Apply nest_asyncio to solve event loop issues
15
- nest_asyncio.apply()
16
-
17
- DEFAULT_RAG_DIR = "index_default"
18
-
19
- DEFAULT_INPUT_FILE = "book.txt"
20
- INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
21
- print(f"INPUT_FILE: {INPUT_FILE}")
22
-
23
- # Configure working directory
24
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
25
- print(f"WORKING_DIR: {WORKING_DIR}")
26
-
27
-
28
- if not os.path.exists(WORKING_DIR):
29
- os.mkdir(WORKING_DIR)
30
-
31
-
32
- async def init():
33
- rag = LightRAG(
34
- working_dir=WORKING_DIR,
35
- llm_model_func=ollama_model_complete,
36
- llm_model_name="gemma2:9b",
37
- llm_model_max_async=4,
38
- llm_model_max_token_size=8192,
39
- llm_model_kwargs={
40
- "host": "http://localhost:11434",
41
- "options": {"num_ctx": 8192},
42
- },
43
- embedding_func=EmbeddingFunc(
44
- embedding_dim=768,
45
- max_token_size=8192,
46
- func=lambda texts: ollama_embed(
47
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
48
- ),
49
- ),
50
- )
51
-
52
- # Add initialization code
53
- await rag.initialize_storages()
54
- await initialize_pipeline_status()
55
-
56
- return rag
57
-
58
-
59
- @asynccontextmanager
60
- async def lifespan(app: FastAPI):
61
- global rag
62
- rag = await init()
63
- print("done!")
64
- yield
65
-
66
-
67
- app = FastAPI(
68
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
69
- )
70
-
71
-
72
- # Data models
73
- class QueryRequest(BaseModel):
74
- query: str
75
- mode: str = "hybrid"
76
- only_need_context: bool = False
77
-
78
-
79
- class InsertRequest(BaseModel):
80
- text: str
81
-
82
-
83
- class Response(BaseModel):
84
- status: str
85
- data: Optional[str] = None
86
- message: Optional[str] = None
87
-
88
-
89
- # API routes
90
- @app.post("/query", response_model=Response)
91
- async def query_endpoint(request: QueryRequest):
92
- try:
93
- loop = asyncio.get_event_loop()
94
- result = await loop.run_in_executor(
95
- None,
96
- lambda: rag.query(
97
- request.query,
98
- param=QueryParam(
99
- mode=request.mode, only_need_context=request.only_need_context
100
- ),
101
- ),
102
- )
103
- return Response(status="success", data=result)
104
- except Exception as e:
105
- raise HTTPException(status_code=500, detail=str(e))
106
-
107
-
108
- # insert by text
109
- @app.post("/insert", response_model=Response)
110
- async def insert_endpoint(request: InsertRequest):
111
- try:
112
- loop = asyncio.get_event_loop()
113
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
114
- return Response(status="success", message="Text inserted successfully")
115
- except Exception as e:
116
- raise HTTPException(status_code=500, detail=str(e))
117
-
118
-
119
- # insert by file in payload
120
- @app.post("/insert_file", response_model=Response)
121
- async def insert_file(file: UploadFile = File(...)):
122
- try:
123
- file_content = await file.read()
124
- # Read file content
125
- try:
126
- content = file_content.decode("utf-8")
127
- except UnicodeDecodeError:
128
- # If UTF-8 decoding fails, try other encodings
129
- content = file_content.decode("gbk")
130
- # Insert file content
131
- loop = asyncio.get_event_loop()
132
- await loop.run_in_executor(None, lambda: rag.insert(content))
133
-
134
- return Response(
135
- status="success",
136
- message=f"File content from {file.filename} inserted successfully",
137
- )
138
- except Exception as e:
139
- raise HTTPException(status_code=500, detail=str(e))
140
-
141
-
142
- # insert by local default file
143
- @app.post("/insert_default_file", response_model=Response)
144
- @app.get("/insert_default_file", response_model=Response)
145
- async def insert_default_file():
146
- try:
147
- # Read file content from book.txt
148
- async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
149
- content = await file.read()
150
- print(f"read input file {INPUT_FILE} successfully")
151
- # Insert file content
152
- loop = asyncio.get_event_loop()
153
- await loop.run_in_executor(None, lambda: rag.insert(content))
154
-
155
- return Response(
156
- status="success",
157
- message=f"File content from {INPUT_FILE} inserted successfully",
158
- )
159
- except Exception as e:
160
- raise HTTPException(status_code=500, detail=str(e))
161
-
162
-
163
- @app.get("/health")
164
- async def health_check():
165
- return {"status": "healthy"}
166
-
167
-
168
- if __name__ == "__main__":
169
- import uvicorn
170
-
171
- uvicorn.run(app, host="0.0.0.0", port=8020)
172
-
173
- # Usage example
174
- # To run the server, use the following command in your terminal:
175
- # python lightrag_api_openai_compatible_demo.py
176
-
177
- # Example requests:
178
- # 1. Query:
179
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
180
-
181
- # 2. Insert text:
182
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
183
-
184
- # 3. Insert file:
185
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
186
-
187
- # 4. Health check:
188
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_api_openai_compatible_demo.py DELETED
@@ -1,204 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from contextlib import asynccontextmanager
3
- from pydantic import BaseModel
4
- import os
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
- from lightrag.utils import EmbeddingFunc
8
- import numpy as np
9
- from typing import Optional
10
- import asyncio
11
- import nest_asyncio
12
- from lightrag.kg.shared_storage import initialize_pipeline_status
13
-
14
- # Apply nest_asyncio to solve event loop issues
15
- nest_asyncio.apply()
16
-
17
- DEFAULT_RAG_DIR = "index_default"
18
- app = FastAPI(title="LightRAG API", description="API for RAG operations")
19
-
20
- # Configure working directory
21
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
22
- print(f"WORKING_DIR: {WORKING_DIR}")
23
- LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
24
- print(f"LLM_MODEL: {LLM_MODEL}")
25
- EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
26
- print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
27
- EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
28
- print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
29
- BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
30
- print(f"BASE_URL: {BASE_URL}")
31
- API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
32
- print(f"API_KEY: {API_KEY}")
33
-
34
- if not os.path.exists(WORKING_DIR):
35
- os.mkdir(WORKING_DIR)
36
-
37
-
38
- # LLM model function
39
-
40
-
41
- async def llm_model_func(
42
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
43
- ) -> str:
44
- return await openai_complete_if_cache(
45
- model=LLM_MODEL,
46
- prompt=prompt,
47
- system_prompt=system_prompt,
48
- history_messages=history_messages,
49
- base_url=BASE_URL,
50
- api_key=API_KEY,
51
- **kwargs,
52
- )
53
-
54
-
55
- # Embedding function
56
-
57
-
58
- async def embedding_func(texts: list[str]) -> np.ndarray:
59
- return await openai_embed(
60
- texts=texts,
61
- model=EMBEDDING_MODEL,
62
- base_url=BASE_URL,
63
- api_key=API_KEY,
64
- )
65
-
66
-
67
- async def get_embedding_dim():
68
- test_text = ["This is a test sentence."]
69
- embedding = await embedding_func(test_text)
70
- embedding_dim = embedding.shape[1]
71
- print(f"{embedding_dim=}")
72
- return embedding_dim
73
-
74
-
75
- # Initialize RAG instance
76
- async def init():
77
- embedding_dimension = await get_embedding_dim()
78
-
79
- rag = LightRAG(
80
- working_dir=WORKING_DIR,
81
- llm_model_func=llm_model_func,
82
- embedding_func=EmbeddingFunc(
83
- embedding_dim=embedding_dimension,
84
- max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
85
- func=embedding_func,
86
- ),
87
- )
88
-
89
- await rag.initialize_storages()
90
- await initialize_pipeline_status()
91
-
92
- return rag
93
-
94
-
95
- @asynccontextmanager
96
- async def lifespan(app: FastAPI):
97
- global rag
98
- rag = await init()
99
- print("done!")
100
- yield
101
-
102
-
103
- app = FastAPI(
104
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
105
- )
106
-
107
- # Data models
108
-
109
-
110
- class QueryRequest(BaseModel):
111
- query: str
112
- mode: str = "hybrid"
113
- only_need_context: bool = False
114
-
115
-
116
- class InsertRequest(BaseModel):
117
- text: str
118
-
119
-
120
- class Response(BaseModel):
121
- status: str
122
- data: Optional[str] = None
123
- message: Optional[str] = None
124
-
125
-
126
- # API routes
127
-
128
-
129
- @app.post("/query", response_model=Response)
130
- async def query_endpoint(request: QueryRequest):
131
- try:
132
- loop = asyncio.get_event_loop()
133
- result = await loop.run_in_executor(
134
- None,
135
- lambda: rag.query(
136
- request.query,
137
- param=QueryParam(
138
- mode=request.mode, only_need_context=request.only_need_context
139
- ),
140
- ),
141
- )
142
- return Response(status="success", data=result)
143
- except Exception as e:
144
- raise HTTPException(status_code=500, detail=str(e))
145
-
146
-
147
- @app.post("/insert", response_model=Response)
148
- async def insert_endpoint(request: InsertRequest):
149
- try:
150
- loop = asyncio.get_event_loop()
151
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
152
- return Response(status="success", message="Text inserted successfully")
153
- except Exception as e:
154
- raise HTTPException(status_code=500, detail=str(e))
155
-
156
-
157
- @app.post("/insert_file", response_model=Response)
158
- async def insert_file(file: UploadFile = File(...)):
159
- try:
160
- file_content = await file.read()
161
- # Read file content
162
- try:
163
- content = file_content.decode("utf-8")
164
- except UnicodeDecodeError:
165
- # If UTF-8 decoding fails, try other encodings
166
- content = file_content.decode("gbk")
167
- # Insert file content
168
- loop = asyncio.get_event_loop()
169
- await loop.run_in_executor(None, lambda: rag.insert(content))
170
-
171
- return Response(
172
- status="success",
173
- message=f"File content from {file.filename} inserted successfully",
174
- )
175
- except Exception as e:
176
- raise HTTPException(status_code=500, detail=str(e))
177
-
178
-
179
- @app.get("/health")
180
- async def health_check():
181
- return {"status": "healthy"}
182
-
183
-
184
- if __name__ == "__main__":
185
- import uvicorn
186
-
187
- uvicorn.run(app, host="0.0.0.0", port=8020)
188
-
189
- # Usage example
190
- # To run the server, use the following command in your terminal:
191
- # python lightrag_api_openai_compatible_demo.py
192
-
193
- # Example requests:
194
- # 1. Query:
195
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
196
-
197
- # 2. Insert text:
198
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
199
-
200
- # 3. Insert file:
201
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
202
-
203
- # 4. Health check:
204
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_api_oracle_demo.py DELETED
@@ -1,267 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from fastapi import Query
3
- from contextlib import asynccontextmanager
4
- from pydantic import BaseModel
5
- from typing import Optional, Any
6
-
7
- import sys
8
- import os
9
-
10
-
11
- from pathlib import Path
12
-
13
- import asyncio
14
- import nest_asyncio
15
- from lightrag import LightRAG, QueryParam
16
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
17
- from lightrag.utils import EmbeddingFunc
18
- import numpy as np
19
- from lightrag.kg.shared_storage import initialize_pipeline_status
20
-
21
-
22
- print(os.getcwd())
23
- script_directory = Path(__file__).resolve().parent.parent
24
- sys.path.append(os.path.abspath(script_directory))
25
-
26
-
27
- # Apply nest_asyncio to solve event loop issues
28
- nest_asyncio.apply()
29
-
30
- DEFAULT_RAG_DIR = "index_default"
31
-
32
-
33
- # We use OpenAI compatible API to call LLM on Oracle Cloud
34
- # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
35
- BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
36
- APIKEY = "ocigenerativeai"
37
-
38
- # Configure working directory
39
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
40
- print(f"WORKING_DIR: {WORKING_DIR}")
41
- LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
42
- print(f"LLM_MODEL: {LLM_MODEL}")
43
- EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
44
- print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
45
- EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
46
- print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
47
-
48
- if not os.path.exists(WORKING_DIR):
49
- os.mkdir(WORKING_DIR)
50
-
51
- os.environ["ORACLE_USER"] = ""
52
- os.environ["ORACLE_PASSWORD"] = ""
53
- os.environ["ORACLE_DSN"] = ""
54
- os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
55
- os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
56
- os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
57
- os.environ["ORACLE_WORKSPACE"] = "company"
58
-
59
-
60
- async def llm_model_func(
61
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
62
- ) -> str:
63
- return await openai_complete_if_cache(
64
- LLM_MODEL,
65
- prompt,
66
- system_prompt=system_prompt,
67
- history_messages=history_messages,
68
- api_key=APIKEY,
69
- base_url=BASE_URL,
70
- **kwargs,
71
- )
72
-
73
-
74
- async def embedding_func(texts: list[str]) -> np.ndarray:
75
- return await openai_embed(
76
- texts,
77
- model=EMBEDDING_MODEL,
78
- api_key=APIKEY,
79
- base_url=BASE_URL,
80
- )
81
-
82
-
83
- async def get_embedding_dim():
84
- test_text = ["This is a test sentence."]
85
- embedding = await embedding_func(test_text)
86
- embedding_dim = embedding.shape[1]
87
- return embedding_dim
88
-
89
-
90
- async def init():
91
- # Detect embedding dimension
92
- embedding_dimension = await get_embedding_dim()
93
- print(f"Detected embedding dimension: {embedding_dimension}")
94
- # Create Oracle DB connection
95
- # The `config` parameter is the connection configuration of Oracle DB
96
- # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
97
- # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
98
- # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
99
-
100
- # Initialize LightRAG
101
- # We use Oracle DB as the KV/vector/graph storage
102
- rag = LightRAG(
103
- enable_llm_cache=False,
104
- working_dir=WORKING_DIR,
105
- chunk_token_size=512,
106
- llm_model_func=llm_model_func,
107
- embedding_func=EmbeddingFunc(
108
- embedding_dim=embedding_dimension,
109
- max_token_size=512,
110
- func=embedding_func,
111
- ),
112
- graph_storage="OracleGraphStorage",
113
- kv_storage="OracleKVStorage",
114
- vector_storage="OracleVectorDBStorage",
115
- )
116
-
117
- await rag.initialize_storages()
118
- await initialize_pipeline_status()
119
-
120
- return rag
121
-
122
-
123
- # Extract and Insert into LightRAG storage
124
- # with open("./dickens/book.txt", "r", encoding="utf-8") as f:
125
- # await rag.ainsert(f.read())
126
-
127
- # # Perform search in different modes
128
- # modes = ["naive", "local", "global", "hybrid"]
129
- # for mode in modes:
130
- # print("="*20, mode, "="*20)
131
- # print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
132
- # print("-"*100, "\n")
133
-
134
- # Data models
135
-
136
-
137
- class QueryRequest(BaseModel):
138
- query: str
139
- mode: str = "hybrid"
140
- only_need_context: bool = False
141
- only_need_prompt: bool = False
142
-
143
-
144
- class DataRequest(BaseModel):
145
- limit: int = 100
146
-
147
-
148
- class InsertRequest(BaseModel):
149
- text: str
150
-
151
-
152
- class Response(BaseModel):
153
- status: str
154
- data: Optional[Any] = None
155
- message: Optional[str] = None
156
-
157
-
158
- # API routes
159
-
160
- rag = None
161
-
162
-
163
- @asynccontextmanager
164
- async def lifespan(app: FastAPI):
165
- global rag
166
- rag = await init()
167
- print("done!")
168
- yield
169
-
170
-
171
- app = FastAPI(
172
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
173
- )
174
-
175
-
176
- @app.post("/query", response_model=Response)
177
- async def query_endpoint(request: QueryRequest):
178
- # try:
179
- # loop = asyncio.get_event_loop()
180
- if request.mode == "naive":
181
- top_k = 3
182
- else:
183
- top_k = 60
184
- result = await rag.aquery(
185
- request.query,
186
- param=QueryParam(
187
- mode=request.mode,
188
- only_need_context=request.only_need_context,
189
- only_need_prompt=request.only_need_prompt,
190
- top_k=top_k,
191
- ),
192
- )
193
- return Response(status="success", data=result)
194
- # except Exception as e:
195
- # raise HTTPException(status_code=500, detail=str(e))
196
-
197
-
198
- @app.get("/data", response_model=Response)
199
- async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
200
- if type == "nodes":
201
- result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
202
- elif type == "edges":
203
- result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
204
- elif type == "statistics":
205
- result = await rag.chunk_entity_relation_graph.get_statistics()
206
- return Response(status="success", data=result)
207
-
208
-
209
- @app.post("/insert", response_model=Response)
210
- async def insert_endpoint(request: InsertRequest):
211
- try:
212
- loop = asyncio.get_event_loop()
213
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
214
- return Response(status="success", message="Text inserted successfully")
215
- except Exception as e:
216
- raise HTTPException(status_code=500, detail=str(e))
217
-
218
-
219
- @app.post("/insert_file", response_model=Response)
220
- async def insert_file(file: UploadFile = File(...)):
221
- try:
222
- file_content = await file.read()
223
- # Read file content
224
- try:
225
- content = file_content.decode("utf-8")
226
- except UnicodeDecodeError:
227
- # If UTF-8 decoding fails, try other encodings
228
- content = file_content.decode("gbk")
229
- # Insert file content
230
- loop = asyncio.get_event_loop()
231
- await loop.run_in_executor(None, lambda: rag.insert(content))
232
-
233
- return Response(
234
- status="success",
235
- message=f"File content from {file.filename} inserted successfully",
236
- )
237
- except Exception as e:
238
- raise HTTPException(status_code=500, detail=str(e))
239
-
240
-
241
- @app.get("/health")
242
- async def health_check():
243
- return {"status": "healthy"}
244
-
245
-
246
- if __name__ == "__main__":
247
- import uvicorn
248
-
249
- uvicorn.run(app, host="127.0.0.1", port=8020)
250
-
251
- # Usage example
252
- # To run the server, use the following command in your terminal:
253
- # python lightrag_api_openai_compatible_demo.py
254
-
255
- # Example requests:
256
- # 1. Query:
257
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
258
-
259
- # 2. Insert text:
260
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
261
-
262
- # 3. Insert file:
263
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
264
-
265
-
266
- # 4. Health check:
267
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_ollama_gremlin_demo.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import asyncio
2
  import inspect
3
  import os
 
1
+ ##############################################
2
+ # Gremlin storage implementation is deprecated
3
+ ##############################################
4
+
5
  import asyncio
6
  import inspect
7
  import os
examples/lightrag_oracle_demo.py DELETED
@@ -1,141 +0,0 @@
1
- import sys
2
- import os
3
- from pathlib import Path
4
- import asyncio
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
- from lightrag.utils import EmbeddingFunc
8
- import numpy as np
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
-
11
- print(os.getcwd())
12
- script_directory = Path(__file__).resolve().parent.parent
13
- sys.path.append(os.path.abspath(script_directory))
14
-
15
- WORKING_DIR = "./dickens"
16
-
17
- # We use OpenAI compatible API to call LLM on Oracle Cloud
18
- # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
19
- BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
20
- APIKEY = "ocigenerativeai"
21
- CHATMODEL = "cohere.command-r-plus"
22
- EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
- CHUNK_TOKEN_SIZE = 1024
24
- MAX_TOKENS = 4000
25
-
26
- if not os.path.exists(WORKING_DIR):
27
- os.mkdir(WORKING_DIR)
28
-
29
- os.environ["ORACLE_USER"] = "username"
30
- os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
31
- os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
32
- os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
33
- os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
34
- os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
35
- os.environ["ORACLE_WORKSPACE"] = "company"
36
-
37
-
38
- async def llm_model_func(
39
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
40
- ) -> str:
41
- return await openai_complete_if_cache(
42
- CHATMODEL,
43
- prompt,
44
- system_prompt=system_prompt,
45
- history_messages=history_messages,
46
- api_key=APIKEY,
47
- base_url=BASE_URL,
48
- **kwargs,
49
- )
50
-
51
-
52
- async def embedding_func(texts: list[str]) -> np.ndarray:
53
- return await openai_embed(
54
- texts,
55
- model=EMBEDMODEL,
56
- api_key=APIKEY,
57
- base_url=BASE_URL,
58
- )
59
-
60
-
61
- async def get_embedding_dim():
62
- test_text = ["This is a test sentence."]
63
- embedding = await embedding_func(test_text)
64
- embedding_dim = embedding.shape[1]
65
- return embedding_dim
66
-
67
-
68
- async def initialize_rag():
69
- # Detect embedding dimension
70
- embedding_dimension = await get_embedding_dim()
71
- print(f"Detected embedding dimension: {embedding_dimension}")
72
-
73
- # Initialize LightRAG
74
- # We use Oracle DB as the KV/vector/graph storage
75
- # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
76
- rag = LightRAG(
77
- # log_level="DEBUG",
78
- working_dir=WORKING_DIR,
79
- entity_extract_max_gleaning=1,
80
- enable_llm_cache=True,
81
- enable_llm_cache_for_entity_extract=True,
82
- embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
83
- chunk_token_size=CHUNK_TOKEN_SIZE,
84
- llm_model_max_token_size=MAX_TOKENS,
85
- llm_model_func=llm_model_func,
86
- embedding_func=EmbeddingFunc(
87
- embedding_dim=embedding_dimension,
88
- max_token_size=500,
89
- func=embedding_func,
90
- ),
91
- graph_storage="OracleGraphStorage",
92
- kv_storage="OracleKVStorage",
93
- vector_storage="OracleVectorDBStorage",
94
- addon_params={
95
- "example_number": 1,
96
- "language": "Simplfied Chinese",
97
- "entity_types": ["organization", "person", "geo", "event"],
98
- "insert_batch_size": 2,
99
- },
100
- )
101
- await rag.initialize_storages()
102
- await initialize_pipeline_status()
103
-
104
- return rag
105
-
106
-
107
- async def main():
108
- try:
109
- # Initialize RAG instance
110
- rag = await initialize_rag()
111
-
112
- # Extract and Insert into LightRAG storage
113
- with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
114
- all_text = f.read()
115
- texts = [x for x in all_text.split("\n") if x]
116
-
117
- # New mode use pipeline
118
- await rag.apipeline_enqueue_documents(texts)
119
- await rag.apipeline_process_enqueue_documents()
120
-
121
- # Old method use ainsert
122
- # await rag.ainsert(texts)
123
-
124
- # Perform search in different modes
125
- modes = ["naive", "local", "global", "hybrid"]
126
- for mode in modes:
127
- print("=" * 20, mode, "=" * 20)
128
- print(
129
- await rag.aquery(
130
- "What are the top themes in this story?",
131
- param=QueryParam(mode=mode),
132
- )
133
- )
134
- print("-" * 100, "\n")
135
-
136
- except Exception as e:
137
- print(f"An error occurred: {e}")
138
-
139
-
140
- if __name__ == "__main__":
141
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_tidb_demo.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import asyncio
2
  import os
3
 
 
1
+ ###########################################
2
+ # TiDB storage implementation is deprecated
3
+ ###########################################
4
+
5
  import asyncio
6
  import os
7
 
lightrag/api/README-zh.md CHANGED
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
291
 
292
  ```
293
  JsonKVStorage JsonFile(默认)
294
- MongoKVStorage MogonDB
295
- RedisKVStorage Redis
296
- TiDBKVStorage TiDB
297
  PGKVStorage Postgres
298
- OracleKVStorage Oracle
 
299
  ```
300
 
301
  * GRAPH_STORAGE 支持的实现名称
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
303
  ```
304
  NetworkXStorage NetworkX(默认)
305
  Neo4JStorage Neo4J
306
- MongoGraphStorage MongoDB
307
- TiDBGraphStorage TiDB
308
- AGEStorage AGE
309
- GremlinStorage Gremlin
310
  PGGraphStorage Postgres
311
- OracleGraphStorage Postgres
312
  ```
313
 
314
  * VECTOR_STORAGE 支持的实现名称
315
 
316
  ```
317
  NanoVectorDBStorage NanoVector(默认)
 
318
  MilvusVectorDBStorge Milvus
319
  ChromaVectorDBStorage Chroma
320
- TiDBVectorDBStorage TiDB
321
- PGVectorStorage Postgres
322
  FaissVectorDBStorage Faiss
323
  QdrantVectorDBStorage Qdrant
324
- OracleVectorDBStorage Oracle
325
  MongoVectorDBStorage MongoDB
326
  ```
327
 
 
291
 
292
  ```
293
  JsonKVStorage JsonFile(默认)
 
 
 
294
  PGKVStorage Postgres
295
+ RedisKVStorage Redis
296
+ MongoKVStorage MogonDB
297
  ```
298
 
299
  * GRAPH_STORAGE 支持的实现名称
 
301
  ```
302
  NetworkXStorage NetworkX(默认)
303
  Neo4JStorage Neo4J
 
 
 
 
304
  PGGraphStorage Postgres
305
+ AGEStorage AGE
306
  ```
307
 
308
  * VECTOR_STORAGE 支持的实现名称
309
 
310
  ```
311
  NanoVectorDBStorage NanoVector(默认)
312
+ PGVectorStorage Postgres
313
  MilvusVectorDBStorge Milvus
314
  ChromaVectorDBStorage Chroma
 
 
315
  FaissVectorDBStorage Faiss
316
  QdrantVectorDBStorage Qdrant
 
317
  MongoVectorDBStorage MongoDB
318
  ```
319
 
lightrag/api/README.md CHANGED
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
302
 
303
  ```
304
  JsonKVStorage JsonFile(default)
305
- MongoKVStorage MogonDB
306
- RedisKVStorage Redis
307
- TiDBKVStorage TiDB
308
  PGKVStorage Postgres
309
- OracleKVStorage Oracle
 
310
  ```
311
 
312
  * GRAPH_STORAGE supported implement-name
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
314
  ```
315
  NetworkXStorage NetworkX(defualt)
316
  Neo4JStorage Neo4J
317
- MongoGraphStorage MongoDB
318
- TiDBGraphStorage TiDB
319
- AGEStorage AGE
320
- GremlinStorage Gremlin
321
  PGGraphStorage Postgres
322
- OracleGraphStorage Postgres
323
  ```
324
 
325
  * VECTOR_STORAGE supported implement-name
326
 
327
  ```
328
  NanoVectorDBStorage NanoVector(default)
329
- MilvusVectorDBStorage Milvus
330
- ChromaVectorDBStorage Chroma
331
- TiDBVectorDBStorage TiDB
332
  PGVectorStorage Postgres
 
 
333
  FaissVectorDBStorage Faiss
334
  QdrantVectorDBStorage Qdrant
335
- OracleVectorDBStorage Oracle
336
  MongoVectorDBStorage MongoDB
337
  ```
338
 
 
302
 
303
  ```
304
  JsonKVStorage JsonFile(default)
 
 
 
305
  PGKVStorage Postgres
306
+ RedisKVStorage Redis
307
+ MongoKVStorage MogonDB
308
  ```
309
 
310
  * GRAPH_STORAGE supported implement-name
 
312
  ```
313
  NetworkXStorage NetworkX(defualt)
314
  Neo4JStorage Neo4J
 
 
 
 
315
  PGGraphStorage Postgres
316
+ AGEStorage AGE
317
  ```
318
 
319
  * VECTOR_STORAGE supported implement-name
320
 
321
  ```
322
  NanoVectorDBStorage NanoVector(default)
 
 
 
323
  PGVectorStorage Postgres
324
+ MilvusVectorDBStorge Milvus
325
+ ChromaVectorDBStorage Chroma
326
  FaissVectorDBStorage Faiss
327
  QdrantVectorDBStorage Qdrant
 
328
  MongoVectorDBStorage MongoDB
329
  ```
330
 
lightrag/api/__init__.py CHANGED
@@ -1 +1 @@
1
- __api_version__ = "1.2.8"
 
1
+ __api_version__ = "0132"
lightrag/api/auth.py CHANGED
@@ -1,9 +1,11 @@
1
- import os
2
  from datetime import datetime, timedelta
 
3
  import jwt
 
4
  from fastapi import HTTPException, status
5
  from pydantic import BaseModel
6
- from dotenv import load_dotenv
 
7
 
8
  # use the .env that is inside the current folder
9
  # allows to use different .env file for each lightrag instance
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
20
 
21
  class AuthHandler:
22
  def __init__(self):
23
- self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
24
- self.algorithm = "HS256"
25
- self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
26
- self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2))
27
-
28
  self.accounts = {}
29
- auth_accounts = os.getenv("AUTH_ACCOUNTS")
30
  if auth_accounts:
31
  for account in auth_accounts.split(","):
32
  username, password = account.split(":", 1)
 
 
1
  from datetime import datetime, timedelta
2
+
3
  import jwt
4
+ from dotenv import load_dotenv
5
  from fastapi import HTTPException, status
6
  from pydantic import BaseModel
7
+
8
+ from .config import global_args
9
 
10
  # use the .env that is inside the current folder
11
  # allows to use different .env file for each lightrag instance
 
22
 
23
  class AuthHandler:
24
  def __init__(self):
25
+ self.secret = global_args.token_secret
26
+ self.algorithm = global_args.jwt_algorithm
27
+ self.expire_hours = global_args.token_expire_hours
28
+ self.guest_expire_hours = global_args.guest_token_expire_hours
 
29
  self.accounts = {}
30
+ auth_accounts = global_args.auth_accounts
31
  if auth_accounts:
32
  for account in auth_accounts.split(","):
33
  username, password = account.split(":", 1)
lightrag/api/config.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configs for the LightRAG API.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import logging
8
+ from dotenv import load_dotenv
9
+
10
+ # use the .env that is inside the current folder
11
+ # allows to use different .env file for each lightrag instance
12
+ # the OS environment variables take precedence over the .env file
13
+ load_dotenv(dotenv_path=".env", override=False)
14
+
15
+
16
+ class OllamaServerInfos:
17
+ # Constants for emulated Ollama model information
18
+ LIGHTRAG_NAME = "lightrag"
19
+ LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
20
+ LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
21
+ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
22
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
23
+ LIGHTRAG_DIGEST = "sha256:lightrag"
24
+
25
+
26
+ ollama_server_infos = OllamaServerInfos()
27
+
28
+
29
+ class DefaultRAGStorageConfig:
30
+ KV_STORAGE = "JsonKVStorage"
31
+ VECTOR_STORAGE = "NanoVectorDBStorage"
32
+ GRAPH_STORAGE = "NetworkXStorage"
33
+ DOC_STATUS_STORAGE = "JsonDocStatusStorage"
34
+
35
+
36
+ def get_default_host(binding_type: str) -> str:
37
+ default_hosts = {
38
+ "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
39
+ "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
40
+ "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
41
+ "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
42
+ }
43
+ return default_hosts.get(
44
+ binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
45
+ ) # fallback to ollama if unknown
46
+
47
+
48
+ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
49
+ """
50
+ Get value from environment variable with type conversion
51
+
52
+ Args:
53
+ env_key (str): Environment variable key
54
+ default (any): Default value if env variable is not set
55
+ value_type (type): Type to convert the value to
56
+
57
+ Returns:
58
+ any: Converted value from environment or default
59
+ """
60
+ value = os.getenv(env_key)
61
+ if value is None:
62
+ return default
63
+
64
+ if value_type is bool:
65
+ return value.lower() in ("true", "1", "yes", "t", "on")
66
+ try:
67
+ return value_type(value)
68
+ except ValueError:
69
+ return default
70
+
71
+
72
+ def parse_args() -> argparse.Namespace:
73
+ """
74
+ Parse command line arguments with environment variable fallback
75
+
76
+ Args:
77
+ is_uvicorn_mode: Whether running under uvicorn mode
78
+
79
+ Returns:
80
+ argparse.Namespace: Parsed arguments
81
+ """
82
+
83
+ parser = argparse.ArgumentParser(
84
+ description="LightRAG FastAPI Server with separate working and input directories"
85
+ )
86
+
87
+ # Server configuration
88
+ parser.add_argument(
89
+ "--host",
90
+ default=get_env_value("HOST", "0.0.0.0"),
91
+ help="Server host (default: from env or 0.0.0.0)",
92
+ )
93
+ parser.add_argument(
94
+ "--port",
95
+ type=int,
96
+ default=get_env_value("PORT", 9621, int),
97
+ help="Server port (default: from env or 9621)",
98
+ )
99
+
100
+ # Directory configuration
101
+ parser.add_argument(
102
+ "--working-dir",
103
+ default=get_env_value("WORKING_DIR", "./rag_storage"),
104
+ help="Working directory for RAG storage (default: from env or ./rag_storage)",
105
+ )
106
+ parser.add_argument(
107
+ "--input-dir",
108
+ default=get_env_value("INPUT_DIR", "./inputs"),
109
+ help="Directory containing input documents (default: from env or ./inputs)",
110
+ )
111
+
112
+ def timeout_type(value):
113
+ if value is None:
114
+ return 150
115
+ if value is None or value == "None":
116
+ return None
117
+ return int(value)
118
+
119
+ parser.add_argument(
120
+ "--timeout",
121
+ default=get_env_value("TIMEOUT", None, timeout_type),
122
+ type=timeout_type,
123
+ help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
124
+ )
125
+
126
+ # RAG configuration
127
+ parser.add_argument(
128
+ "--max-async",
129
+ type=int,
130
+ default=get_env_value("MAX_ASYNC", 4, int),
131
+ help="Maximum async operations (default: from env or 4)",
132
+ )
133
+ parser.add_argument(
134
+ "--max-tokens",
135
+ type=int,
136
+ default=get_env_value("MAX_TOKENS", 32768, int),
137
+ help="Maximum token size (default: from env or 32768)",
138
+ )
139
+
140
+ # Logging configuration
141
+ parser.add_argument(
142
+ "--log-level",
143
+ default=get_env_value("LOG_LEVEL", "INFO"),
144
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
145
+ help="Logging level (default: from env or INFO)",
146
+ )
147
+ parser.add_argument(
148
+ "--verbose",
149
+ action="store_true",
150
+ default=get_env_value("VERBOSE", False, bool),
151
+ help="Enable verbose debug output(only valid for DEBUG log-level)",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--key",
156
+ type=str,
157
+ default=get_env_value("LIGHTRAG_API_KEY", None),
158
+ help="API key for authentication. This protects lightrag server against unauthorized access",
159
+ )
160
+
161
+ # Optional https parameters
162
+ parser.add_argument(
163
+ "--ssl",
164
+ action="store_true",
165
+ default=get_env_value("SSL", False, bool),
166
+ help="Enable HTTPS (default: from env or False)",
167
+ )
168
+ parser.add_argument(
169
+ "--ssl-certfile",
170
+ default=get_env_value("SSL_CERTFILE", None),
171
+ help="Path to SSL certificate file (required if --ssl is enabled)",
172
+ )
173
+ parser.add_argument(
174
+ "--ssl-keyfile",
175
+ default=get_env_value("SSL_KEYFILE", None),
176
+ help="Path to SSL private key file (required if --ssl is enabled)",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--history-turns",
181
+ type=int,
182
+ default=get_env_value("HISTORY_TURNS", 3, int),
183
+ help="Number of conversation history turns to include (default: from env or 3)",
184
+ )
185
+
186
+ # Search parameters
187
+ parser.add_argument(
188
+ "--top-k",
189
+ type=int,
190
+ default=get_env_value("TOP_K", 60, int),
191
+ help="Number of most similar results to return (default: from env or 60)",
192
+ )
193
+ parser.add_argument(
194
+ "--cosine-threshold",
195
+ type=float,
196
+ default=get_env_value("COSINE_THRESHOLD", 0.2, float),
197
+ help="Cosine similarity threshold (default: from env or 0.4)",
198
+ )
199
+
200
+ # Ollama model name
201
+ parser.add_argument(
202
+ "--simulated-model-name",
203
+ type=str,
204
+ default=get_env_value(
205
+ "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
206
+ ),
207
+ help="Number of conversation history turns to include (default: from env or 3)",
208
+ )
209
+
210
+ # Namespace
211
+ parser.add_argument(
212
+ "--namespace-prefix",
213
+ type=str,
214
+ default=get_env_value("NAMESPACE_PREFIX", ""),
215
+ help="Prefix of the namespace",
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--auto-scan-at-startup",
220
+ action="store_true",
221
+ default=False,
222
+ help="Enable automatic scanning when the program starts",
223
+ )
224
+
225
+ # Server workers configuration
226
+ parser.add_argument(
227
+ "--workers",
228
+ type=int,
229
+ default=get_env_value("WORKERS", 1, int),
230
+ help="Number of worker processes (default: from env or 1)",
231
+ )
232
+
233
+ # LLM and embedding bindings
234
+ parser.add_argument(
235
+ "--llm-binding",
236
+ type=str,
237
+ default=get_env_value("LLM_BINDING", "ollama"),
238
+ choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
239
+ help="LLM binding type (default: from env or ollama)",
240
+ )
241
+ parser.add_argument(
242
+ "--embedding-binding",
243
+ type=str,
244
+ default=get_env_value("EMBEDDING_BINDING", "ollama"),
245
+ choices=["lollms", "ollama", "openai", "azure_openai"],
246
+ help="Embedding binding type (default: from env or ollama)",
247
+ )
248
+
249
+ args = parser.parse_args()
250
+
251
+ # convert relative path to absolute path
252
+ args.working_dir = os.path.abspath(args.working_dir)
253
+ args.input_dir = os.path.abspath(args.input_dir)
254
+
255
+ # Inject storage configuration from environment variables
256
+ args.kv_storage = get_env_value(
257
+ "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
258
+ )
259
+ args.doc_status_storage = get_env_value(
260
+ "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
261
+ )
262
+ args.graph_storage = get_env_value(
263
+ "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
264
+ )
265
+ args.vector_storage = get_env_value(
266
+ "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
267
+ )
268
+
269
+ # Get MAX_PARALLEL_INSERT from environment
270
+ args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
271
+
272
+ # Handle openai-ollama special case
273
+ if args.llm_binding == "openai-ollama":
274
+ args.llm_binding = "openai"
275
+ args.embedding_binding = "ollama"
276
+
277
+ args.llm_binding_host = get_env_value(
278
+ "LLM_BINDING_HOST", get_default_host(args.llm_binding)
279
+ )
280
+ args.embedding_binding_host = get_env_value(
281
+ "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
282
+ )
283
+ args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
284
+ args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
285
+
286
+ # Inject model configuration
287
+ args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
288
+ args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
289
+ args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
290
+ args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
291
+
292
+ # Inject chunk configuration
293
+ args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
294
+ args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
295
+
296
+ # Inject LLM cache configuration
297
+ args.enable_llm_cache_for_extract = get_env_value(
298
+ "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
299
+ )
300
+
301
+ # Inject LLM temperature configuration
302
+ args.temperature = get_env_value("TEMPERATURE", 0.5, float)
303
+
304
+ # Select Document loading tool (DOCLING, DEFAULT)
305
+ args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
306
+
307
+ # Add environment variables that were previously read directly
308
+ args.cors_origins = get_env_value("CORS_ORIGINS", "*")
309
+ args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
310
+ args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
311
+
312
+ # For JWT Auth
313
+ args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
314
+ args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
315
+ args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
316
+ args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
317
+ args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
318
+
319
+ ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
320
+
321
+ return args
322
+
323
+
324
+ def update_uvicorn_mode_config():
325
+ # If in uvicorn mode and workers > 1, force it to 1 and log warning
326
+ if global_args.workers > 1:
327
+ original_workers = global_args.workers
328
+ global_args.workers = 1
329
+ # Log warning directly here
330
+ logging.warning(
331
+ f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
332
+ )
333
+
334
+
335
+ global_args = parse_args()
lightrag/api/lightrag_server.py CHANGED
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_combined_auth_dependency,
22
- parse_args,
23
- get_default_host,
24
  display_splash_screen,
25
  check_env_file,
26
  )
 
 
 
 
 
27
  import sys
28
  from lightrag import LightRAG, __version__ as core_version
29
  from lightrag.api import __api_version__
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
52
  # the OS environment variables take precedence over the .env file
53
  load_dotenv(dotenv_path=".env", override=False)
54
 
 
 
 
 
55
  # Initialize config parser
56
  config = configparser.ConfigParser()
57
  config.read("config.ini")
@@ -164,10 +171,10 @@ def create_app(args):
164
  app = FastAPI(**app_kwargs)
165
 
166
  def get_cors_origins():
167
- """Get allowed origins from environment variable
168
  Returns a list of allowed origins, defaults to ["*"] if not set
169
  """
170
- origins_str = os.getenv("CORS_ORIGINS", "*")
171
  if origins_str == "*":
172
  return ["*"]
173
  return [origin.strip() for origin in origins_str.split(",")]
@@ -315,9 +322,10 @@ def create_app(args):
315
  "similarity_threshold": 0.95,
316
  "use_llm_check": False,
317
  },
318
- namespace_prefix=args.namespace_prefix,
319
  auto_manage_storages_states=False,
320
  max_parallel_insert=args.max_parallel_insert,
 
321
  )
322
  else: # azure_openai
323
  rag = LightRAG(
@@ -345,9 +353,10 @@ def create_app(args):
345
  "similarity_threshold": 0.95,
346
  "use_llm_check": False,
347
  },
348
- namespace_prefix=args.namespace_prefix,
349
  auto_manage_storages_states=False,
350
  max_parallel_insert=args.max_parallel_insert,
 
351
  )
352
 
353
  # Add routes
@@ -381,6 +390,8 @@ def create_app(args):
381
  "message": "Authentication is disabled. Using guest access.",
382
  "core_version": core_version,
383
  "api_version": __api_version__,
 
 
384
  }
385
 
386
  return {
@@ -388,6 +399,8 @@ def create_app(args):
388
  "auth_mode": "enabled",
389
  "core_version": core_version,
390
  "api_version": __api_version__,
 
 
391
  }
392
 
393
  @app.post("/login")
@@ -404,6 +417,8 @@ def create_app(args):
404
  "message": "Authentication is disabled. Using guest access.",
405
  "core_version": core_version,
406
  "api_version": __api_version__,
 
 
407
  }
408
  username = form_data.username
409
  if auth_handler.accounts.get(username) != form_data.password:
@@ -454,10 +469,12 @@ def create_app(args):
454
  "vector_storage": args.vector_storage,
455
  "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
456
  },
457
- "core_version": core_version,
458
- "api_version": __api_version__,
459
  "auth_mode": auth_mode,
460
  "pipeline_busy": pipeline_status.get("busy", False),
 
 
 
 
461
  }
462
  except Exception as e:
463
  logger.error(f"Error getting health status: {str(e)}")
@@ -490,7 +507,7 @@ def create_app(args):
490
  def get_application(args=None):
491
  """Factory function for creating the FastAPI application"""
492
  if args is None:
493
- args = parse_args()
494
  return create_app(args)
495
 
496
 
@@ -611,30 +628,31 @@ def main():
611
 
612
  # Configure logging before parsing args
613
  configure_logging()
614
-
615
- args = parse_args(is_uvicorn_mode=True)
616
- display_splash_screen(args)
617
 
618
  # Create application instance directly instead of using factory function
619
- app = create_app(args)
620
 
621
  # Start Uvicorn in single process mode
622
  uvicorn_config = {
623
  "app": app, # Pass application instance directly instead of string path
624
- "host": args.host,
625
- "port": args.port,
626
  "log_config": None, # Disable default config
627
  }
628
 
629
- if args.ssl:
630
  uvicorn_config.update(
631
  {
632
- "ssl_certfile": args.ssl_certfile,
633
- "ssl_keyfile": args.ssl_keyfile,
634
  }
635
  )
636
 
637
- print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
 
 
638
  uvicorn.run(**uvicorn_config)
639
 
640
 
 
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_combined_auth_dependency,
 
 
22
  display_splash_screen,
23
  check_env_file,
24
  )
25
+ from .config import (
26
+ global_args,
27
+ update_uvicorn_mode_config,
28
+ get_default_host,
29
+ )
30
  import sys
31
  from lightrag import LightRAG, __version__ as core_version
32
  from lightrag.api import __api_version__
 
55
  # the OS environment variables take precedence over the .env file
56
  load_dotenv(dotenv_path=".env", override=False)
57
 
58
+
59
+ webui_title = os.getenv("WEBUI_TITLE")
60
+ webui_description = os.getenv("WEBUI_DESCRIPTION")
61
+
62
  # Initialize config parser
63
  config = configparser.ConfigParser()
64
  config.read("config.ini")
 
171
  app = FastAPI(**app_kwargs)
172
 
173
  def get_cors_origins():
174
+ """Get allowed origins from global_args
175
  Returns a list of allowed origins, defaults to ["*"] if not set
176
  """
177
+ origins_str = global_args.cors_origins
178
  if origins_str == "*":
179
  return ["*"]
180
  return [origin.strip() for origin in origins_str.split(",")]
 
322
  "similarity_threshold": 0.95,
323
  "use_llm_check": False,
324
  },
325
+ # namespace_prefix=args.namespace_prefix,
326
  auto_manage_storages_states=False,
327
  max_parallel_insert=args.max_parallel_insert,
328
+ addon_params={"language": args.summary_language},
329
  )
330
  else: # azure_openai
331
  rag = LightRAG(
 
353
  "similarity_threshold": 0.95,
354
  "use_llm_check": False,
355
  },
356
+ # namespace_prefix=args.namespace_prefix,
357
  auto_manage_storages_states=False,
358
  max_parallel_insert=args.max_parallel_insert,
359
+ addon_params={"language": args.summary_language},
360
  )
361
 
362
  # Add routes
 
390
  "message": "Authentication is disabled. Using guest access.",
391
  "core_version": core_version,
392
  "api_version": __api_version__,
393
+ "webui_title": webui_title,
394
+ "webui_description": webui_description,
395
  }
396
 
397
  return {
 
399
  "auth_mode": "enabled",
400
  "core_version": core_version,
401
  "api_version": __api_version__,
402
+ "webui_title": webui_title,
403
+ "webui_description": webui_description,
404
  }
405
 
406
  @app.post("/login")
 
417
  "message": "Authentication is disabled. Using guest access.",
418
  "core_version": core_version,
419
  "api_version": __api_version__,
420
+ "webui_title": webui_title,
421
+ "webui_description": webui_description,
422
  }
423
  username = form_data.username
424
  if auth_handler.accounts.get(username) != form_data.password:
 
469
  "vector_storage": args.vector_storage,
470
  "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
471
  },
 
 
472
  "auth_mode": auth_mode,
473
  "pipeline_busy": pipeline_status.get("busy", False),
474
+ "core_version": core_version,
475
+ "api_version": __api_version__,
476
+ "webui_title": webui_title,
477
+ "webui_description": webui_description,
478
  }
479
  except Exception as e:
480
  logger.error(f"Error getting health status: {str(e)}")
 
507
  def get_application(args=None):
508
  """Factory function for creating the FastAPI application"""
509
  if args is None:
510
+ args = global_args
511
  return create_app(args)
512
 
513
 
 
628
 
629
  # Configure logging before parsing args
630
  configure_logging()
631
+ update_uvicorn_mode_config()
632
+ display_splash_screen(global_args)
 
633
 
634
  # Create application instance directly instead of using factory function
635
+ app = create_app(global_args)
636
 
637
  # Start Uvicorn in single process mode
638
  uvicorn_config = {
639
  "app": app, # Pass application instance directly instead of string path
640
+ "host": global_args.host,
641
+ "port": global_args.port,
642
  "log_config": None, # Disable default config
643
  }
644
 
645
+ if global_args.ssl:
646
  uvicorn_config.update(
647
  {
648
+ "ssl_certfile": global_args.ssl_certfile,
649
+ "ssl_keyfile": global_args.ssl_keyfile,
650
  }
651
  )
652
 
653
+ print(
654
+ f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
655
+ )
656
  uvicorn.run(**uvicorn_config)
657
 
658
 
lightrag/api/routers/document_routes.py CHANGED
@@ -10,16 +10,14 @@ import traceback
10
  import pipmaster as pm
11
  from datetime import datetime
12
  from pathlib import Path
13
- from typing import Dict, List, Optional, Any
14
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
- from lightrag.api.utils_api import (
20
- get_combined_auth_dependency,
21
- global_args,
22
- )
23
 
24
  router = APIRouter(
25
  prefix="/documents",
@@ -30,7 +28,37 @@ router = APIRouter(
30
  temp_prefix = "__tmp__"
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class InsertTextRequest(BaseModel):
 
 
 
 
 
 
34
  text: str = Field(
35
  min_length=1,
36
  description="The text to insert",
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
41
  def strip_after(cls, text: str) -> str:
42
  return text.strip()
43
 
 
 
 
 
 
 
 
44
 
45
  class InsertTextsRequest(BaseModel):
 
 
 
 
 
 
46
  texts: list[str] = Field(
47
  min_length=1,
48
  description="The texts to insert",
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
53
  def strip_after(cls, texts: list[str]) -> list[str]:
54
  return [text.strip() for text in texts]
55
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  class InsertResponse(BaseModel):
58
- status: str = Field(description="Status of the operation")
 
 
 
 
 
 
 
 
 
59
  message: str = Field(description="Message describing the operation result")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class DocStatusResponse(BaseModel):
63
  @staticmethod
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
68
  return dt
69
  return dt.isoformat()
70
 
71
- """Response model for document status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  Attributes:
74
- id: Document identifier
75
- content_summary: Summary of document content
76
- content_length: Length of document content
77
- status: Current processing status
78
- created_at: Creation timestamp (ISO format string)
79
- updated_at: Last update timestamp (ISO format string)
80
- chunks_count: Number of chunks (optional)
81
- error: Error message if any (optional)
82
- metadata: Additional metadata (optional)
83
  """
84
 
85
- id: str
86
- content_summary: str
87
- content_length: int
88
- status: DocStatus
89
- created_at: str
90
- updated_at: str
91
- chunks_count: Optional[int] = None
92
- error: Optional[str] = None
93
- metadata: Optional[dict[str, Any]] = None
94
- file_path: str
95
-
96
 
97
- class DocsStatusesResponse(BaseModel):
98
- statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  class PipelineStatusResponse(BaseModel):
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
276
  )
277
  return False
278
  case ".pdf":
279
- if global_args["main_args"].document_loading_engine == "DOCLING":
280
  if not pm.is_installed("docling"): # type: ignore
281
  pm.install("docling")
282
  from docling.document_converter import DocumentConverter # type: ignore
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
295
  for page in reader.pages:
296
  content += page.extract_text() + "\n"
297
  case ".docx":
298
- if global_args["main_args"].document_loading_engine == "DOCLING":
299
  if not pm.is_installed("docling"): # type: ignore
300
  pm.install("docling")
301
  from docling.document_converter import DocumentConverter # type: ignore
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
315
  [paragraph.text for paragraph in doc.paragraphs]
316
  )
317
  case ".pptx":
318
- if global_args["main_args"].document_loading_engine == "DOCLING":
319
  if not pm.is_installed("docling"): # type: ignore
320
  pm.install("docling")
321
  from docling.document_converter import DocumentConverter # type: ignore
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
336
  if hasattr(shape, "text"):
337
  content += shape.text + "\n"
338
  case ".xlsx":
339
- if global_args["main_args"].document_loading_engine == "DOCLING":
340
  if not pm.is_installed("docling"): # type: ignore
341
  pm.install("docling")
342
  from docling.document_converter import DocumentConverter # type: ignore
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
443
  await rag.apipeline_process_enqueue_documents()
444
 
445
 
 
446
  async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
447
  """Save the uploaded file to a temporary location
448
 
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
476
  if not new_files:
477
  return
478
 
479
- # Get MAX_PARALLEL_INSERT from global_args["main_args"]
480
- max_parallel = global_args["main_args"].max_parallel_insert
481
  # Calculate batch size as 2 * MAX_PARALLEL_INSERT
482
  batch_size = 2 * max_parallel
483
 
@@ -509,7 +704,9 @@ def create_document_routes(
509
  # Create combined auth dependency for document routes
510
  combined_auth = get_combined_auth_dependency(api_key)
511
 
512
- @router.post("/scan", dependencies=[Depends(combined_auth)])
 
 
513
  async def scan_for_new_documents(background_tasks: BackgroundTasks):
514
  """
515
  Trigger the scanning process for new documents.
@@ -519,13 +716,18 @@ def create_document_routes(
519
  that fact.
520
 
521
  Returns:
522
- dict: A dictionary containing the scanning status
523
  """
524
  # Start the scanning process in the background
525
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
526
- return {"status": "scanning_started"}
 
 
 
527
 
528
- @router.post("/upload", dependencies=[Depends(combined_auth)])
 
 
529
  async def upload_to_input_dir(
530
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
531
  ):
@@ -645,6 +847,7 @@ def create_document_routes(
645
  logger.error(traceback.format_exc())
646
  raise HTTPException(status_code=500, detail=str(e))
647
 
 
648
  @router.post(
649
  "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
650
  )
@@ -688,6 +891,7 @@ def create_document_routes(
688
  logger.error(traceback.format_exc())
689
  raise HTTPException(status_code=500, detail=str(e))
690
 
 
691
  @router.post(
692
  "/file_batch",
693
  response_model=InsertResponse,
@@ -752,32 +956,186 @@ def create_document_routes(
752
  raise HTTPException(status_code=500, detail=str(e))
753
 
754
  @router.delete(
755
- "", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
756
  )
757
  async def clear_documents():
758
  """
759
  Clear all documents from the RAG system.
760
 
761
- This endpoint deletes all text chunks, entities vector database, and relationships
762
- vector database, effectively clearing all documents from the RAG system.
 
763
 
764
  Returns:
765
- InsertResponse: A response object containing the status and message.
 
 
 
 
 
 
766
 
767
  Raises:
768
- HTTPException: If an error occurs during the clearing process (500).
 
769
  """
770
- try:
771
- rag.text_chunks = []
772
- rag.entities_vdb = None
773
- rag.relationships_vdb = None
774
- return InsertResponse(
775
- status="success", message="All documents cleared successfully"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  except Exception as e:
778
- logger.error(f"Error DELETE /documents: {str(e)}")
 
779
  logger.error(traceback.format_exc())
 
 
780
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
781
 
782
  @router.get(
783
  "/pipeline_status",
@@ -850,7 +1208,9 @@ def create_document_routes(
850
  logger.error(traceback.format_exc())
851
  raise HTTPException(status_code=500, detail=str(e))
852
 
853
- @router.get("", dependencies=[Depends(combined_auth)])
 
 
854
  async def documents() -> DocsStatusesResponse:
855
  """
856
  Get the status of all documents in the system.
@@ -908,4 +1268,57 @@ def create_document_routes(
908
  logger.error(traceback.format_exc())
909
  raise HTTPException(status_code=500, detail=str(e))
910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  return router
 
10
  import pipmaster as pm
11
  from datetime import datetime
12
  from pathlib import Path
13
+ from typing import Dict, List, Optional, Any, Literal
14
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
+ from lightrag.api.utils_api import get_combined_auth_dependency
20
+ from ..config import global_args
 
 
21
 
22
  router = APIRouter(
23
  prefix="/documents",
 
28
  temp_prefix = "__tmp__"
29
 
30
 
31
+ class ScanResponse(BaseModel):
32
+ """Response model for document scanning operation
33
+
34
+ Attributes:
35
+ status: Status of the scanning operation
36
+ message: Optional message with additional details
37
+ """
38
+
39
+ status: Literal["scanning_started"] = Field(
40
+ description="Status of the scanning operation"
41
+ )
42
+ message: Optional[str] = Field(
43
+ default=None, description="Additional details about the scanning operation"
44
+ )
45
+
46
+ class Config:
47
+ json_schema_extra = {
48
+ "example": {
49
+ "status": "scanning_started",
50
+ "message": "Scanning process has been initiated in the background",
51
+ }
52
+ }
53
+
54
+
55
  class InsertTextRequest(BaseModel):
56
+ """Request model for inserting a single text document
57
+
58
+ Attributes:
59
+ text: The text content to be inserted into the RAG system
60
+ """
61
+
62
  text: str = Field(
63
  min_length=1,
64
  description="The text to insert",
 
69
  def strip_after(cls, text: str) -> str:
70
  return text.strip()
71
 
72
+ class Config:
73
+ json_schema_extra = {
74
+ "example": {
75
+ "text": "This is a sample text to be inserted into the RAG system."
76
+ }
77
+ }
78
+
79
 
80
  class InsertTextsRequest(BaseModel):
81
+ """Request model for inserting multiple text documents
82
+
83
+ Attributes:
84
+ texts: List of text contents to be inserted into the RAG system
85
+ """
86
+
87
  texts: list[str] = Field(
88
  min_length=1,
89
  description="The texts to insert",
 
94
  def strip_after(cls, texts: list[str]) -> list[str]:
95
  return [text.strip() for text in texts]
96
 
97
+ class Config:
98
+ json_schema_extra = {
99
+ "example": {
100
+ "texts": [
101
+ "This is the first text to be inserted.",
102
+ "This is the second text to be inserted.",
103
+ ]
104
+ }
105
+ }
106
+
107
 
108
  class InsertResponse(BaseModel):
109
+ """Response model for document insertion operations
110
+
111
+ Attributes:
112
+ status: Status of the operation (success, duplicated, partial_success, failure)
113
+ message: Detailed message describing the operation result
114
+ """
115
+
116
+ status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
117
+ description="Status of the operation"
118
+ )
119
  message: str = Field(description="Message describing the operation result")
120
 
121
+ class Config:
122
+ json_schema_extra = {
123
+ "example": {
124
+ "status": "success",
125
+ "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
126
+ }
127
+ }
128
+
129
+
130
+ class ClearDocumentsResponse(BaseModel):
131
+ """Response model for document clearing operation
132
+
133
+ Attributes:
134
+ status: Status of the clear operation
135
+ message: Detailed message describing the operation result
136
+ """
137
+
138
+ status: Literal["success", "partial_success", "busy", "fail"] = Field(
139
+ description="Status of the clear operation"
140
+ )
141
+ message: str = Field(description="Message describing the operation result")
142
+
143
+ class Config:
144
+ json_schema_extra = {
145
+ "example": {
146
+ "status": "success",
147
+ "message": "All documents cleared successfully. Deleted 15 files.",
148
+ }
149
+ }
150
+
151
+
152
+ class ClearCacheRequest(BaseModel):
153
+ """Request model for clearing cache
154
+
155
+ Attributes:
156
+ modes: Optional list of cache modes to clear
157
+ """
158
+
159
+ modes: Optional[
160
+ List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
161
+ ] = Field(
162
+ default=None,
163
+ description="Modes of cache to clear. If None, clears all cache.",
164
+ )
165
+
166
+ class Config:
167
+ json_schema_extra = {"example": {"modes": ["default", "naive"]}}
168
+
169
+
170
+ class ClearCacheResponse(BaseModel):
171
+ """Response model for cache clearing operation
172
+
173
+ Attributes:
174
+ status: Status of the clear operation
175
+ message: Detailed message describing the operation result
176
+ """
177
+
178
+ status: Literal["success", "fail"] = Field(
179
+ description="Status of the clear operation"
180
+ )
181
+ message: str = Field(description="Message describing the operation result")
182
+
183
+ class Config:
184
+ json_schema_extra = {
185
+ "example": {
186
+ "status": "success",
187
+ "message": "Successfully cleared cache for modes: ['default', 'naive']",
188
+ }
189
+ }
190
+
191
+
192
+ """Response model for document status
193
+
194
+ Attributes:
195
+ id: Document identifier
196
+ content_summary: Summary of document content
197
+ content_length: Length of document content
198
+ status: Current processing status
199
+ created_at: Creation timestamp (ISO format string)
200
+ updated_at: Last update timestamp (ISO format string)
201
+ chunks_count: Number of chunks (optional)
202
+ error: Error message if any (optional)
203
+ metadata: Additional metadata (optional)
204
+ file_path: Path to the document file
205
+ """
206
+
207
 
208
  class DocStatusResponse(BaseModel):
209
  @staticmethod
 
214
  return dt
215
  return dt.isoformat()
216
 
217
+ id: str = Field(description="Document identifier")
218
+ content_summary: str = Field(description="Summary of document content")
219
+ content_length: int = Field(description="Length of document content in characters")
220
+ status: DocStatus = Field(description="Current processing status")
221
+ created_at: str = Field(description="Creation timestamp (ISO format string)")
222
+ updated_at: str = Field(description="Last update timestamp (ISO format string)")
223
+ chunks_count: Optional[int] = Field(
224
+ default=None, description="Number of chunks the document was split into"
225
+ )
226
+ error: Optional[str] = Field(
227
+ default=None, description="Error message if processing failed"
228
+ )
229
+ metadata: Optional[dict[str, Any]] = Field(
230
+ default=None, description="Additional metadata about the document"
231
+ )
232
+ file_path: str = Field(description="Path to the document file")
233
+
234
+ class Config:
235
+ json_schema_extra = {
236
+ "example": {
237
+ "id": "doc_123456",
238
+ "content_summary": "Research paper on machine learning",
239
+ "content_length": 15240,
240
+ "status": "PROCESSED",
241
+ "created_at": "2025-03-31T12:34:56",
242
+ "updated_at": "2025-03-31T12:35:30",
243
+ "chunks_count": 12,
244
+ "error": None,
245
+ "metadata": {"author": "John Doe", "year": 2025},
246
+ "file_path": "research_paper.pdf",
247
+ }
248
+ }
249
+
250
+
251
+ class DocsStatusesResponse(BaseModel):
252
+ """Response model for document statuses
253
 
254
  Attributes:
255
+ statuses: Dictionary mapping document status to lists of document status responses
 
 
 
 
 
 
 
 
256
  """
257
 
258
+ statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
259
+ default_factory=dict,
260
+ description="Dictionary mapping document status to lists of document status responses",
261
+ )
 
 
 
 
 
 
 
262
 
263
+ class Config:
264
+ json_schema_extra = {
265
+ "example": {
266
+ "statuses": {
267
+ "PENDING": [
268
+ {
269
+ "id": "doc_123",
270
+ "content_summary": "Pending document",
271
+ "content_length": 5000,
272
+ "status": "PENDING",
273
+ "created_at": "2025-03-31T10:00:00",
274
+ "updated_at": "2025-03-31T10:00:00",
275
+ "file_path": "pending_doc.pdf",
276
+ }
277
+ ],
278
+ "PROCESSED": [
279
+ {
280
+ "id": "doc_456",
281
+ "content_summary": "Processed document",
282
+ "content_length": 8000,
283
+ "status": "PROCESSED",
284
+ "created_at": "2025-03-31T09:00:00",
285
+ "updated_at": "2025-03-31T09:05:00",
286
+ "chunks_count": 8,
287
+ "file_path": "processed_doc.pdf",
288
+ }
289
+ ],
290
+ }
291
+ }
292
+ }
293
 
294
 
295
  class PipelineStatusResponse(BaseModel):
 
470
  )
471
  return False
472
  case ".pdf":
473
+ if global_args.document_loading_engine == "DOCLING":
474
  if not pm.is_installed("docling"): # type: ignore
475
  pm.install("docling")
476
  from docling.document_converter import DocumentConverter # type: ignore
 
489
  for page in reader.pages:
490
  content += page.extract_text() + "\n"
491
  case ".docx":
492
+ if global_args.document_loading_engine == "DOCLING":
493
  if not pm.is_installed("docling"): # type: ignore
494
  pm.install("docling")
495
  from docling.document_converter import DocumentConverter # type: ignore
 
509
  [paragraph.text for paragraph in doc.paragraphs]
510
  )
511
  case ".pptx":
512
+ if global_args.document_loading_engine == "DOCLING":
513
  if not pm.is_installed("docling"): # type: ignore
514
  pm.install("docling")
515
  from docling.document_converter import DocumentConverter # type: ignore
 
530
  if hasattr(shape, "text"):
531
  content += shape.text + "\n"
532
  case ".xlsx":
533
+ if global_args.document_loading_engine == "DOCLING":
534
  if not pm.is_installed("docling"): # type: ignore
535
  pm.install("docling")
536
  from docling.document_converter import DocumentConverter # type: ignore
 
637
  await rag.apipeline_process_enqueue_documents()
638
 
639
 
640
+ # TODO: deprecate after /insert_file is removed
641
  async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
642
  """Save the uploaded file to a temporary location
643
 
 
671
  if not new_files:
672
  return
673
 
674
+ # Get MAX_PARALLEL_INSERT from global_args
675
+ max_parallel = global_args.max_parallel_insert
676
  # Calculate batch size as 2 * MAX_PARALLEL_INSERT
677
  batch_size = 2 * max_parallel
678
 
 
704
  # Create combined auth dependency for document routes
705
  combined_auth = get_combined_auth_dependency(api_key)
706
 
707
+ @router.post(
708
+ "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
709
+ )
710
  async def scan_for_new_documents(background_tasks: BackgroundTasks):
711
  """
712
  Trigger the scanning process for new documents.
 
716
  that fact.
717
 
718
  Returns:
719
+ ScanResponse: A response object containing the scanning status
720
  """
721
  # Start the scanning process in the background
722
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
723
+ return ScanResponse(
724
+ status="scanning_started",
725
+ message="Scanning process has been initiated in the background",
726
+ )
727
 
728
+ @router.post(
729
+ "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
730
+ )
731
  async def upload_to_input_dir(
732
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
733
  ):
 
847
  logger.error(traceback.format_exc())
848
  raise HTTPException(status_code=500, detail=str(e))
849
 
850
+ # TODO: deprecated, use /upload instead
851
  @router.post(
852
  "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
853
  )
 
891
  logger.error(traceback.format_exc())
892
  raise HTTPException(status_code=500, detail=str(e))
893
 
894
+ # TODO: deprecated, use /upload instead
895
  @router.post(
896
  "/file_batch",
897
  response_model=InsertResponse,
 
956
  raise HTTPException(status_code=500, detail=str(e))
957
 
958
  @router.delete(
959
+ "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
960
  )
961
  async def clear_documents():
962
  """
963
  Clear all documents from the RAG system.
964
 
965
+ This endpoint deletes all documents, entities, relationships, and files from the system.
966
+ It uses the storage drop methods to properly clean up all data and removes all files
967
+ from the input directory.
968
 
969
  Returns:
970
+ ClearDocumentsResponse: A response object containing the status and message.
971
+ - status="success": All documents and files were successfully cleared.
972
+ - status="partial_success": Document clear job exit with some errors.
973
+ - status="busy": Operation could not be completed because the pipeline is busy.
974
+ - status="fail": All storage drop operations failed, with message
975
+ - message: Detailed information about the operation results, including counts
976
+ of deleted files and any errors encountered.
977
 
978
  Raises:
979
+ HTTPException: Raised when a serious error occurs during the clearing process,
980
+ with status code 500 and error details in the detail field.
981
  """
982
+ from lightrag.kg.shared_storage import (
983
+ get_namespace_data,
984
+ get_pipeline_status_lock,
985
+ )
986
+
987
+ # Get pipeline status and lock
988
+ pipeline_status = await get_namespace_data("pipeline_status")
989
+ pipeline_status_lock = get_pipeline_status_lock()
990
+
991
+ # Check and set status with lock
992
+ async with pipeline_status_lock:
993
+ if pipeline_status.get("busy", False):
994
+ return ClearDocumentsResponse(
995
+ status="busy",
996
+ message="Cannot clear documents while pipeline is busy",
997
+ )
998
+ # Set busy to true
999
+ pipeline_status.update(
1000
+ {
1001
+ "busy": True,
1002
+ "job_name": "Clearing Documents",
1003
+ "job_start": datetime.now().isoformat(),
1004
+ "docs": 0,
1005
+ "batchs": 0,
1006
+ "cur_batch": 0,
1007
+ "request_pending": False, # Clear any previous request
1008
+ "latest_message": "Starting document clearing process",
1009
+ }
1010
  )
1011
+ # Cleaning history_messages without breaking it as a shared list object
1012
+ del pipeline_status["history_messages"][:]
1013
+ pipeline_status["history_messages"].append(
1014
+ "Starting document clearing process"
1015
+ )
1016
+
1017
+ try:
1018
+ # Use drop method to clear all data
1019
+ drop_tasks = []
1020
+ storages = [
1021
+ rag.text_chunks,
1022
+ rag.full_docs,
1023
+ rag.entities_vdb,
1024
+ rag.relationships_vdb,
1025
+ rag.chunks_vdb,
1026
+ rag.chunk_entity_relation_graph,
1027
+ rag.doc_status,
1028
+ ]
1029
+
1030
+ # Log storage drop start
1031
+ if "history_messages" in pipeline_status:
1032
+ pipeline_status["history_messages"].append(
1033
+ "Starting to drop storage components"
1034
+ )
1035
+
1036
+ for storage in storages:
1037
+ if storage is not None:
1038
+ drop_tasks.append(storage.drop())
1039
+
1040
+ # Wait for all drop tasks to complete
1041
+ drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
1042
+
1043
+ # Check for errors and log results
1044
+ errors = []
1045
+ storage_success_count = 0
1046
+ storage_error_count = 0
1047
+
1048
+ for i, result in enumerate(drop_results):
1049
+ storage_name = storages[i].__class__.__name__
1050
+ if isinstance(result, Exception):
1051
+ error_msg = f"Error dropping {storage_name}: {str(result)}"
1052
+ errors.append(error_msg)
1053
+ logger.error(error_msg)
1054
+ storage_error_count += 1
1055
+ else:
1056
+ logger.info(f"Successfully dropped {storage_name}")
1057
+ storage_success_count += 1
1058
+
1059
+ # Log storage drop results
1060
+ if "history_messages" in pipeline_status:
1061
+ if storage_error_count > 0:
1062
+ pipeline_status["history_messages"].append(
1063
+ f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
1064
+ )
1065
+ else:
1066
+ pipeline_status["history_messages"].append(
1067
+ f"Successfully dropped all {storage_success_count} storage components"
1068
+ )
1069
+
1070
+ # If all storage operations failed, return error status and don't proceed with file deletion
1071
+ if storage_success_count == 0 and storage_error_count > 0:
1072
+ error_message = "All storage drop operations failed. Aborting document clearing process."
1073
+ logger.error(error_message)
1074
+ if "history_messages" in pipeline_status:
1075
+ pipeline_status["history_messages"].append(error_message)
1076
+ return ClearDocumentsResponse(status="fail", message=error_message)
1077
+
1078
+ # Log file deletion start
1079
+ if "history_messages" in pipeline_status:
1080
+ pipeline_status["history_messages"].append(
1081
+ "Starting to delete files in input directory"
1082
+ )
1083
+
1084
+ # Delete all files in input_dir
1085
+ deleted_files_count = 0
1086
+ file_errors_count = 0
1087
+
1088
+ for file_path in doc_manager.input_dir.glob("**/*"):
1089
+ if file_path.is_file():
1090
+ try:
1091
+ file_path.unlink()
1092
+ deleted_files_count += 1
1093
+ except Exception as e:
1094
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
1095
+ file_errors_count += 1
1096
+
1097
+ # Log file deletion results
1098
+ if "history_messages" in pipeline_status:
1099
+ if file_errors_count > 0:
1100
+ pipeline_status["history_messages"].append(
1101
+ f"Deleted {deleted_files_count} files with {file_errors_count} errors"
1102
+ )
1103
+ errors.append(f"Failed to delete {file_errors_count} files")
1104
+ else:
1105
+ pipeline_status["history_messages"].append(
1106
+ f"Successfully deleted {deleted_files_count} files"
1107
+ )
1108
+
1109
+ # Prepare final result message
1110
+ final_message = ""
1111
+ if errors:
1112
+ final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
1113
+ status = "partial_success"
1114
+ else:
1115
+ final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
1116
+ status = "success"
1117
+
1118
+ # Log final result
1119
+ if "history_messages" in pipeline_status:
1120
+ pipeline_status["history_messages"].append(final_message)
1121
+
1122
+ # Return response based on results
1123
+ return ClearDocumentsResponse(status=status, message=final_message)
1124
  except Exception as e:
1125
+ error_msg = f"Error clearing documents: {str(e)}"
1126
+ logger.error(error_msg)
1127
  logger.error(traceback.format_exc())
1128
+ if "history_messages" in pipeline_status:
1129
+ pipeline_status["history_messages"].append(error_msg)
1130
  raise HTTPException(status_code=500, detail=str(e))
1131
+ finally:
1132
+ # Reset busy status after completion
1133
+ async with pipeline_status_lock:
1134
+ pipeline_status["busy"] = False
1135
+ completion_msg = "Document clearing process completed"
1136
+ pipeline_status["latest_message"] = completion_msg
1137
+ if "history_messages" in pipeline_status:
1138
+ pipeline_status["history_messages"].append(completion_msg)
1139
 
1140
  @router.get(
1141
  "/pipeline_status",
 
1208
  logger.error(traceback.format_exc())
1209
  raise HTTPException(status_code=500, detail=str(e))
1210
 
1211
+ @router.get(
1212
+ "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
1213
+ )
1214
  async def documents() -> DocsStatusesResponse:
1215
  """
1216
  Get the status of all documents in the system.
 
1268
  logger.error(traceback.format_exc())
1269
  raise HTTPException(status_code=500, detail=str(e))
1270
 
1271
+ @router.post(
1272
+ "/clear_cache",
1273
+ response_model=ClearCacheResponse,
1274
+ dependencies=[Depends(combined_auth)],
1275
+ )
1276
+ async def clear_cache(request: ClearCacheRequest):
1277
+ """
1278
+ Clear cache data from the LLM response cache storage.
1279
+
1280
+ This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
1281
+ Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
1282
+ - "default" represents extraction cache.
1283
+ - Other modes correspond to different query modes.
1284
+
1285
+ Args:
1286
+ request (ClearCacheRequest): The request body containing optional modes to clear.
1287
+
1288
+ Returns:
1289
+ ClearCacheResponse: A response object containing the status and message.
1290
+
1291
+ Raises:
1292
+ HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
1293
+ """
1294
+ try:
1295
+ # Validate modes if provided
1296
+ valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
1297
+ if request.modes and not all(mode in valid_modes for mode in request.modes):
1298
+ invalid_modes = [
1299
+ mode for mode in request.modes if mode not in valid_modes
1300
+ ]
1301
+ raise HTTPException(
1302
+ status_code=400,
1303
+ detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
1304
+ )
1305
+
1306
+ # Call the aclear_cache method
1307
+ await rag.aclear_cache(request.modes)
1308
+
1309
+ # Prepare success message
1310
+ if request.modes:
1311
+ message = f"Successfully cleared cache for modes: {request.modes}"
1312
+ else:
1313
+ message = "Successfully cleared all cache"
1314
+
1315
+ return ClearCacheResponse(status="success", message=message)
1316
+ except HTTPException:
1317
+ # Re-raise HTTP exceptions
1318
+ raise
1319
+ except Exception as e:
1320
+ logger.error(f"Error clearing cache: {str(e)}")
1321
+ logger.error(traceback.format_exc())
1322
+ raise HTTPException(status_code=500, detail=str(e))
1323
+
1324
  return router
lightrag/api/routers/graph_routes.py CHANGED
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
3
  """
4
 
5
  from typing import Optional
6
- from fastapi import APIRouter, Depends
7
 
8
  from ..utils_api import get_combined_auth_dependency
9
 
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
25
 
26
  @router.get("/graphs", dependencies=[Depends(combined_auth)])
27
  async def get_knowledge_graph(
28
- label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
 
 
29
  ):
30
  """
31
  Retrieve a connected subgraph of nodes where the label includes the specified label.
32
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
33
  When reducing the number of nodes, the prioritization criteria are as follows:
34
- 1. min_degree does not affect nodes directly connected to the matching nodes
35
- 2. Label matching nodes take precedence
36
- 3. Followed by nodes directly connected to the matching nodes
37
- 4. Finally, the degree of the nodes
38
- Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
39
 
40
  Args:
41
- label (str): Label to get knowledge graph for
42
- max_depth (int, optional): Maximum depth of graph. Defaults to 3.
43
- inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
44
- min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
45
 
46
  Returns:
47
  Dict[str, List[str]]: Knowledge graph for label
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
49
  return await rag.get_knowledge_graph(
50
  node_label=label,
51
  max_depth=max_depth,
52
- inclusive=inclusive,
53
- min_degree=min_degree,
54
  )
55
 
56
  return router
 
3
  """
4
 
5
  from typing import Optional
6
+ from fastapi import APIRouter, Depends, Query
7
 
8
  from ..utils_api import get_combined_auth_dependency
9
 
 
25
 
26
  @router.get("/graphs", dependencies=[Depends(combined_auth)])
27
  async def get_knowledge_graph(
28
+ label: str = Query(..., description="Label to get knowledge graph for"),
29
+ max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
30
+ max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
31
  ):
32
  """
33
  Retrieve a connected subgraph of nodes where the label includes the specified label.
 
34
  When reducing the number of nodes, the prioritization criteria are as follows:
35
+ 1. Hops(path) to the staring node take precedence
36
+ 2. Followed by the degree of the nodes
 
 
 
37
 
38
  Args:
39
+ label (str): Label of the starting node
40
+ max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
41
+ max_nodes: Maxiumu nodes to return
 
42
 
43
  Returns:
44
  Dict[str, List[str]]: Knowledge graph for label
 
46
  return await rag.get_knowledge_graph(
47
  node_label=label,
48
  max_depth=max_depth,
49
+ max_nodes=max_nodes,
 
50
  )
51
 
52
  return router
lightrag/api/run_with_gunicorn.py CHANGED
@@ -7,14 +7,9 @@ import os
7
  import sys
8
  import signal
9
  import pipmaster as pm
10
- from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file
11
  from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
- from dotenv import load_dotenv
13
-
14
- # use the .env that is inside the current folder
15
- # allows to use different .env file for each lightrag instance
16
- # the OS environment variables take precedence over the .env file
17
- load_dotenv(dotenv_path=".env", override=False)
18
 
19
 
20
  def check_and_install_dependencies():
@@ -59,20 +54,17 @@ def main():
59
  signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
60
  signal.signal(signal.SIGTERM, signal_handler) # kill command
61
 
62
- # Parse all arguments using parse_args
63
- args = parse_args(is_uvicorn_mode=False)
64
-
65
  # Display startup information
66
- display_splash_screen(args)
67
 
68
  print("🚀 Starting LightRAG with Gunicorn")
69
- print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
70
  print("🔍 Preloading app: Enabled")
71
  print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
72
  print("\n\n" + "=" * 80)
73
  print("MAIN PROCESS INITIALIZATION")
74
  print(f"Process ID: {os.getpid()}")
75
- print(f"Workers setting: {args.workers}")
76
  print("=" * 80 + "\n")
77
 
78
  # Import Gunicorn's StandaloneApplication
@@ -128,31 +120,43 @@ def main():
128
 
129
  # Set configuration variables in gunicorn_config, prioritizing command line arguments
130
  gunicorn_config.workers = (
131
- args.workers if args.workers else int(os.getenv("WORKERS", 1))
 
 
132
  )
133
 
134
  # Bind configuration prioritizes command line arguments
135
- host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
136
- port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
 
 
 
 
 
 
 
 
137
  gunicorn_config.bind = f"{host}:{port}"
138
 
139
  # Log level configuration prioritizes command line arguments
140
  gunicorn_config.loglevel = (
141
- args.log_level.lower()
142
- if args.log_level
143
  else os.getenv("LOG_LEVEL", "info")
144
  )
145
 
146
  # Timeout configuration prioritizes command line arguments
147
  gunicorn_config.timeout = (
148
- args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
 
 
149
  )
150
 
151
  # Keepalive configuration
152
  gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
153
 
154
  # SSL configuration prioritizes command line arguments
155
- if args.ssl or os.getenv("SSL", "").lower() in (
156
  "true",
157
  "1",
158
  "yes",
@@ -160,12 +164,14 @@ def main():
160
  "on",
161
  ):
162
  gunicorn_config.certfile = (
163
- args.ssl_certfile
164
- if args.ssl_certfile
165
  else os.getenv("SSL_CERTFILE")
166
  )
167
  gunicorn_config.keyfile = (
168
- args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
 
 
169
  )
170
 
171
  # Set configuration options from the module
@@ -190,13 +196,13 @@ def main():
190
  # Import the application
191
  from lightrag.api.lightrag_server import get_application
192
 
193
- return get_application(args)
194
 
195
  # Create the application
196
  app = GunicornApp("")
197
 
198
  # Force workers to be an integer and greater than 1 for multi-process mode
199
- workers_count = int(args.workers)
200
  if workers_count > 1:
201
  # Set a flag to indicate we're in the main process
202
  os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
 
7
  import sys
8
  import signal
9
  import pipmaster as pm
10
+ from lightrag.api.utils_api import display_splash_screen, check_env_file
11
  from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
+ from .config import global_args
 
 
 
 
 
13
 
14
 
15
  def check_and_install_dependencies():
 
54
  signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
55
  signal.signal(signal.SIGTERM, signal_handler) # kill command
56
 
 
 
 
57
  # Display startup information
58
+ display_splash_screen(global_args)
59
 
60
  print("🚀 Starting LightRAG with Gunicorn")
61
+ print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
62
  print("🔍 Preloading app: Enabled")
63
  print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
64
  print("\n\n" + "=" * 80)
65
  print("MAIN PROCESS INITIALIZATION")
66
  print(f"Process ID: {os.getpid()}")
67
+ print(f"Workers setting: {global_args.workers}")
68
  print("=" * 80 + "\n")
69
 
70
  # Import Gunicorn's StandaloneApplication
 
120
 
121
  # Set configuration variables in gunicorn_config, prioritizing command line arguments
122
  gunicorn_config.workers = (
123
+ global_args.workers
124
+ if global_args.workers
125
+ else int(os.getenv("WORKERS", 1))
126
  )
127
 
128
  # Bind configuration prioritizes command line arguments
129
+ host = (
130
+ global_args.host
131
+ if global_args.host != "0.0.0.0"
132
+ else os.getenv("HOST", "0.0.0.0")
133
+ )
134
+ port = (
135
+ global_args.port
136
+ if global_args.port != 9621
137
+ else int(os.getenv("PORT", 9621))
138
+ )
139
  gunicorn_config.bind = f"{host}:{port}"
140
 
141
  # Log level configuration prioritizes command line arguments
142
  gunicorn_config.loglevel = (
143
+ global_args.log_level.lower()
144
+ if global_args.log_level
145
  else os.getenv("LOG_LEVEL", "info")
146
  )
147
 
148
  # Timeout configuration prioritizes command line arguments
149
  gunicorn_config.timeout = (
150
+ global_args.timeout
151
+ if global_args.timeout * 2
152
+ else int(os.getenv("TIMEOUT", 150 * 2))
153
  )
154
 
155
  # Keepalive configuration
156
  gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
157
 
158
  # SSL configuration prioritizes command line arguments
159
+ if global_args.ssl or os.getenv("SSL", "").lower() in (
160
  "true",
161
  "1",
162
  "yes",
 
164
  "on",
165
  ):
166
  gunicorn_config.certfile = (
167
+ global_args.ssl_certfile
168
+ if global_args.ssl_certfile
169
  else os.getenv("SSL_CERTFILE")
170
  )
171
  gunicorn_config.keyfile = (
172
+ global_args.ssl_keyfile
173
+ if global_args.ssl_keyfile
174
+ else os.getenv("SSL_KEYFILE")
175
  )
176
 
177
  # Set configuration options from the module
 
196
  # Import the application
197
  from lightrag.api.lightrag_server import get_application
198
 
199
+ return get_application(global_args)
200
 
201
  # Create the application
202
  app = GunicornApp("")
203
 
204
  # Force workers to be an integer and greater than 1 for multi-process mode
205
+ workers_count = int(global_args.workers)
206
  if workers_count > 1:
207
  # Set a flag to indicate we're in the main process
208
  os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
lightrag/api/utils_api.py CHANGED
@@ -7,15 +7,13 @@ import argparse
7
  from typing import Optional, List, Tuple
8
  import sys
9
  from ascii_colors import ASCIIColors
10
- import logging
11
  from lightrag.api import __api_version__ as api_version
12
  from lightrag import __version__ as core_version
13
  from fastapi import HTTPException, Security, Request, status
14
- from dotenv import load_dotenv
15
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
16
  from starlette.status import HTTP_403_FORBIDDEN
17
  from .auth import auth_handler
18
- from ..prompt import PROMPTS
19
 
20
 
21
  def check_env_file():
@@ -36,16 +34,8 @@ def check_env_file():
36
  return True
37
 
38
 
39
- # use the .env that is inside the current folder
40
- # allows to use different .env file for each lightrag instance
41
- # the OS environment variables take precedence over the .env file
42
- load_dotenv(dotenv_path=".env", override=False)
43
-
44
- global_args = {"main_args": None}
45
-
46
- # Get whitelist paths from environment variable, only once during initialization
47
- default_whitelist = "/health,/api/*"
48
- whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
49
 
50
  # Pre-compile path matching patterns
51
  whitelist_patterns: List[Tuple[str, bool]] = []
@@ -63,19 +53,6 @@ for path in whitelist_paths:
63
  auth_configured = bool(auth_handler.accounts)
64
 
65
 
66
- class OllamaServerInfos:
67
- # Constants for emulated Ollama model information
68
- LIGHTRAG_NAME = "lightrag"
69
- LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
70
- LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
71
- LIGHTRAG_SIZE = 7365960935 # it's a dummy value
72
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
73
- LIGHTRAG_DIGEST = "sha256:lightrag"
74
-
75
-
76
- ollama_server_infos = OllamaServerInfos()
77
-
78
-
79
  def get_combined_auth_dependency(api_key: Optional[str] = None):
80
  """
81
  Create a combined authentication dependency that implements authentication logic
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
186
  return combined_dependency
187
 
188
 
189
- class DefaultRAGStorageConfig:
190
- KV_STORAGE = "JsonKVStorage"
191
- VECTOR_STORAGE = "NanoVectorDBStorage"
192
- GRAPH_STORAGE = "NetworkXStorage"
193
- DOC_STATUS_STORAGE = "JsonDocStatusStorage"
194
-
195
-
196
- def get_default_host(binding_type: str) -> str:
197
- default_hosts = {
198
- "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
199
- "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
200
- "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
201
- "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
202
- }
203
- return default_hosts.get(
204
- binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
205
- ) # fallback to ollama if unknown
206
-
207
-
208
- def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
209
- """
210
- Get value from environment variable with type conversion
211
-
212
- Args:
213
- env_key (str): Environment variable key
214
- default (any): Default value if env variable is not set
215
- value_type (type): Type to convert the value to
216
-
217
- Returns:
218
- any: Converted value from environment or default
219
- """
220
- value = os.getenv(env_key)
221
- if value is None:
222
- return default
223
-
224
- if value_type is bool:
225
- return value.lower() in ("true", "1", "yes", "t", "on")
226
- try:
227
- return value_type(value)
228
- except ValueError:
229
- return default
230
-
231
-
232
- def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
233
- """
234
- Parse command line arguments with environment variable fallback
235
-
236
- Args:
237
- is_uvicorn_mode: Whether running under uvicorn mode
238
-
239
- Returns:
240
- argparse.Namespace: Parsed arguments
241
- """
242
-
243
- parser = argparse.ArgumentParser(
244
- description="LightRAG FastAPI Server with separate working and input directories"
245
- )
246
-
247
- # Server configuration
248
- parser.add_argument(
249
- "--host",
250
- default=get_env_value("HOST", "0.0.0.0"),
251
- help="Server host (default: from env or 0.0.0.0)",
252
- )
253
- parser.add_argument(
254
- "--port",
255
- type=int,
256
- default=get_env_value("PORT", 9621, int),
257
- help="Server port (default: from env or 9621)",
258
- )
259
-
260
- # Directory configuration
261
- parser.add_argument(
262
- "--working-dir",
263
- default=get_env_value("WORKING_DIR", "./rag_storage"),
264
- help="Working directory for RAG storage (default: from env or ./rag_storage)",
265
- )
266
- parser.add_argument(
267
- "--input-dir",
268
- default=get_env_value("INPUT_DIR", "./inputs"),
269
- help="Directory containing input documents (default: from env or ./inputs)",
270
- )
271
-
272
- def timeout_type(value):
273
- if value is None:
274
- return 150
275
- if value is None or value == "None":
276
- return None
277
- return int(value)
278
-
279
- parser.add_argument(
280
- "--timeout",
281
- default=get_env_value("TIMEOUT", None, timeout_type),
282
- type=timeout_type,
283
- help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
284
- )
285
-
286
- # RAG configuration
287
- parser.add_argument(
288
- "--max-async",
289
- type=int,
290
- default=get_env_value("MAX_ASYNC", 4, int),
291
- help="Maximum async operations (default: from env or 4)",
292
- )
293
- parser.add_argument(
294
- "--max-tokens",
295
- type=int,
296
- default=get_env_value("MAX_TOKENS", 32768, int),
297
- help="Maximum token size (default: from env or 32768)",
298
- )
299
-
300
- # Logging configuration
301
- parser.add_argument(
302
- "--log-level",
303
- default=get_env_value("LOG_LEVEL", "INFO"),
304
- choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
305
- help="Logging level (default: from env or INFO)",
306
- )
307
- parser.add_argument(
308
- "--verbose",
309
- action="store_true",
310
- default=get_env_value("VERBOSE", False, bool),
311
- help="Enable verbose debug output(only valid for DEBUG log-level)",
312
- )
313
-
314
- parser.add_argument(
315
- "--key",
316
- type=str,
317
- default=get_env_value("LIGHTRAG_API_KEY", None),
318
- help="API key for authentication. This protects lightrag server against unauthorized access",
319
- )
320
-
321
- # Optional https parameters
322
- parser.add_argument(
323
- "--ssl",
324
- action="store_true",
325
- default=get_env_value("SSL", False, bool),
326
- help="Enable HTTPS (default: from env or False)",
327
- )
328
- parser.add_argument(
329
- "--ssl-certfile",
330
- default=get_env_value("SSL_CERTFILE", None),
331
- help="Path to SSL certificate file (required if --ssl is enabled)",
332
- )
333
- parser.add_argument(
334
- "--ssl-keyfile",
335
- default=get_env_value("SSL_KEYFILE", None),
336
- help="Path to SSL private key file (required if --ssl is enabled)",
337
- )
338
-
339
- parser.add_argument(
340
- "--history-turns",
341
- type=int,
342
- default=get_env_value("HISTORY_TURNS", 3, int),
343
- help="Number of conversation history turns to include (default: from env or 3)",
344
- )
345
-
346
- # Search parameters
347
- parser.add_argument(
348
- "--top-k",
349
- type=int,
350
- default=get_env_value("TOP_K", 60, int),
351
- help="Number of most similar results to return (default: from env or 60)",
352
- )
353
- parser.add_argument(
354
- "--cosine-threshold",
355
- type=float,
356
- default=get_env_value("COSINE_THRESHOLD", 0.2, float),
357
- help="Cosine similarity threshold (default: from env or 0.4)",
358
- )
359
-
360
- # Ollama model name
361
- parser.add_argument(
362
- "--simulated-model-name",
363
- type=str,
364
- default=get_env_value(
365
- "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
366
- ),
367
- help="Number of conversation history turns to include (default: from env or 3)",
368
- )
369
-
370
- # Namespace
371
- parser.add_argument(
372
- "--namespace-prefix",
373
- type=str,
374
- default=get_env_value("NAMESPACE_PREFIX", ""),
375
- help="Prefix of the namespace",
376
- )
377
-
378
- parser.add_argument(
379
- "--auto-scan-at-startup",
380
- action="store_true",
381
- default=False,
382
- help="Enable automatic scanning when the program starts",
383
- )
384
-
385
- # Server workers configuration
386
- parser.add_argument(
387
- "--workers",
388
- type=int,
389
- default=get_env_value("WORKERS", 1, int),
390
- help="Number of worker processes (default: from env or 1)",
391
- )
392
-
393
- # LLM and embedding bindings
394
- parser.add_argument(
395
- "--llm-binding",
396
- type=str,
397
- default=get_env_value("LLM_BINDING", "ollama"),
398
- choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
399
- help="LLM binding type (default: from env or ollama)",
400
- )
401
- parser.add_argument(
402
- "--embedding-binding",
403
- type=str,
404
- default=get_env_value("EMBEDDING_BINDING", "ollama"),
405
- choices=["lollms", "ollama", "openai", "azure_openai"],
406
- help="Embedding binding type (default: from env or ollama)",
407
- )
408
-
409
- args = parser.parse_args()
410
-
411
- # If in uvicorn mode and workers > 1, force it to 1 and log warning
412
- if is_uvicorn_mode and args.workers > 1:
413
- original_workers = args.workers
414
- args.workers = 1
415
- # Log warning directly here
416
- logging.warning(
417
- f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
418
- )
419
-
420
- # convert relative path to absolute path
421
- args.working_dir = os.path.abspath(args.working_dir)
422
- args.input_dir = os.path.abspath(args.input_dir)
423
-
424
- # Inject storage configuration from environment variables
425
- args.kv_storage = get_env_value(
426
- "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
427
- )
428
- args.doc_status_storage = get_env_value(
429
- "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
430
- )
431
- args.graph_storage = get_env_value(
432
- "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
433
- )
434
- args.vector_storage = get_env_value(
435
- "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
436
- )
437
-
438
- # Get MAX_PARALLEL_INSERT from environment
439
- args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
440
-
441
- # Handle openai-ollama special case
442
- if args.llm_binding == "openai-ollama":
443
- args.llm_binding = "openai"
444
- args.embedding_binding = "ollama"
445
-
446
- args.llm_binding_host = get_env_value(
447
- "LLM_BINDING_HOST", get_default_host(args.llm_binding)
448
- )
449
- args.embedding_binding_host = get_env_value(
450
- "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
451
- )
452
- args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
453
- args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
454
-
455
- # Inject model configuration
456
- args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
457
- args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
458
- args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
459
- args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
460
-
461
- # Inject chunk configuration
462
- args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
463
- args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
464
-
465
- # Inject LLM cache configuration
466
- args.enable_llm_cache_for_extract = get_env_value(
467
- "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
468
- )
469
-
470
- # Inject LLM temperature configuration
471
- args.temperature = get_env_value("TEMPERATURE", 0.5, float)
472
-
473
- # Select Document loading tool (DOCLING, DEFAULT)
474
- args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
475
-
476
- ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
477
-
478
- global_args["main_args"] = args
479
- return args
480
-
481
-
482
  def display_splash_screen(args: argparse.Namespace) -> None:
483
  """
484
  Display a colorful splash screen showing LightRAG server configuration
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
503
  ASCIIColors.white(" ├─ Workers: ", end="")
504
  ASCIIColors.yellow(f"{args.workers}")
505
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
506
- ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
507
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
508
  ASCIIColors.yellow(f"{args.ssl}")
509
  if args.ssl:
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
519
  ASCIIColors.yellow(f"{args.verbose}")
520
  ASCIIColors.white(" ├─ History Turns: ", end="")
521
  ASCIIColors.yellow(f"{args.history_turns}")
522
- ASCIIColors.white(" └─ API Key: ", end="")
523
  ASCIIColors.yellow("Set" if args.key else "Not Set")
 
 
524
 
525
  # Directory Configuration
526
  ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
558
  ASCIIColors.yellow(f"{args.embedding_dim}")
559
 
560
  # RAG Configuration
561
- summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
562
  ASCIIColors.magenta("\n⚙️ RAG Configuration:")
563
  ASCIIColors.white(" ├─ Summary Language: ", end="")
564
- ASCIIColors.yellow(f"{summary_language}")
565
  ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
566
  ASCIIColors.yellow(f"{args.max_parallel_insert}")
567
  ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
595
  protocol = "https" if args.ssl else "http"
596
  if args.host == "0.0.0.0":
597
  ASCIIColors.magenta("\n🌐 Server Access Information:")
598
- ASCIIColors.white(" ├─ Local Access: ", end="")
599
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
600
  ASCIIColors.white(" ├─ Remote Access: ", end="")
601
  ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
602
  ASCIIColors.white(" ├─ API Documentation (local): ", end="")
603
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
604
- ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="")
605
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
606
- ASCIIColors.white(" └─ WebUI (local): ", end="")
607
- ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
608
 
609
- ASCIIColors.yellow("\n📝 Note:")
610
- ASCIIColors.white(""" Since the server is running on 0.0.0.0:
611
  - Use 'localhost' or '127.0.0.1' for local access
612
  - Use your machine's IP address for remote access
613
  - To find your IP address:
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
617
  else:
618
  base_url = f"{protocol}://{args.host}:{args.port}"
619
  ASCIIColors.magenta("\n🌐 Server Access Information:")
620
- ASCIIColors.white(" ├─ Base URL: ", end="")
621
  ASCIIColors.yellow(f"{base_url}")
622
  ASCIIColors.white(" ├─ API Documentation: ", end="")
623
  ASCIIColors.yellow(f"{base_url}/docs")
624
  ASCIIColors.white(" └─ Alternative Documentation: ", end="")
625
  ASCIIColors.yellow(f"{base_url}/redoc")
626
 
627
- # Usage Examples
628
- ASCIIColors.magenta("\n📚 Quick Start Guide:")
629
- ASCIIColors.cyan("""
630
- 1. Access the Swagger UI:
631
- Open your browser and navigate to the API documentation URL above
632
-
633
- 2. API Authentication:""")
634
- if args.key:
635
- ASCIIColors.cyan(""" Add the following header to your requests:
636
- X-API-Key: <your-api-key>
637
- """)
638
- else:
639
- ASCIIColors.cyan(" No authentication required\n")
640
-
641
- ASCIIColors.cyan(""" 3. Basic Operations:
642
- - POST /upload_document: Upload new documents to RAG
643
- - POST /query: Query your document collection
644
-
645
- 4. Monitor the server:
646
- - Check server logs for detailed operation information
647
- - Use healthcheck endpoint: GET /health
648
- """)
649
-
650
  # Security Notice
651
  if args.key:
652
  ASCIIColors.yellow("\n⚠️ Security Notice:")
653
  ASCIIColors.white(""" API Key authentication is enabled.
654
  Make sure to include the X-API-Key header in all your requests.
655
  """)
 
 
 
 
 
656
 
657
  # Ensure splash output flush to system log
658
  sys.stdout.flush()
 
7
  from typing import Optional, List, Tuple
8
  import sys
9
  from ascii_colors import ASCIIColors
 
10
  from lightrag.api import __api_version__ as api_version
11
  from lightrag import __version__ as core_version
12
  from fastapi import HTTPException, Security, Request, status
 
13
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
14
  from starlette.status import HTTP_403_FORBIDDEN
15
  from .auth import auth_handler
16
+ from .config import ollama_server_infos, global_args
17
 
18
 
19
  def check_env_file():
 
34
  return True
35
 
36
 
37
+ # Get whitelist paths from global_args, only once during initialization
38
+ whitelist_paths = global_args.whitelist_paths.split(",")
 
 
 
 
 
 
 
 
39
 
40
  # Pre-compile path matching patterns
41
  whitelist_patterns: List[Tuple[str, bool]] = []
 
53
  auth_configured = bool(auth_handler.accounts)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def get_combined_auth_dependency(api_key: Optional[str] = None):
57
  """
58
  Create a combined authentication dependency that implements authentication logic
 
163
  return combined_dependency
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def display_splash_screen(args: argparse.Namespace) -> None:
167
  """
168
  Display a colorful splash screen showing LightRAG server configuration
 
187
  ASCIIColors.white(" ├─ Workers: ", end="")
188
  ASCIIColors.yellow(f"{args.workers}")
189
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
190
+ ASCIIColors.yellow(f"{args.cors_origins}")
191
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
192
  ASCIIColors.yellow(f"{args.ssl}")
193
  if args.ssl:
 
203
  ASCIIColors.yellow(f"{args.verbose}")
204
  ASCIIColors.white(" ├─ History Turns: ", end="")
205
  ASCIIColors.yellow(f"{args.history_turns}")
206
+ ASCIIColors.white(" ├─ API Key: ", end="")
207
  ASCIIColors.yellow("Set" if args.key else "Not Set")
208
+ ASCIIColors.white(" └─ JWT Auth: ", end="")
209
+ ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
210
 
211
  # Directory Configuration
212
  ASCIIColors.magenta("\n📂 Directory Configuration:")
 
244
  ASCIIColors.yellow(f"{args.embedding_dim}")
245
 
246
  # RAG Configuration
 
247
  ASCIIColors.magenta("\n⚙️ RAG Configuration:")
248
  ASCIIColors.white(" ├─ Summary Language: ", end="")
249
+ ASCIIColors.yellow(f"{args.summary_language}")
250
  ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
251
  ASCIIColors.yellow(f"{args.max_parallel_insert}")
252
  ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
 
280
  protocol = "https" if args.ssl else "http"
281
  if args.host == "0.0.0.0":
282
  ASCIIColors.magenta("\n🌐 Server Access Information:")
283
+ ASCIIColors.white(" ├─ WebUI (local): ", end="")
284
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
285
  ASCIIColors.white(" ├─ Remote Access: ", end="")
286
  ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
287
  ASCIIColors.white(" ├─ API Documentation (local): ", end="")
288
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
289
+ ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
290
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
 
 
291
 
292
+ ASCIIColors.magenta("\n📝 Note:")
293
+ ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
294
  - Use 'localhost' or '127.0.0.1' for local access
295
  - Use your machine's IP address for remote access
296
  - To find your IP address:
 
300
  else:
301
  base_url = f"{protocol}://{args.host}:{args.port}"
302
  ASCIIColors.magenta("\n🌐 Server Access Information:")
303
+ ASCIIColors.white(" ├─ WebUI (local): ", end="")
304
  ASCIIColors.yellow(f"{base_url}")
305
  ASCIIColors.white(" ├─ API Documentation: ", end="")
306
  ASCIIColors.yellow(f"{base_url}/docs")
307
  ASCIIColors.white(" └─ Alternative Documentation: ", end="")
308
  ASCIIColors.yellow(f"{base_url}/redoc")
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # Security Notice
311
  if args.key:
312
  ASCIIColors.yellow("\n⚠️ Security Notice:")
313
  ASCIIColors.white(""" API Key authentication is enabled.
314
  Make sure to include the X-API-Key header in all your requests.
315
  """)
316
+ if args.auth_accounts:
317
+ ASCIIColors.yellow("\n⚠️ Security Notice:")
318
+ ASCIIColors.white(""" JWT authentication is enabled.
319
+ Make sure to login before making the request, and include the 'Authorization' in the header.
320
+ """)
321
 
322
  # Ensure splash output flush to system log
323
  sys.stdout.flush()
lightrag/api/webui/assets/{index-D8zGvNlV.js → index-BaHKTcxB.js} RENAMED
Binary files a/lightrag/api/webui/assets/index-D8zGvNlV.js and b/lightrag/api/webui/assets/index-BaHKTcxB.js differ
 
lightrag/api/webui/assets/index-CD5HxTy1.css DELETED
Binary file (55.1 kB)
 
lightrag/api/webui/assets/index-f0HMqdqP.css ADDED
Binary file (57.1 kB). View file
 
lightrag/api/webui/index.html CHANGED
Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ
 
lightrag/base.py CHANGED
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
112
  async def index_done_callback(self) -> None:
113
  """Commit the storage operations after indexing"""
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  @dataclass
117
  class BaseVectorStorage(StorageNameSpace, ABC):
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
127
 
128
  @abstractmethod
129
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
130
- """Insert or update vectors in the storage."""
 
 
 
 
 
 
131
 
132
  @abstractmethod
133
  async def delete_entity(self, entity_name: str) -> None:
134
- """Delete a single entity by its name."""
 
 
 
 
 
 
135
 
136
  @abstractmethod
137
  async def delete_entity_relation(self, entity_name: str) -> None:
138
- """Delete relations for a given entity."""
 
 
 
 
 
 
139
 
140
  @abstractmethod
141
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
161
  """
162
  pass
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  @dataclass
166
  class BaseKVStorage(StorageNameSpace, ABC):
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
180
 
181
  @abstractmethod
182
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
183
- """Upsert data"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  @dataclass
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
205
 
206
  @abstractmethod
207
  async def get_node(self, node_id: str) -> dict[str, str] | None:
208
- """Get an edge by its source and target node ids."""
209
 
210
  @abstractmethod
211
  async def get_edge(
212
  self, source_node_id: str, target_node_id: str
213
  ) -> dict[str, str] | None:
214
- """Get all edges connected to a node."""
215
 
216
  @abstractmethod
217
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
225
  async def upsert_edge(
226
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
227
  ) -> None:
228
- """Delete a node from the graph."""
 
 
 
 
 
 
229
 
230
  @abstractmethod
231
  async def delete_node(self, node_id: str) -> None:
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
243
 
244
  @abstractmethod
245
  async def get_knowledge_graph(
246
- self, node_label: str, max_depth: int = 3
247
  ) -> KnowledgeGraph:
248
- """Retrieve a subgraph of the knowledge graph starting from a given node."""
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  class DocStatus(str, Enum):
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
297
  ) -> dict[str, DocProcessingStatus]:
298
  """Get all documents with a specific status"""
299
 
 
 
 
 
300
 
301
  class StoragesStatus(str, Enum):
302
  """Storages status"""
 
112
  async def index_done_callback(self) -> None:
113
  """Commit the storage operations after indexing"""
114
 
115
+ @abstractmethod
116
+ async def drop(self) -> dict[str, str]:
117
+ """Drop all data from storage and clean up resources
118
+
119
+ This abstract method defines the contract for dropping all data from a storage implementation.
120
+ Each storage type must implement this method to:
121
+ 1. Clear all data from memory and/or external storage
122
+ 2. Remove any associated storage files if applicable
123
+ 3. Reset the storage to its initial state
124
+ 4. Handle cleanup of any resources
125
+ 5. Notify other processes if necessary
126
+ 6. This action should persistent the data to disk immediately.
127
+
128
+ Returns:
129
+ dict[str, str]: Operation status and message with the following format:
130
+ {
131
+ "status": str, # "success" or "error"
132
+ "message": str # "data dropped" on success, error details on failure
133
+ }
134
+
135
+ Implementation specific:
136
+ - On success: return {"status": "success", "message": "data dropped"}
137
+ - On failure: return {"status": "error", "message": "<error details>"}
138
+ - If not supported: return {"status": "error", "message": "unsupported"}
139
+ """
140
+
141
 
142
  @dataclass
143
  class BaseVectorStorage(StorageNameSpace, ABC):
 
153
 
154
  @abstractmethod
155
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
156
+ """Insert or update vectors in the storage.
157
+
158
+ Importance notes for in-memory storage:
159
+ 1. Changes will be persisted to disk during the next index_done_callback
160
+ 2. Only one process should updating the storage at a time before index_done_callback,
161
+ KG-storage-log should be used to avoid data corruption
162
+ """
163
 
164
  @abstractmethod
165
  async def delete_entity(self, entity_name: str) -> None:
166
+ """Delete a single entity by its name.
167
+
168
+ Importance notes for in-memory storage:
169
+ 1. Changes will be persisted to disk during the next index_done_callback
170
+ 2. Only one process should updating the storage at a time before index_done_callback,
171
+ KG-storage-log should be used to avoid data corruption
172
+ """
173
 
174
  @abstractmethod
175
  async def delete_entity_relation(self, entity_name: str) -> None:
176
+ """Delete relations for a given entity.
177
+
178
+ Importance notes for in-memory storage:
179
+ 1. Changes will be persisted to disk during the next index_done_callback
180
+ 2. Only one process should updating the storage at a time before index_done_callback,
181
+ KG-storage-log should be used to avoid data corruption
182
+ """
183
 
184
  @abstractmethod
185
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
205
  """
206
  pass
207
 
208
+ @abstractmethod
209
+ async def delete(self, ids: list[str]):
210
+ """Delete vectors with specified IDs
211
+
212
+ Importance notes for in-memory storage:
213
+ 1. Changes will be persisted to disk during the next index_done_callback
214
+ 2. Only one process should updating the storage at a time before index_done_callback,
215
+ KG-storage-log should be used to avoid data corruption
216
+
217
+ Args:
218
+ ids: List of vector IDs to be deleted
219
+ """
220
+
221
 
222
  @dataclass
223
  class BaseKVStorage(StorageNameSpace, ABC):
 
237
 
238
  @abstractmethod
239
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
240
+ """Upsert data
241
+
242
+ Importance notes for in-memory storage:
243
+ 1. Changes will be persisted to disk during the next index_done_callback
244
+ 2. update flags to notify other processes that data persistence is needed
245
+ """
246
+
247
+ @abstractmethod
248
+ async def delete(self, ids: list[str]) -> None:
249
+ """Delete specific records from storage by their IDs
250
+
251
+ Importance notes for in-memory storage:
252
+ 1. Changes will be persisted to disk during the next index_done_callback
253
+ 2. update flags to notify other processes that data persistence is needed
254
+
255
+ Args:
256
+ ids (list[str]): List of document IDs to be deleted from storage
257
+
258
+ Returns:
259
+ None
260
+ """
261
+
262
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
263
+ """Delete specific records from storage by cache mode
264
+
265
+ Importance notes for in-memory storage:
266
+ 1. Changes will be persisted to disk during the next index_done_callback
267
+ 2. update flags to notify other processes that data persistence is needed
268
+
269
+ Args:
270
+ modes (list[str]): List of cache modes to be dropped from storage
271
+
272
+ Returns:
273
+ True: if the cache drop successfully
274
+ False: if the cache drop failed, or the cache mode is not supported
275
+ """
276
 
277
 
278
  @dataclass
 
297
 
298
  @abstractmethod
299
  async def get_node(self, node_id: str) -> dict[str, str] | None:
300
+ """Get node by its label identifier, return only node properties"""
301
 
302
  @abstractmethod
303
  async def get_edge(
304
  self, source_node_id: str, target_node_id: str
305
  ) -> dict[str, str] | None:
306
+ """Get edge properties between two nodes"""
307
 
308
  @abstractmethod
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
 
317
  async def upsert_edge(
318
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
319
  ) -> None:
320
+ """Delete a node from the graph.
321
+
322
+ Importance notes for in-memory storage:
323
+ 1. Changes will be persisted to disk during the next index_done_callback
324
+ 2. Only one process should updating the storage at a time before index_done_callback,
325
+ KG-storage-log should be used to avoid data corruption
326
+ """
327
 
328
  @abstractmethod
329
  async def delete_node(self, node_id: str) -> None:
 
341
 
342
  @abstractmethod
343
  async def get_knowledge_graph(
344
+ self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
345
  ) -> KnowledgeGraph:
346
+ """
347
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
348
+
349
+ Args:
350
+ node_label: Label of the starting node,* means all nodes
351
+ max_depth: Maximum depth of the subgraph, Defaults to 3
352
+ max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
353
+
354
+ Returns:
355
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
356
+ indicating whether the graph was truncated due to max_nodes limit
357
+ """
358
 
359
 
360
  class DocStatus(str, Enum):
 
406
  ) -> dict[str, DocProcessingStatus]:
407
  """Get all documents with a specific status"""
408
 
409
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
410
+ """Drop cache is not supported for Doc Status storage"""
411
+ return False
412
+
413
 
414
  class StoragesStatus(str, Enum):
415
  """Storages status"""
lightrag/kg/__init__.py CHANGED
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
2
  "KV_STORAGE": {
3
  "implementations": [
4
  "JsonKVStorage",
5
- "MongoKVStorage",
6
  "RedisKVStorage",
7
- "TiDBKVStorage",
8
  "PGKVStorage",
9
- "OracleKVStorage",
 
10
  ],
11
  "required_methods": ["get_by_id", "upsert"],
12
  },
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
14
  "implementations": [
15
  "NetworkXStorage",
16
  "Neo4JStorage",
17
- "MongoGraphStorage",
18
- "TiDBGraphStorage",
19
- "AGEStorage",
20
- "GremlinStorage",
21
  "PGGraphStorage",
22
- "OracleGraphStorage",
 
 
 
23
  ],
24
  "required_methods": ["upsert_node", "upsert_edge"],
25
  },
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
28
  "NanoVectorDBStorage",
29
  "MilvusVectorDBStorage",
30
  "ChromaVectorDBStorage",
31
- "TiDBVectorDBStorage",
32
  "PGVectorStorage",
33
  "FaissVectorDBStorage",
34
  "QdrantVectorDBStorage",
35
- "OracleVectorDBStorage",
36
  "MongoVectorDBStorage",
 
37
  ],
38
  "required_methods": ["query", "upsert"],
39
  },
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
41
  "implementations": [
42
  "JsonDocStatusStorage",
43
  "PGDocStatusStorage",
44
- "PGDocStatusStorage",
45
  "MongoDocStatusStorage",
46
  ],
47
  "required_methods": ["get_docs_by_status"],
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
54
  "JsonKVStorage": [],
55
  "MongoKVStorage": [],
56
  "RedisKVStorage": ["REDIS_URI"],
57
- "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
58
  "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
59
- "OracleKVStorage": [
60
- "ORACLE_DSN",
61
- "ORACLE_USER",
62
- "ORACLE_PASSWORD",
63
- "ORACLE_CONFIG_DIR",
64
- ],
65
  # Graph Storage Implementations
66
  "NetworkXStorage": [],
67
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
68
  "MongoGraphStorage": [],
69
- "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
70
  "AGEStorage": [
71
  "AGE_POSTGRES_DB",
72
  "AGE_POSTGRES_USER",
73
  "AGE_POSTGRES_PASSWORD",
74
  ],
75
- "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
76
  "PGGraphStorage": [
77
  "POSTGRES_USER",
78
  "POSTGRES_PASSWORD",
79
  "POSTGRES_DATABASE",
80
  ],
81
- "OracleGraphStorage": [
82
- "ORACLE_DSN",
83
- "ORACLE_USER",
84
- "ORACLE_PASSWORD",
85
- "ORACLE_CONFIG_DIR",
86
- ],
87
  # Vector Storage Implementations
88
  "NanoVectorDBStorage": [],
89
  "MilvusVectorDBStorage": [],
90
  "ChromaVectorDBStorage": [],
91
- "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
92
  "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
93
  "FaissVectorDBStorage": [],
94
  "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
95
- "OracleVectorDBStorage": [
96
- "ORACLE_DSN",
97
- "ORACLE_USER",
98
- "ORACLE_PASSWORD",
99
- "ORACLE_CONFIG_DIR",
100
- ],
101
  "MongoVectorDBStorage": [],
102
  # Document Status Storage Implementations
103
  "JsonDocStatusStorage": [],
@@ -112,9 +90,6 @@ STORAGES = {
112
  "NanoVectorDBStorage": ".kg.nano_vector_db_impl",
113
  "JsonDocStatusStorage": ".kg.json_doc_status_impl",
114
  "Neo4JStorage": ".kg.neo4j_impl",
115
- "OracleKVStorage": ".kg.oracle_impl",
116
- "OracleGraphStorage": ".kg.oracle_impl",
117
- "OracleVectorDBStorage": ".kg.oracle_impl",
118
  "MilvusVectorDBStorage": ".kg.milvus_impl",
119
  "MongoKVStorage": ".kg.mongo_impl",
120
  "MongoDocStatusStorage": ".kg.mongo_impl",
@@ -122,14 +97,14 @@ STORAGES = {
122
  "MongoVectorDBStorage": ".kg.mongo_impl",
123
  "RedisKVStorage": ".kg.redis_impl",
124
  "ChromaVectorDBStorage": ".kg.chroma_impl",
125
- "TiDBKVStorage": ".kg.tidb_impl",
126
- "TiDBVectorDBStorage": ".kg.tidb_impl",
127
- "TiDBGraphStorage": ".kg.tidb_impl",
128
  "PGKVStorage": ".kg.postgres_impl",
129
  "PGVectorStorage": ".kg.postgres_impl",
130
  "AGEStorage": ".kg.age_impl",
131
  "PGGraphStorage": ".kg.postgres_impl",
132
- "GremlinStorage": ".kg.gremlin_impl",
133
  "PGDocStatusStorage": ".kg.postgres_impl",
134
  "FaissVectorDBStorage": ".kg.faiss_impl",
135
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
 
2
  "KV_STORAGE": {
3
  "implementations": [
4
  "JsonKVStorage",
 
5
  "RedisKVStorage",
 
6
  "PGKVStorage",
7
+ "MongoKVStorage",
8
+ # "TiDBKVStorage",
9
  ],
10
  "required_methods": ["get_by_id", "upsert"],
11
  },
 
13
  "implementations": [
14
  "NetworkXStorage",
15
  "Neo4JStorage",
 
 
 
 
16
  "PGGraphStorage",
17
+ # "AGEStorage",
18
+ # "MongoGraphStorage",
19
+ # "TiDBGraphStorage",
20
+ # "GremlinStorage",
21
  ],
22
  "required_methods": ["upsert_node", "upsert_edge"],
23
  },
 
26
  "NanoVectorDBStorage",
27
  "MilvusVectorDBStorage",
28
  "ChromaVectorDBStorage",
 
29
  "PGVectorStorage",
30
  "FaissVectorDBStorage",
31
  "QdrantVectorDBStorage",
 
32
  "MongoVectorDBStorage",
33
+ # "TiDBVectorDBStorage",
34
  ],
35
  "required_methods": ["query", "upsert"],
36
  },
 
38
  "implementations": [
39
  "JsonDocStatusStorage",
40
  "PGDocStatusStorage",
 
41
  "MongoDocStatusStorage",
42
  ],
43
  "required_methods": ["get_docs_by_status"],
 
50
  "JsonKVStorage": [],
51
  "MongoKVStorage": [],
52
  "RedisKVStorage": ["REDIS_URI"],
53
+ # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
54
  "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
 
 
 
 
 
 
55
  # Graph Storage Implementations
56
  "NetworkXStorage": [],
57
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
58
  "MongoGraphStorage": [],
59
+ # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
60
  "AGEStorage": [
61
  "AGE_POSTGRES_DB",
62
  "AGE_POSTGRES_USER",
63
  "AGE_POSTGRES_PASSWORD",
64
  ],
65
+ # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
66
  "PGGraphStorage": [
67
  "POSTGRES_USER",
68
  "POSTGRES_PASSWORD",
69
  "POSTGRES_DATABASE",
70
  ],
 
 
 
 
 
 
71
  # Vector Storage Implementations
72
  "NanoVectorDBStorage": [],
73
  "MilvusVectorDBStorage": [],
74
  "ChromaVectorDBStorage": [],
75
+ # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
76
  "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
77
  "FaissVectorDBStorage": [],
78
  "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
 
 
 
 
 
 
79
  "MongoVectorDBStorage": [],
80
  # Document Status Storage Implementations
81
  "JsonDocStatusStorage": [],
 
90
  "NanoVectorDBStorage": ".kg.nano_vector_db_impl",
91
  "JsonDocStatusStorage": ".kg.json_doc_status_impl",
92
  "Neo4JStorage": ".kg.neo4j_impl",
 
 
 
93
  "MilvusVectorDBStorage": ".kg.milvus_impl",
94
  "MongoKVStorage": ".kg.mongo_impl",
95
  "MongoDocStatusStorage": ".kg.mongo_impl",
 
97
  "MongoVectorDBStorage": ".kg.mongo_impl",
98
  "RedisKVStorage": ".kg.redis_impl",
99
  "ChromaVectorDBStorage": ".kg.chroma_impl",
100
+ # "TiDBKVStorage": ".kg.tidb_impl",
101
+ # "TiDBVectorDBStorage": ".kg.tidb_impl",
102
+ # "TiDBGraphStorage": ".kg.tidb_impl",
103
  "PGKVStorage": ".kg.postgres_impl",
104
  "PGVectorStorage": ".kg.postgres_impl",
105
  "AGEStorage": ".kg.age_impl",
106
  "PGGraphStorage": ".kg.postgres_impl",
107
+ # "GremlinStorage": ".kg.gremlin_impl",
108
  "PGDocStatusStorage": ".kg.postgres_impl",
109
  "FaissVectorDBStorage": ".kg.faiss_impl",
110
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
lightrag/kg/age_impl.py CHANGED
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
34
  if not pm.is_installed("asyncpg"):
35
  pm.install("asyncpg")
36
 
37
- import psycopg
38
- from psycopg.rows import namedtuple_row
39
- from psycopg_pool import AsyncConnectionPool, PoolTimeout
40
 
41
 
42
  class AGEQueryException(Exception):
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
871
  async def index_done_callback(self) -> None:
872
  # AGES handles persistence automatically
873
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if not pm.is_installed("asyncpg"):
35
  pm.install("asyncpg")
36
 
37
+ import psycopg # type: ignore
38
+ from psycopg.rows import namedtuple_row # type: ignore
39
+ from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
40
 
41
 
42
  class AGEQueryException(Exception):
 
871
  async def index_done_callback(self) -> None:
872
  # AGES handles persistence automatically
873
  pass
874
+
875
+ async def drop(self) -> dict[str, str]:
876
+ """Drop the storage by removing all nodes and relationships in the graph.
877
+
878
+ Returns:
879
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
880
+ """
881
+ try:
882
+ query = """
883
+ MATCH (n)
884
+ DETACH DELETE n
885
+ """
886
+ await self._query(query)
887
+ logger.info(f"Successfully dropped all data from graph {self.graph_name}")
888
+ return {"status": "success", "message": "graph data dropped"}
889
+ except Exception as e:
890
+ logger.error(f"Error dropping graph {self.graph_name}: {e}")
891
+ return {"status": "error", "message": str(e)}
lightrag/kg/chroma_impl.py CHANGED
@@ -1,4 +1,5 @@
1
  import asyncio
 
2
  from dataclasses import dataclass
3
  from typing import Any, final
4
  import numpy as np
@@ -10,8 +11,8 @@ import pipmaster as pm
10
  if not pm.is_installed("chromadb"):
11
  pm.install("chromadb")
12
 
13
- from chromadb import HttpClient, PersistentClient
14
- from chromadb.config import Settings
15
 
16
 
17
  @final
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
335
  except Exception as e:
336
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
337
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
5
  import numpy as np
 
11
  if not pm.is_installed("chromadb"):
12
  pm.install("chromadb")
13
 
14
+ from chromadb import HttpClient, PersistentClient # type: ignore
15
+ from chromadb.config import Settings # type: ignore
16
 
17
 
18
  @final
 
336
  except Exception as e:
337
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
338
  return []
339
+
340
+ async def drop(self) -> dict[str, str]:
341
+ """Drop all vector data from storage and clean up resources
342
+
343
+ This method will delete all documents from the ChromaDB collection.
344
+
345
+ Returns:
346
+ dict[str, str]: Operation status and message
347
+ - On success: {"status": "success", "message": "data dropped"}
348
+ - On failure: {"status": "error", "message": "<error details>"}
349
+ """
350
+ try:
351
+ # Get all IDs in the collection
352
+ result = self._collection.get(include=[])
353
+ if result and result["ids"] and len(result["ids"]) > 0:
354
+ # Delete all documents
355
+ self._collection.delete(ids=result["ids"])
356
+
357
+ logger.info(
358
+ f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
359
+ )
360
+ return {"status": "success", "message": "data dropped"}
361
+ except Exception as e:
362
+ logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
363
+ return {"status": "error", "message": str(e)}
lightrag/kg/faiss_impl.py CHANGED
@@ -11,16 +11,20 @@ import pipmaster as pm
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
 
14
- if not pm.is_installed("faiss"):
15
- pm.install("faiss")
16
-
17
- import faiss # type: ignore
18
  from .shared_storage import (
19
  get_storage_lock,
20
  get_update_flag,
21
  set_all_update_flags,
22
  )
23
 
 
 
 
 
 
 
 
 
24
 
25
  @final
26
  @dataclass
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
217
  async def delete(self, ids: list[str]):
218
  """
219
  Delete vectors for the provided custom IDs.
 
 
 
 
 
220
  """
221
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
222
  to_remove = []
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
232
  )
233
 
234
  async def delete_entity(self, entity_name: str) -> None:
 
 
 
 
 
 
235
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
236
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
237
  await self.delete([entity_id])
238
 
239
  async def delete_entity_relation(self, entity_name: str) -> None:
240
  """
241
- Delete relations for a given entity by scanning metadata.
 
 
 
242
  """
243
  logger.debug(f"Searching relations for entity {entity_name}")
244
  relations = []
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
429
  results.append({**metadata, "id": metadata.get("__id__")})
430
 
431
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
 
 
 
 
 
14
  from .shared_storage import (
15
  get_storage_lock,
16
  get_update_flag,
17
  set_all_update_flags,
18
  )
19
 
20
+ import faiss # type: ignore
21
+
22
+ USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
23
+ FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
24
+
25
+ if not pm.is_installed(FAISS_PACKAGE):
26
+ pm.install(FAISS_PACKAGE)
27
+
28
 
29
  @final
30
  @dataclass
 
221
  async def delete(self, ids: list[str]):
222
  """
223
  Delete vectors for the provided custom IDs.
224
+
225
+ Importance notes:
226
+ 1. Changes will be persisted to disk during the next index_done_callback
227
+ 2. Only one process should updating the storage at a time before index_done_callback,
228
+ KG-storage-log should be used to avoid data corruption
229
  """
230
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
231
  to_remove = []
 
241
  )
242
 
243
  async def delete_entity(self, entity_name: str) -> None:
244
+ """
245
+ Importance notes:
246
+ 1. Changes will be persisted to disk during the next index_done_callback
247
+ 2. Only one process should updating the storage at a time before index_done_callback,
248
+ KG-storage-log should be used to avoid data corruption
249
+ """
250
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
251
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
252
  await self.delete([entity_id])
253
 
254
  async def delete_entity_relation(self, entity_name: str) -> None:
255
  """
256
+ Importance notes:
257
+ 1. Changes will be persisted to disk during the next index_done_callback
258
+ 2. Only one process should updating the storage at a time before index_done_callback,
259
+ KG-storage-log should be used to avoid data corruption
260
  """
261
  logger.debug(f"Searching relations for entity {entity_name}")
262
  relations = []
 
447
  results.append({**metadata, "id": metadata.get("__id__")})
448
 
449
  return results
450
+
451
+ async def drop(self) -> dict[str, str]:
452
+ """Drop all vector data from storage and clean up resources
453
+
454
+ This method will:
455
+ 1. Remove the vector database storage file if it exists
456
+ 2. Reinitialize the vector database client
457
+ 3. Update flags to notify other processes
458
+ 4. Changes is persisted to disk immediately
459
+
460
+ This method will remove all vectors from the Faiss index and delete the storage files.
461
+
462
+ Returns:
463
+ dict[str, str]: Operation status and message
464
+ - On success: {"status": "success", "message": "data dropped"}
465
+ - On failure: {"status": "error", "message": "<error details>"}
466
+ """
467
+ try:
468
+ async with self._storage_lock:
469
+ # Reset the index
470
+ self._index = faiss.IndexFlatIP(self._dim)
471
+ self._id_to_meta = {}
472
+
473
+ # Remove storage files if they exist
474
+ if os.path.exists(self._faiss_index_file):
475
+ os.remove(self._faiss_index_file)
476
+ if os.path.exists(self._meta_file):
477
+ os.remove(self._meta_file)
478
+
479
+ self._id_to_meta = {}
480
+ self._load_faiss_index()
481
+
482
+ # Notify other processes
483
+ await set_all_update_flags(self.namespace)
484
+ self.storage_updated.value = False
485
+
486
+ logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
487
+ return {"status": "success", "message": "data dropped"}
488
+ except Exception as e:
489
+ logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
490
+ return {"status": "error", "message": str(e)}
lightrag/kg/gremlin_impl.py CHANGED
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
24
  if not pm.is_installed("gremlinpython"):
25
  pm.install("gremlinpython")
26
 
27
- from gremlin_python.driver import client, serializer
28
- from gremlin_python.driver.aiohttp.transport import AiohttpTransport
29
- from gremlin_python.driver.protocol import GremlinServerError
30
 
31
 
32
  @final
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
695
  except Exception as e:
696
  logger.error(f"Error during edge deletion: {str(e)}")
697
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  if not pm.is_installed("gremlinpython"):
25
  pm.install("gremlinpython")
26
 
27
+ from gremlin_python.driver import client, serializer # type: ignore
28
+ from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
29
+ from gremlin_python.driver.protocol import GremlinServerError # type: ignore
30
 
31
 
32
  @final
 
695
  except Exception as e:
696
  logger.error(f"Error during edge deletion: {str(e)}")
697
  raise
698
+
699
+ async def drop(self) -> dict[str, str]:
700
+ """Drop the storage by removing all nodes and relationships in the graph.
701
+
702
+ This function deletes all nodes with the specified graph name property,
703
+ which automatically removes all associated edges.
704
+
705
+ Returns:
706
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
707
+ """
708
+ try:
709
+ query = f"""g
710
+ .V().has('graph', {self.graph_name})
711
+ .drop()
712
+ """
713
+ await self._query(query)
714
+ logger.info(f"Successfully dropped all data from graph {self.graph_name}")
715
+ return {"status": "success", "message": "graph data dropped"}
716
+ except Exception as e:
717
+ logger.error(f"Error dropping graph {self.graph_name}: {e}")
718
+ return {"status": "error", "message": str(e)}
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
109
  await clear_all_update_flags(self.namespace)
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
112
  if not data:
113
  return
114
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
122
  async with self._storage_lock:
123
  return self._data.get(id)
124
 
125
- async def delete(self, doc_ids: list[str]):
 
 
 
 
 
 
 
 
 
 
 
 
126
  async with self._storage_lock:
 
127
  for doc_id in doc_ids:
128
- self._data.pop(doc_id, None)
129
- await set_all_update_flags(self.namespace)
130
- await self.index_done_callback()
131
 
132
- async def drop(self) -> None:
133
- """Drop the storage"""
134
- async with self._storage_lock:
135
- self._data.clear()
136
- await set_all_update_flags(self.namespace)
137
- await self.index_done_callback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  await clear_all_update_flags(self.namespace)
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
+ """
113
+ Importance notes for in-memory storage:
114
+ 1. Changes will be persisted to disk during the next index_done_callback
115
+ 2. update flags to notify other processes that data persistence is needed
116
+ """
117
  if not data:
118
  return
119
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
 
127
  async with self._storage_lock:
128
  return self._data.get(id)
129
 
130
+ async def delete(self, doc_ids: list[str]) -> None:
131
+ """Delete specific records from storage by their IDs
132
+
133
+ Importance notes for in-memory storage:
134
+ 1. Changes will be persisted to disk during the next index_done_callback
135
+ 2. update flags to notify other processes that data persistence is needed
136
+
137
+ Args:
138
+ ids (list[str]): List of document IDs to be deleted from storage
139
+
140
+ Returns:
141
+ None
142
+ """
143
  async with self._storage_lock:
144
+ any_deleted = False
145
  for doc_id in doc_ids:
146
+ result = self._data.pop(doc_id, None)
147
+ if result is not None:
148
+ any_deleted = True
149
 
150
+ if any_deleted:
151
+ await set_all_update_flags(self.namespace)
152
+
153
+ async def drop(self) -> dict[str, str]:
154
+ """Drop all document status data from storage and clean up resources
155
+
156
+ This method will:
157
+ 1. Clear all document status data from memory
158
+ 2. Update flags to notify other processes
159
+ 3. Trigger index_done_callback to save the empty state
160
+
161
+ Returns:
162
+ dict[str, str]: Operation status and message
163
+ - On success: {"status": "success", "message": "data dropped"}
164
+ - On failure: {"status": "error", "message": "<error details>"}
165
+ """
166
+ try:
167
+ async with self._storage_lock:
168
+ self._data.clear()
169
+ await set_all_update_flags(self.namespace)
170
+
171
+ await self.index_done_callback()
172
+ logger.info(f"Process {os.getpid()} drop {self.namespace}")
173
+ return {"status": "success", "message": "data dropped"}
174
+ except Exception as e:
175
+ logger.error(f"Error dropping {self.namespace}: {e}")
176
+ return {"status": "error", "message": str(e)}
lightrag/kg/json_kv_impl.py CHANGED
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
114
  return set(keys) - set(self._data.keys())
115
 
116
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
117
  if not data:
118
  return
119
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
122
  await set_all_update_flags(self.namespace)
123
 
124
  async def delete(self, ids: list[str]) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
125
  async with self._storage_lock:
 
126
  for doc_id in ids:
127
- self._data.pop(doc_id, None)
128
- await set_all_update_flags(self.namespace)
129
- await self.index_done_callback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return set(keys) - set(self._data.keys())
115
 
116
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
117
+ """
118
+ Importance notes for in-memory storage:
119
+ 1. Changes will be persisted to disk during the next index_done_callback
120
+ 2. update flags to notify other processes that data persistence is needed
121
+ """
122
  if not data:
123
  return
124
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
 
127
  await set_all_update_flags(self.namespace)
128
 
129
  async def delete(self, ids: list[str]) -> None:
130
+ """Delete specific records from storage by their IDs
131
+
132
+ Importance notes for in-memory storage:
133
+ 1. Changes will be persisted to disk during the next index_done_callback
134
+ 2. update flags to notify other processes that data persistence is needed
135
+
136
+ Args:
137
+ ids (list[str]): List of document IDs to be deleted from storage
138
+
139
+ Returns:
140
+ None
141
+ """
142
  async with self._storage_lock:
143
+ any_deleted = False
144
  for doc_id in ids:
145
+ result = self._data.pop(doc_id, None)
146
+ if result is not None:
147
+ any_deleted = True
148
+
149
+ if any_deleted:
150
+ await set_all_update_flags(self.namespace)
151
+
152
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
153
+ """Delete specific records from storage by by cache mode
154
+
155
+ Importance notes for in-memory storage:
156
+ 1. Changes will be persisted to disk during the next index_done_callback
157
+ 2. update flags to notify other processes that data persistence is needed
158
+
159
+ Args:
160
+ ids (list[str]): List of cache mode to be drop from storage
161
+
162
+ Returns:
163
+ True: if the cache drop successfully
164
+ False: if the cache drop failed
165
+ """
166
+ if not modes:
167
+ return False
168
+
169
+ try:
170
+ await self.delete(modes)
171
+ return True
172
+ except Exception:
173
+ return False
174
+
175
+ async def drop(self) -> dict[str, str]:
176
+ """Drop all data from storage and clean up resources
177
+ This action will persistent the data to disk immediately.
178
+
179
+ This method will:
180
+ 1. Clear all data from memory
181
+ 2. Update flags to notify other processes
182
+ 3. Trigger index_done_callback to save the empty state
183
+
184
+ Returns:
185
+ dict[str, str]: Operation status and message
186
+ - On success: {"status": "success", "message": "data dropped"}
187
+ - On failure: {"status": "error", "message": "<error details>"}
188
+ """
189
+ try:
190
+ async with self._storage_lock:
191
+ self._data.clear()
192
+ await set_all_update_flags(self.namespace)
193
+
194
+ await self.index_done_callback()
195
+ logger.info(f"Process {os.getpid()} drop {self.namespace}")
196
+ return {"status": "success", "message": "data dropped"}
197
+ except Exception as e:
198
+ logger.error(f"Error dropping {self.namespace}: {e}")
199
+ return {"status": "error", "message": str(e)}
lightrag/kg/milvus_impl.py CHANGED
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
- from pymilvus import MilvusClient
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
287
  except Exception as e:
288
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
289
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
+ from pymilvus import MilvusClient # type: ignore
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
 
287
  except Exception as e:
288
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
289
  return []
290
+
291
+ async def drop(self) -> dict[str, str]:
292
+ """Drop all vector data from storage and clean up resources
293
+
294
+ This method will delete all data from the Milvus collection.
295
+
296
+ Returns:
297
+ dict[str, str]: Operation status and message
298
+ - On success: {"status": "success", "message": "data dropped"}
299
+ - On failure: {"status": "error", "message": "<error details>"}
300
+ """
301
+ try:
302
+ # Drop the collection and recreate it
303
+ if self._client.has_collection(self.namespace):
304
+ self._client.drop_collection(self.namespace)
305
+
306
+ # Recreate the collection
307
+ MilvusVectorDBStorage.create_collection_if_not_exist(
308
+ self._client,
309
+ self.namespace,
310
+ dimension=self.embedding_func.embedding_dim,
311
+ )
312
+
313
+ logger.info(
314
+ f"Process {os.getpid()} drop Milvus collection {self.namespace}"
315
+ )
316
+ return {"status": "success", "message": "data dropped"}
317
+ except Exception as e:
318
+ logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
319
+ return {"status": "error", "message": str(e)}
lightrag/kg/mongo_impl.py CHANGED
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
25
  if not pm.is_installed("motor"):
26
  pm.install("motor")
27
 
28
- from motor.motor_asyncio import (
29
  AsyncIOMotorClient,
30
  AsyncIOMotorDatabase,
31
  AsyncIOMotorCollection,
32
  )
33
- from pymongo.operations import SearchIndexModel
34
- from pymongo.errors import PyMongoError
35
 
36
  config = configparser.ConfigParser()
37
  config.read("config.ini", "utf-8")
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
150
  # Mongo handles persistence automatically
151
  pass
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  @final
155
  @dataclass
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
230
  # Mongo handles persistence automatically
231
  pass
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  @final
235
  @dataclass
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
840
 
841
  logger.debug(f"Successfully deleted edges: {edges}")
842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
 
844
  @final
845
  @dataclass
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
1127
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
1128
  return []
1129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
 
1131
  async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
1132
  collection_names = await db.list_collection_names()
 
25
  if not pm.is_installed("motor"):
26
  pm.install("motor")
27
 
28
+ from motor.motor_asyncio import ( # type: ignore
29
  AsyncIOMotorClient,
30
  AsyncIOMotorDatabase,
31
  AsyncIOMotorCollection,
32
  )
33
+ from pymongo.operations import SearchIndexModel # type: ignore
34
+ from pymongo.errors import PyMongoError # type: ignore
35
 
36
  config = configparser.ConfigParser()
37
  config.read("config.ini", "utf-8")
 
150
  # Mongo handles persistence automatically
151
  pass
152
 
153
+ async def delete(self, ids: list[str]) -> None:
154
+ """Delete documents with specified IDs
155
+
156
+ Args:
157
+ ids: List of document IDs to be deleted
158
+ """
159
+ if not ids:
160
+ return
161
+
162
+ try:
163
+ result = await self._data.delete_many({"_id": {"$in": ids}})
164
+ logger.info(
165
+ f"Deleted {result.deleted_count} documents from {self.namespace}"
166
+ )
167
+ except PyMongoError as e:
168
+ logger.error(f"Error deleting documents from {self.namespace}: {e}")
169
+
170
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
171
+ """Delete specific records from storage by cache mode
172
+
173
+ Args:
174
+ modes (list[str]): List of cache modes to be dropped from storage
175
+
176
+ Returns:
177
+ bool: True if successful, False otherwise
178
+ """
179
+ if not modes:
180
+ return False
181
+
182
+ try:
183
+ # Build regex pattern to match documents with the specified modes
184
+ pattern = f"^({'|'.join(modes)})_"
185
+ result = await self._data.delete_many({"_id": {"$regex": pattern}})
186
+ logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
187
+ return True
188
+ except Exception as e:
189
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
190
+ return False
191
+
192
+ async def drop(self) -> dict[str, str]:
193
+ """Drop the storage by removing all documents in the collection.
194
+
195
+ Returns:
196
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
197
+ """
198
+ try:
199
+ result = await self._data.delete_many({})
200
+ deleted_count = result.deleted_count
201
+
202
+ logger.info(
203
+ f"Dropped {deleted_count} documents from doc status {self._collection_name}"
204
+ )
205
+ return {
206
+ "status": "success",
207
+ "message": f"{deleted_count} documents dropped",
208
+ }
209
+ except PyMongoError as e:
210
+ logger.error(f"Error dropping doc status {self._collection_name}: {e}")
211
+ return {"status": "error", "message": str(e)}
212
+
213
 
214
  @final
215
  @dataclass
 
290
  # Mongo handles persistence automatically
291
  pass
292
 
293
+ async def drop(self) -> dict[str, str]:
294
+ """Drop the storage by removing all documents in the collection.
295
+
296
+ Returns:
297
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
298
+ """
299
+ try:
300
+ result = await self._data.delete_many({})
301
+ deleted_count = result.deleted_count
302
+
303
+ logger.info(
304
+ f"Dropped {deleted_count} documents from doc status {self._collection_name}"
305
+ )
306
+ return {
307
+ "status": "success",
308
+ "message": f"{deleted_count} documents dropped",
309
+ }
310
+ except PyMongoError as e:
311
+ logger.error(f"Error dropping doc status {self._collection_name}: {e}")
312
+ return {"status": "error", "message": str(e)}
313
+
314
 
315
  @final
316
  @dataclass
 
921
 
922
  logger.debug(f"Successfully deleted edges: {edges}")
923
 
924
+ async def drop(self) -> dict[str, str]:
925
+ """Drop the storage by removing all documents in the collection.
926
+
927
+ Returns:
928
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
929
+ """
930
+ try:
931
+ result = await self.collection.delete_many({})
932
+ deleted_count = result.deleted_count
933
+
934
+ logger.info(
935
+ f"Dropped {deleted_count} documents from graph {self._collection_name}"
936
+ )
937
+ return {
938
+ "status": "success",
939
+ "message": f"{deleted_count} documents dropped",
940
+ }
941
+ except PyMongoError as e:
942
+ logger.error(f"Error dropping graph {self._collection_name}: {e}")
943
+ return {"status": "error", "message": str(e)}
944
+
945
 
946
  @final
947
  @dataclass
 
1229
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
1230
  return []
1231
 
1232
+ async def drop(self) -> dict[str, str]:
1233
+ """Drop the storage by removing all documents in the collection and recreating vector index.
1234
+
1235
+ Returns:
1236
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
1237
+ """
1238
+ try:
1239
+ # Delete all documents
1240
+ result = await self._data.delete_many({})
1241
+ deleted_count = result.deleted_count
1242
+
1243
+ # Recreate vector index
1244
+ await self.create_vector_index_if_not_exists()
1245
+
1246
+ logger.info(
1247
+ f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
1248
+ )
1249
+ return {
1250
+ "status": "success",
1251
+ "message": f"{deleted_count} documents dropped and vector index recreated",
1252
+ }
1253
+ except PyMongoError as e:
1254
+ logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
1255
+ return {"status": "error", "message": str(e)}
1256
+
1257
 
1258
  async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
1259
  collection_names = await db.list_collection_names()
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
78
  return self._client
79
 
80
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
 
 
81
  logger.info(f"Inserting {len(data)} to {self.namespace}")
82
  if not data:
83
  return
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
146
  async def delete(self, ids: list[str]):
147
  """Delete vectors with specified IDs
148
 
 
 
 
 
 
149
  Args:
150
  ids: List of vector IDs to be deleted
151
  """
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
159
  logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
160
 
161
  async def delete_entity(self, entity_name: str) -> None:
 
 
 
 
 
 
 
162
  try:
163
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
164
  logger.debug(
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
176
  logger.error(f"Error deleting entity {entity_name}: {e}")
177
 
178
  async def delete_entity_relation(self, entity_name: str) -> None:
 
 
 
 
 
 
 
179
  try:
180
  client = await self._get_client()
181
  storage = getattr(client, "_NanoVectorDB__storage")
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
280
 
281
  client = await self._get_client()
282
  return client.get(ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return self._client
79
 
80
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
81
+ """
82
+ Importance notes:
83
+ 1. Changes will be persisted to disk during the next index_done_callback
84
+ 2. Only one process should updating the storage at a time before index_done_callback,
85
+ KG-storage-log should be used to avoid data corruption
86
+ """
87
+
88
  logger.info(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
 
153
  async def delete(self, ids: list[str]):
154
  """Delete vectors with specified IDs
155
 
156
+ Importance notes:
157
+ 1. Changes will be persisted to disk during the next index_done_callback
158
+ 2. Only one process should updating the storage at a time before index_done_callback,
159
+ KG-storage-log should be used to avoid data corruption
160
+
161
  Args:
162
  ids: List of vector IDs to be deleted
163
  """
 
171
  logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
172
 
173
  async def delete_entity(self, entity_name: str) -> None:
174
+ """
175
+ Importance notes:
176
+ 1. Changes will be persisted to disk during the next index_done_callback
177
+ 2. Only one process should updating the storage at a time before index_done_callback,
178
+ KG-storage-log should be used to avoid data corruption
179
+ """
180
+
181
  try:
182
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
183
  logger.debug(
 
195
  logger.error(f"Error deleting entity {entity_name}: {e}")
196
 
197
  async def delete_entity_relation(self, entity_name: str) -> None:
198
+ """
199
+ Importance notes:
200
+ 1. Changes will be persisted to disk during the next index_done_callback
201
+ 2. Only one process should updating the storage at a time before index_done_callback,
202
+ KG-storage-log should be used to avoid data corruption
203
+ """
204
+
205
  try:
206
  client = await self._get_client()
207
  storage = getattr(client, "_NanoVectorDB__storage")
 
306
 
307
  client = await self._get_client()
308
  return client.get(ids)
309
+
310
+ async def drop(self) -> dict[str, str]:
311
+ """Drop all vector data from storage and clean up resources
312
+
313
+ This method will:
314
+ 1. Remove the vector database storage file if it exists
315
+ 2. Reinitialize the vector database client
316
+ 3. Update flags to notify other processes
317
+ 4. Changes is persisted to disk immediately
318
+
319
+ This method is intended for use in scenarios where all data needs to be removed,
320
+
321
+ Returns:
322
+ dict[str, str]: Operation status and message
323
+ - On success: {"status": "success", "message": "data dropped"}
324
+ - On failure: {"status": "error", "message": "<error details>"}
325
+ """
326
+ try:
327
+ async with self._storage_lock:
328
+ # delete _client_file_name
329
+ if os.path.exists(self._client_file_name):
330
+ os.remove(self._client_file_name)
331
+
332
+ self._client = NanoVectorDB(
333
+ self.embedding_func.embedding_dim,
334
+ storage_file=self._client_file_name,
335
+ )
336
+
337
+ # Notify other processes that data has been updated
338
+ await set_all_update_flags(self.namespace)
339
+ # Reset own update flag to avoid self-reloading
340
+ self.storage_updated.value = False
341
+
342
+ logger.info(
343
+ f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
344
+ )
345
+ return {"status": "success", "message": "data dropped"}
346
+ except Exception as e:
347
+ logger.error(f"Error dropping {self.namespace}: {e}")
348
+ return {"status": "error", "message": str(e)}
lightrag/kg/neo4j_impl.py CHANGED
@@ -1,9 +1,8 @@
1
- import asyncio
2
  import inspect
3
  import os
4
  import re
5
  from dataclasses import dataclass
6
- from typing import Any, final, Optional
7
  import numpy as np
8
  import configparser
9
 
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
29
  exceptions as neo4jExceptions,
30
  AsyncDriver,
31
  AsyncManagedTransaction,
32
- GraphDatabase,
33
  )
34
 
35
  config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
52
  embedding_func=embedding_func,
53
  )
54
  self._driver = None
55
- self._driver_lock = asyncio.Lock()
56
 
 
 
 
 
 
 
57
  URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
58
  USERNAME = os.environ.get(
59
  "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
86
  ),
87
  )
88
  DATABASE = os.environ.get(
89
- "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
90
  )
91
 
92
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
98
  max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
99
  )
100
 
101
- # Try to connect to the database
102
- with GraphDatabase.driver(
103
- URI,
104
- auth=(USERNAME, PASSWORD),
105
- max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
106
- connection_timeout=CONNECTION_TIMEOUT,
107
- connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
108
- ) as _sync_driver:
109
- for database in (DATABASE, None):
110
- self._DATABASE = database
111
- connected = False
112
 
113
- try:
114
- with _sync_driver.session(database=database) as session:
115
- try:
116
- session.run("MATCH (n) RETURN n LIMIT 0")
117
- logger.info(f"Connected to {database} at {URI}")
118
- connected = True
119
- except neo4jExceptions.ServiceUnavailable as e:
120
- logger.error(
121
- f"{database} at {URI} is not available".capitalize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  raise e
124
- except neo4jExceptions.AuthError as e:
125
- logger.error(f"Authentication failed for {database} at {URI}")
126
- raise e
127
- except neo4jExceptions.ClientError as e:
128
- if e.code == "Neo.ClientError.Database.DatabaseNotFound":
129
- logger.info(
130
- f"{database} at {URI} not found. Try to create specified database.".capitalize()
131
- )
 
 
 
132
  try:
133
- with _sync_driver.session() as session:
134
- session.run(
135
- f"CREATE DATABASE `{database}` IF NOT EXISTS"
136
- )
137
- logger.info(f"{database} at {URI} created".capitalize())
138
- connected = True
139
- except (
140
- neo4jExceptions.ClientError,
141
- neo4jExceptions.DatabaseError,
142
- ) as e:
143
- if (
144
- e.code
145
- == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
146
- ) or (
147
- e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
148
- ):
149
- if database is not None:
150
- logger.warning(
151
- "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
152
- )
153
- if database is None:
154
- logger.error(f"Failed to create {database} at {URI}")
155
- raise e
156
 
157
- if connected:
158
- break
159
 
160
- def __post_init__(self):
161
- self._node_embed_algorithms = {
162
- "node2vec": self._node2vec_embed,
163
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- async def close(self):
166
  """Close the Neo4j driver and release all resources"""
167
  if self._driver:
168
  await self._driver.close()
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
170
 
171
  async def __aexit__(self, exc_type, exc, tb):
172
  """Ensure driver is closed when context manager exits"""
173
- await self.close()
174
 
175
  async def index_done_callback(self) -> None:
176
  # Noe4J handles persistence automatically
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
243
  raise
244
 
245
  async def get_node(self, node_id: str) -> dict[str, str] | None:
246
- """Get node by its label identifier.
247
 
248
  Args:
249
  node_id: The node label to look up
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
428
  logger.debug(
429
  f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
430
  )
431
- # Return default edge properties when no edge found
432
- return {
433
- "weight": 0.0,
434
- "source_id": None,
435
- "description": None,
436
- "keywords": None,
437
- }
438
  finally:
439
  await result.consume() # Ensure result is fully consumed
440
 
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
526
  """
527
  properties = node_data
528
  entity_type = properties["entity_type"]
529
- entity_id = properties["entity_id"]
530
  if "entity_id" not in properties:
531
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
532
 
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
536
  async def execute_upsert(tx: AsyncManagedTransaction):
537
  query = (
538
  """
539
- MERGE (n:base {entity_id: $properties.entity_id})
540
  SET n += $properties
541
  SET n:`%s`
542
  """
543
  % entity_type
544
  )
545
- result = await tx.run(query, properties=properties)
 
 
546
  logger.debug(
547
- f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
548
  )
549
  await result.consume() # Ensure result is fully consumed
550
 
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
622
  self,
623
  node_label: str,
624
  max_depth: int = 3,
625
- min_degree: int = 0,
626
- inclusive: bool = False,
627
  ) -> KnowledgeGraph:
628
  """
629
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
630
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
631
- When reducing the number of nodes, the prioritization criteria are as follows:
632
- 1. min_degree does not affect nodes directly connected to the matching nodes
633
- 2. Label matching nodes take precedence
634
- 3. Followed by nodes directly connected to the matching nodes
635
- 4. Finally, the degree of the nodes
636
 
637
  Args:
638
- node_label: Label of the starting node
639
- max_depth: Maximum depth of the subgraph
640
- min_degree: Minimum degree of nodes to include. Defaults to 0
641
- inclusive: Do an inclusive search if true
642
  Returns:
643
- KnowledgeGraph: Complete connected subgraph for specified node
 
644
  """
645
  result = KnowledgeGraph()
646
  seen_nodes = set()
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
651
  ) as session:
652
  try:
653
  if node_label == "*":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  main_query = """
655
  MATCH (n)
656
  OPTIONAL MATCH (n)-[r]-()
657
  WITH n, COALESCE(count(r), 0) AS degree
658
- WHERE degree >= $min_degree
659
  ORDER BY degree DESC
660
  LIMIT $max_nodes
661
  WITH collect({node: n}) AS filtered_nodes
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
666
  RETURN filtered_nodes AS node_info,
667
  collect(DISTINCT r) AS relationships
668
  """
669
- result_set = await session.run(
670
- main_query,
671
- {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
672
- )
 
 
 
 
 
 
673
 
674
  else:
675
- # Main query uses partial matching
676
- main_query = """
 
677
  MATCH (start)
678
- WHERE
679
- CASE
680
- WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
681
- ELSE start.entity_id = $entity_id
682
- END
683
  WITH start
684
  CALL apoc.path.subgraphAll(start, {
685
  relationshipFilter: '',
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
688
  bfs: true
689
  })
690
  YIELD nodes, relationships
691
- WITH start, nodes, relationships
692
  UNWIND nodes AS node
693
- OPTIONAL MATCH (node)-[r]-()
694
- WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
695
- WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
696
- ORDER BY
697
- CASE
698
- WHEN node = start THEN 3
699
- WHEN EXISTS((start)--(node)) THEN 2
700
- ELSE 1
701
- END DESC,
702
- degree DESC
703
- LIMIT $max_nodes
704
- WITH collect({node: node}) AS filtered_nodes
705
- UNWIND filtered_nodes AS node_info
706
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
707
- OPTIONAL MATCH (a)-[r]-(b)
708
- WHERE a IN kept_nodes AND b IN kept_nodes
709
- RETURN filtered_nodes AS node_info,
710
- collect(DISTINCT r) AS relationships
711
  """
712
- result_set = await session.run(
713
- main_query,
714
- {
715
- "max_nodes": MAX_GRAPH_NODES,
716
- "entity_id": node_label,
717
- "inclusive": inclusive,
718
- "max_depth": max_depth,
719
- "min_degree": min_degree,
720
- },
721
- )
722
 
723
- try:
724
- record = await result_set.single()
725
-
726
- if record:
727
- # Handle nodes (compatible with multi-label cases)
728
- for node_info in record["node_info"]:
729
- node = node_info["node"]
730
- node_id = node.id
731
- if node_id not in seen_nodes:
732
- result.nodes.append(
733
- KnowledgeGraphNode(
734
- id=f"{node_id}",
735
- labels=[node.get("entity_id")],
736
- properties=dict(node),
737
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  )
739
- seen_nodes.add(node_id)
740
-
741
- # Handle relationships (including direction information)
742
- for rel in record["relationships"]:
743
- edge_id = rel.id
744
- if edge_id not in seen_edges:
745
- start = rel.start_node
746
- end = rel.end_node
747
- result.edges.append(
748
- KnowledgeGraphEdge(
749
- id=f"{edge_id}",
750
- type=rel.type,
751
- source=f"{start.id}",
752
- target=f"{end.id}",
753
- properties=dict(rel),
754
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  )
756
- seen_edges.add(edge_id)
 
757
 
758
- logger.info(
759
- f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
760
- )
761
- finally:
762
- await result_set.consume() # Ensure result set is consumed
763
 
764
  except neo4jExceptions.ClientError as e:
765
  logger.warning(f"APOC plugin error: {str(e)}")
@@ -767,46 +837,89 @@ class Neo4JStorage(BaseGraphStorage):
767
  logger.warning(
768
  "Neo4j: falling back to basic Cypher recursive search..."
769
  )
770
- if inclusive:
771
- logger.warning(
772
- "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
773
- )
774
- return await self._robust_fallback(
775
- node_label, max_depth, min_degree
776
  )
777
 
778
  return result
779
 
780
  async def _robust_fallback(
781
- self, node_label: str, max_depth: int, min_degree: int = 0
782
  ) -> KnowledgeGraph:
783
  """
784
  Fallback implementation when APOC plugin is not available or incompatible.
785
  This method implements the same functionality as get_knowledge_graph but uses
786
- only basic Cypher queries and recursive traversal instead of APOC procedures.
787
  """
 
 
788
  result = KnowledgeGraph()
789
  visited_nodes = set()
790
  visited_edges = set()
 
791
 
792
- async def traverse(
793
- node: KnowledgeGraphNode,
794
- edge: Optional[KnowledgeGraphEdge],
795
- current_depth: int,
796
- ):
797
- # Check traversal limits
798
- if current_depth > max_depth:
799
- logger.debug(f"Reached max depth: {max_depth}")
800
- return
801
- if len(visited_nodes) >= MAX_GRAPH_NODES:
802
- logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
803
- return
 
 
 
 
 
 
 
 
 
 
804
 
805
- # Check if node already visited
806
- if node.id in visited_nodes:
807
- return
808
 
809
- # Get all edges and target nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  async with self._driver.session(
811
  database=self._DATABASE, default_access_mode="READ"
812
  ) as session:
@@ -815,32 +928,17 @@ class Neo4JStorage(BaseGraphStorage):
815
  WITH r, b, id(r) as edge_id, id(b) as target_id
816
  RETURN r, b, edge_id, target_id
817
  """
818
- results = await session.run(query, entity_id=node.id)
819
 
820
  # Get all records and release database connection
821
- records = await results.fetch(
822
- 1000
823
- ) # Max neighbour nodes we can handled
824
  await results.consume() # Ensure results are consumed
825
 
826
- # Nodes not connected to start node need to check degree
827
- if current_depth > 1 and len(records) < min_degree:
828
- return
829
-
830
- # Add current node to result
831
- result.nodes.append(node)
832
- visited_nodes.add(node.id)
833
-
834
- # Add edge to result if it exists and not already added
835
- if edge and edge.id not in visited_edges:
836
- result.edges.append(edge)
837
- visited_edges.add(edge.id)
838
-
839
- # Prepare nodes and edges for recursive processing
840
- nodes_to_process = []
841
  for record in records:
842
  rel = record["r"]
843
  edge_id = str(record["edge_id"])
 
844
  if edge_id not in visited_edges:
845
  b_node = record["b"]
846
  target_id = b_node.get("entity_id")
@@ -849,55 +947,59 @@ class Neo4JStorage(BaseGraphStorage):
849
  # Create KnowledgeGraphNode for target
850
  target_node = KnowledgeGraphNode(
851
  id=f"{target_id}",
852
- labels=list(f"{target_id}"),
853
- properties=dict(b_node.properties),
854
  )
855
 
856
  # Create KnowledgeGraphEdge
857
  target_edge = KnowledgeGraphEdge(
858
  id=f"{edge_id}",
859
  type=rel.type,
860
- source=f"{node.id}",
861
  target=f"{target_id}",
862
  properties=dict(rel),
863
  )
864
 
865
- nodes_to_process.append((target_node, target_edge))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  else:
867
  logger.warning(
868
- f"Skipping edge {edge_id} due to missing labels on target node"
869
  )
870
 
871
- # Process nodes after releasing database connection
872
- for target_node, target_edge in nodes_to_process:
873
- await traverse(target_node, target_edge, current_depth + 1)
874
-
875
- # Get the starting node's data
876
- async with self._driver.session(
877
- database=self._DATABASE, default_access_mode="READ"
878
- ) as session:
879
- query = """
880
- MATCH (n:base {entity_id: $entity_id})
881
- RETURN id(n) as node_id, n
882
- """
883
- node_result = await session.run(query, entity_id=node_label)
884
- try:
885
- node_record = await node_result.single()
886
- if not node_record:
887
- return result
888
-
889
- # Create initial KnowledgeGraphNode
890
- start_node = KnowledgeGraphNode(
891
- id=f"{node_record['n'].get('entity_id')}",
892
- labels=list(f"{node_record['n'].get('entity_id')}"),
893
- properties=dict(node_record["n"].properties),
894
- )
895
- finally:
896
- await node_result.consume() # Ensure results are consumed
897
-
898
- # Start traversal with the initial node
899
- await traverse(start_node, None, 0)
900
-
901
  return result
902
 
903
  async def get_all_labels(self) -> list[str]:
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
914
 
915
  # Method 2: Query compatible with older versions
916
  query = """
917
- MATCH (n)
918
  WHERE n.entity_id IS NOT NULL
919
  RETURN DISTINCT n.entity_id AS label
920
  ORDER BY label
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
1028
  self, algorithm: str
1029
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1030
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
  import re
4
  from dataclasses import dataclass
5
+ from typing import Any, final
6
  import numpy as np
7
  import configparser
8
 
 
28
  exceptions as neo4jExceptions,
29
  AsyncDriver,
30
  AsyncManagedTransaction,
 
31
  )
32
 
33
  config = configparser.ConfigParser()
 
50
  embedding_func=embedding_func,
51
  )
52
  self._driver = None
 
53
 
54
+ def __post_init__(self):
55
+ self._node_embed_algorithms = {
56
+ "node2vec": self._node2vec_embed,
57
+ }
58
+
59
+ async def initialize(self):
60
  URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
61
  USERNAME = os.environ.get(
62
  "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
 
89
  ),
90
  )
91
  DATABASE = os.environ.get(
92
+ "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
93
  )
94
 
95
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
 
101
  max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
102
  )
103
 
104
+ # Try to connect to the database and create it if it doesn't exist
105
+ for database in (DATABASE, None):
106
+ self._DATABASE = database
107
+ connected = False
 
 
 
 
 
 
 
108
 
109
+ try:
110
+ async with self._driver.session(database=database) as session:
111
+ try:
112
+ result = await session.run("MATCH (n) RETURN n LIMIT 0")
113
+ await result.consume() # Ensure result is consumed
114
+ logger.info(f"Connected to {database} at {URI}")
115
+ connected = True
116
+ except neo4jExceptions.ServiceUnavailable as e:
117
+ logger.error(
118
+ f"{database} at {URI} is not available".capitalize()
119
+ )
120
+ raise e
121
+ except neo4jExceptions.AuthError as e:
122
+ logger.error(f"Authentication failed for {database} at {URI}")
123
+ raise e
124
+ except neo4jExceptions.ClientError as e:
125
+ if e.code == "Neo.ClientError.Database.DatabaseNotFound":
126
+ logger.info(
127
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
128
+ )
129
+ try:
130
+ async with self._driver.session() as session:
131
+ result = await session.run(
132
+ f"CREATE DATABASE `{database}` IF NOT EXISTS"
133
  )
134
+ await result.consume() # Ensure result is consumed
135
+ logger.info(f"{database} at {URI} created".capitalize())
136
+ connected = True
137
+ except (
138
+ neo4jExceptions.ClientError,
139
+ neo4jExceptions.DatabaseError,
140
+ ) as e:
141
+ if (
142
+ e.code
143
+ == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
144
+ ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
145
+ if database is not None:
146
+ logger.warning(
147
+ "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
148
+ )
149
+ if database is None:
150
+ logger.error(f"Failed to create {database} at {URI}")
151
  raise e
152
+
153
+ if connected:
154
+ # Create index for base nodes on entity_id if it doesn't exist
155
+ try:
156
+ async with self._driver.session(database=database) as session:
157
+ # Check if index exists first
158
+ check_query = """
159
+ CALL db.indexes() YIELD name, labelsOrTypes, properties
160
+ WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
161
+ RETURN count(*) > 0 AS exists
162
+ """
163
  try:
164
+ check_result = await session.run(check_query)
165
+ record = await check_result.single()
166
+ await check_result.consume()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ index_exists = record and record.get("exists", False)
 
169
 
170
+ if not index_exists:
171
+ # Create index only if it doesn't exist
172
+ result = await session.run(
173
+ "CREATE INDEX FOR (n:base) ON (n.entity_id)"
174
+ )
175
+ await result.consume()
176
+ logger.info(
177
+ f"Created index for base nodes on entity_id in {database}"
178
+ )
179
+ except Exception:
180
+ # Fallback if db.indexes() is not supported in this Neo4j version
181
+ result = await session.run(
182
+ "CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
183
+ )
184
+ await result.consume()
185
+ except Exception as e:
186
+ logger.warning(f"Failed to create index: {str(e)}")
187
+ break
188
 
189
+ async def finalize(self):
190
  """Close the Neo4j driver and release all resources"""
191
  if self._driver:
192
  await self._driver.close()
 
194
 
195
  async def __aexit__(self, exc_type, exc, tb):
196
  """Ensure driver is closed when context manager exits"""
197
+ await self.finalize()
198
 
199
  async def index_done_callback(self) -> None:
200
  # Noe4J handles persistence automatically
 
267
  raise
268
 
269
  async def get_node(self, node_id: str) -> dict[str, str] | None:
270
+ """Get node by its label identifier, return only node properties
271
 
272
  Args:
273
  node_id: The node label to look up
 
452
  logger.debug(
453
  f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
454
  )
455
+ # Return None when no edge found
456
+ return None
 
 
 
 
 
457
  finally:
458
  await result.consume() # Ensure result is fully consumed
459
 
 
545
  """
546
  properties = node_data
547
  entity_type = properties["entity_type"]
 
548
  if "entity_id" not in properties:
549
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
550
 
 
554
  async def execute_upsert(tx: AsyncManagedTransaction):
555
  query = (
556
  """
557
+ MERGE (n:base {entity_id: $entity_id})
558
  SET n += $properties
559
  SET n:`%s`
560
  """
561
  % entity_type
562
  )
563
+ result = await tx.run(
564
+ query, entity_id=node_id, properties=properties
565
+ )
566
  logger.debug(
567
+ f"Upserted node with entity_id '{node_id}' and properties: {properties}"
568
  )
569
  await result.consume() # Ensure result is fully consumed
570
 
 
642
  self,
643
  node_label: str,
644
  max_depth: int = 3,
645
+ max_nodes: int = MAX_GRAPH_NODES,
 
646
  ) -> KnowledgeGraph:
647
  """
648
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
 
 
 
 
 
 
649
 
650
  Args:
651
+ node_label: Label of the starting node, * means all nodes
652
+ max_depth: Maximum depth of the subgraph, Defaults to 3
653
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
654
+
655
  Returns:
656
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
657
+ indicating whether the graph was truncated due to max_nodes limit
658
  """
659
  result = KnowledgeGraph()
660
  seen_nodes = set()
 
665
  ) as session:
666
  try:
667
  if node_label == "*":
668
+ # First check total node count to determine if graph is truncated
669
+ count_query = "MATCH (n) RETURN count(n) as total"
670
+ count_result = None
671
+ try:
672
+ count_result = await session.run(count_query)
673
+ count_record = await count_result.single()
674
+
675
+ if count_record and count_record["total"] > max_nodes:
676
+ result.is_truncated = True
677
+ logger.info(
678
+ f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
679
+ )
680
+ finally:
681
+ if count_result:
682
+ await count_result.consume()
683
+
684
+ # Run main query to get nodes with highest degree
685
  main_query = """
686
  MATCH (n)
687
  OPTIONAL MATCH (n)-[r]-()
688
  WITH n, COALESCE(count(r), 0) AS degree
 
689
  ORDER BY degree DESC
690
  LIMIT $max_nodes
691
  WITH collect({node: n}) AS filtered_nodes
 
696
  RETURN filtered_nodes AS node_info,
697
  collect(DISTINCT r) AS relationships
698
  """
699
+ result_set = None
700
+ try:
701
+ result_set = await session.run(
702
+ main_query,
703
+ {"max_nodes": max_nodes},
704
+ )
705
+ record = await result_set.single()
706
+ finally:
707
+ if result_set:
708
+ await result_set.consume()
709
 
710
  else:
711
+ # return await self._robust_fallback(node_label, max_depth, max_nodes)
712
+ # First try without limit to check if we need to truncate
713
+ full_query = """
714
  MATCH (start)
715
+ WHERE start.entity_id = $entity_id
 
 
 
 
716
  WITH start
717
  CALL apoc.path.subgraphAll(start, {
718
  relationshipFilter: '',
 
721
  bfs: true
722
  })
723
  YIELD nodes, relationships
724
+ WITH nodes, relationships, size(nodes) AS total_nodes
725
  UNWIND nodes AS node
726
+ WITH collect({node: node}) AS node_info, relationships, total_nodes
727
+ RETURN node_info, relationships, total_nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  """
 
 
 
 
 
 
 
 
 
 
729
 
730
+ # Try to get full result
731
+ full_result = None
732
+ try:
733
+ full_result = await session.run(
734
+ full_query,
735
+ {
736
+ "entity_id": node_label,
737
+ "max_depth": max_depth,
738
+ },
739
+ )
740
+ full_record = await full_result.single()
741
+
742
+ # If no record found, return empty KnowledgeGraph
743
+ if not full_record:
744
+ logger.debug(f"No nodes found for entity_id: {node_label}")
745
+ return result
746
+
747
+ # If record found, check node count
748
+ total_nodes = full_record["total_nodes"]
749
+
750
+ if total_nodes <= max_nodes:
751
+ # If node count is within limit, use full result directly
752
+ logger.debug(
753
+ f"Using full result with {total_nodes} nodes (no truncation needed)"
754
+ )
755
+ record = full_record
756
+ else:
757
+ # If node count exceeds limit, set truncated flag and run limited query
758
+ result.is_truncated = True
759
+ logger.info(
760
+ f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
761
+ )
762
+
763
+ # Run limited query
764
+ limited_query = """
765
+ MATCH (start)
766
+ WHERE start.entity_id = $entity_id
767
+ WITH start
768
+ CALL apoc.path.subgraphAll(start, {
769
+ relationshipFilter: '',
770
+ minLevel: 0,
771
+ maxLevel: $max_depth,
772
+ limit: $max_nodes,
773
+ bfs: true
774
+ })
775
+ YIELD nodes, relationships
776
+ UNWIND nodes AS node
777
+ WITH collect({node: node}) AS node_info, relationships
778
+ RETURN node_info, relationships
779
+ """
780
+ result_set = None
781
+ try:
782
+ result_set = await session.run(
783
+ limited_query,
784
+ {
785
+ "entity_id": node_label,
786
+ "max_depth": max_depth,
787
+ "max_nodes": max_nodes,
788
+ },
789
  )
790
+ record = await result_set.single()
791
+ finally:
792
+ if result_set:
793
+ await result_set.consume()
794
+ finally:
795
+ if full_result:
796
+ await full_result.consume()
797
+
798
+ if record:
799
+ # Handle nodes (compatible with multi-label cases)
800
+ for node_info in record["node_info"]:
801
+ node = node_info["node"]
802
+ node_id = node.id
803
+ if node_id not in seen_nodes:
804
+ result.nodes.append(
805
+ KnowledgeGraphNode(
806
+ id=f"{node_id}",
807
+ labels=[node.get("entity_id")],
808
+ properties=dict(node),
809
+ )
810
+ )
811
+ seen_nodes.add(node_id)
812
+
813
+ # Handle relationships (including direction information)
814
+ for rel in record["relationships"]:
815
+ edge_id = rel.id
816
+ if edge_id not in seen_edges:
817
+ start = rel.start_node
818
+ end = rel.end_node
819
+ result.edges.append(
820
+ KnowledgeGraphEdge(
821
+ id=f"{edge_id}",
822
+ type=rel.type,
823
+ source=f"{start.id}",
824
+ target=f"{end.id}",
825
+ properties=dict(rel),
826
  )
827
+ )
828
+ seen_edges.add(edge_id)
829
 
830
+ logger.info(
831
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
832
+ )
 
 
833
 
834
  except neo4jExceptions.ClientError as e:
835
  logger.warning(f"APOC plugin error: {str(e)}")
 
837
  logger.warning(
838
  "Neo4j: falling back to basic Cypher recursive search..."
839
  )
840
+ return await self._robust_fallback(node_label, max_depth, max_nodes)
841
+ else:
842
+ logger.warning(
843
+ "Neo4j: APOC plugin error with wildcard query, returning empty result"
 
 
844
  )
845
 
846
  return result
847
 
848
  async def _robust_fallback(
849
+ self, node_label: str, max_depth: int, max_nodes: int
850
  ) -> KnowledgeGraph:
851
  """
852
  Fallback implementation when APOC plugin is not available or incompatible.
853
  This method implements the same functionality as get_knowledge_graph but uses
854
+ only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
855
  """
856
+ from collections import deque
857
+
858
  result = KnowledgeGraph()
859
  visited_nodes = set()
860
  visited_edges = set()
861
+ visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
862
 
863
+ # Get the starting node's data
864
+ async with self._driver.session(
865
+ database=self._DATABASE, default_access_mode="READ"
866
+ ) as session:
867
+ query = """
868
+ MATCH (n:base {entity_id: $entity_id})
869
+ RETURN id(n) as node_id, n
870
+ """
871
+ node_result = await session.run(query, entity_id=node_label)
872
+ try:
873
+ node_record = await node_result.single()
874
+ if not node_record:
875
+ return result
876
+
877
+ # Create initial KnowledgeGraphNode
878
+ start_node = KnowledgeGraphNode(
879
+ id=f"{node_record['n'].get('entity_id')}",
880
+ labels=[node_record["n"].get("entity_id")],
881
+ properties=dict(node_record["n"]._properties),
882
+ )
883
+ finally:
884
+ await node_result.consume() # Ensure results are consumed
885
 
886
+ # Initialize queue for BFS with (node, edge, depth) tuples
887
+ # edge is None for the starting node
888
+ queue = deque([(start_node, None, 0)])
889
 
890
+ # True BFS implementation using a queue
891
+ while queue and len(visited_nodes) < max_nodes:
892
+ # Dequeue the next node to process
893
+ current_node, current_edge, current_depth = queue.popleft()
894
+
895
+ # Skip if already visited or exceeds max depth
896
+ if current_node.id in visited_nodes:
897
+ continue
898
+
899
+ if current_depth > max_depth:
900
+ logger.debug(
901
+ f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
902
+ )
903
+ continue
904
+
905
+ # Add current node to result
906
+ result.nodes.append(current_node)
907
+ visited_nodes.add(current_node.id)
908
+
909
+ # Add edge to result if it exists and not already added
910
+ if current_edge and current_edge.id not in visited_edges:
911
+ result.edges.append(current_edge)
912
+ visited_edges.add(current_edge.id)
913
+
914
+ # Stop if we've reached the node limit
915
+ if len(visited_nodes) >= max_nodes:
916
+ result.is_truncated = True
917
+ logger.info(
918
+ f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
919
+ )
920
+ break
921
+
922
+ # Get all edges and target nodes for the current node (even at max_depth)
923
  async with self._driver.session(
924
  database=self._DATABASE, default_access_mode="READ"
925
  ) as session:
 
928
  WITH r, b, id(r) as edge_id, id(b) as target_id
929
  RETURN r, b, edge_id, target_id
930
  """
931
+ results = await session.run(query, entity_id=current_node.id)
932
 
933
  # Get all records and release database connection
934
+ records = await results.fetch(1000) # Max neighbor nodes we can handle
 
 
935
  await results.consume() # Ensure results are consumed
936
 
937
+ # Process all neighbors - capture all edges but only queue unvisited nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  for record in records:
939
  rel = record["r"]
940
  edge_id = str(record["edge_id"])
941
+
942
  if edge_id not in visited_edges:
943
  b_node = record["b"]
944
  target_id = b_node.get("entity_id")
 
947
  # Create KnowledgeGraphNode for target
948
  target_node = KnowledgeGraphNode(
949
  id=f"{target_id}",
950
+ labels=[target_id],
951
+ properties=dict(b_node._properties),
952
  )
953
 
954
  # Create KnowledgeGraphEdge
955
  target_edge = KnowledgeGraphEdge(
956
  id=f"{edge_id}",
957
  type=rel.type,
958
+ source=f"{current_node.id}",
959
  target=f"{target_id}",
960
  properties=dict(rel),
961
  )
962
 
963
+ # 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
964
+ sorted_pair = tuple(sorted([current_node.id, target_id]))
965
+
966
+ # 检查是否已存在相同的边(考虑无向性)
967
+ if sorted_pair not in visited_edge_pairs:
968
+ # 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
969
+ if target_id in visited_nodes or (
970
+ target_id not in visited_nodes
971
+ and current_depth < max_depth
972
+ ):
973
+ result.edges.append(target_edge)
974
+ visited_edges.add(edge_id)
975
+ visited_edge_pairs.add(sorted_pair)
976
+
977
+ # Only add unvisited nodes to the queue for further expansion
978
+ if target_id not in visited_nodes:
979
+ # Only add to queue if we're not at max depth yet
980
+ if current_depth < max_depth:
981
+ # Add node to queue with incremented depth
982
+ # Edge is already added to result, so we pass None as edge
983
+ queue.append((target_node, None, current_depth + 1))
984
+ else:
985
+ # At max depth, we've already added the edge but we don't add the node
986
+ # This prevents adding nodes beyond max_depth to the result
987
+ logger.debug(
988
+ f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
989
+ )
990
+ else:
991
+ # If target node already exists in result, we don't need to add it again
992
+ logger.debug(
993
+ f"Node {target_id} already visited, edge added but node not queued"
994
+ )
995
  else:
996
  logger.warning(
997
+ f"Skipping edge {edge_id} due to missing entity_id on target node"
998
  )
999
 
1000
+ logger.info(
1001
+ f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
1002
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  return result
1004
 
1005
  async def get_all_labels(self) -> list[str]:
 
1016
 
1017
  # Method 2: Query compatible with older versions
1018
  query = """
1019
+ MATCH (n:base)
1020
  WHERE n.entity_id IS NOT NULL
1021
  RETURN DISTINCT n.entity_id AS label
1022
  ORDER BY label
 
1130
  self, algorithm: str
1131
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1132
  raise NotImplementedError
1133
+
1134
+ async def drop(self) -> dict[str, str]:
1135
+ """Drop all data from storage and clean up resources
1136
+
1137
+ This method will delete all nodes and relationships in the Neo4j database.
1138
+
1139
+ Returns:
1140
+ dict[str, str]: Operation status and message
1141
+ - On success: {"status": "success", "message": "data dropped"}
1142
+ - On failure: {"status": "error", "message": "<error details>"}
1143
+ """
1144
+ try:
1145
+ async with self._driver.session(database=self._DATABASE) as session:
1146
+ # Delete all nodes and relationships
1147
+ query = "MATCH (n) DETACH DELETE n"
1148
+ result = await session.run(query)
1149
+ await result.consume() # Ensure result is fully consumed
1150
+
1151
+ logger.info(
1152
+ f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
1153
+ )
1154
+ return {"status": "success", "message": "data dropped"}
1155
+ except Exception as e:
1156
+ logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
1157
+ return {"status": "error", "message": str(e)}
lightrag/kg/networkx_impl.py CHANGED
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
42
  )
43
  nx.write_graphml(graph, file_name)
44
 
 
45
  @staticmethod
46
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
47
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
155
  return None
156
 
157
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
 
 
 
 
 
158
  graph = await self._get_graph()
159
  graph.add_node(node_id, **node_data)
160
 
161
  async def upsert_edge(
162
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
163
  ) -> None:
 
 
 
 
 
 
164
  graph = await self._get_graph()
165
  graph.add_edge(source_node_id, target_node_id, **edge_data)
166
 
167
  async def delete_node(self, node_id: str) -> None:
 
 
 
 
 
 
168
  graph = await self._get_graph()
169
  if graph.has_node(node_id):
170
  graph.remove_node(node_id)
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
172
  else:
173
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
174
 
 
175
  async def embed_nodes(
176
  self, algorithm: str
177
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
192
  async def remove_nodes(self, nodes: list[str]):
193
  """Delete multiple nodes
194
 
 
 
 
 
 
195
  Args:
196
  nodes: List of node IDs to be deleted
197
  """
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
203
  async def remove_edges(self, edges: list[tuple[str, str]]):
204
  """Delete multiple edges
205
 
 
 
 
 
 
206
  Args:
207
  edges: List of edges to be deleted, each edge is a (source, target) tuple
208
  """
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
229
  self,
230
  node_label: str,
231
  max_depth: int = 3,
232
- min_degree: int = 0,
233
- inclusive: bool = False,
234
  ) -> KnowledgeGraph:
235
  """
236
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
237
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
238
- When reducing the number of nodes, the prioritization criteria are as follows:
239
- 1. min_degree does not affect nodes directly connected to the matching nodes
240
- 2. Label matching nodes take precedence
241
- 3. Followed by nodes directly connected to the matching nodes
242
- 4. Finally, the degree of the nodes
243
 
244
  Args:
245
- node_label: Label of the starting node
246
- max_depth: Maximum depth of the subgraph
247
- min_degree: Minimum degree of nodes to include. Defaults to 0
248
- inclusive: Do an inclusive search if true
249
 
250
  Returns:
251
- KnowledgeGraph object containing nodes and edges
 
252
  """
253
- result = KnowledgeGraph()
254
- seen_nodes = set()
255
- seen_edges = set()
256
-
257
  graph = await self._get_graph()
258
 
259
- # Initialize sets for start nodes and direct connected nodes
260
- start_nodes = set()
261
- direct_connected_nodes = set()
262
 
263
  # Handle special case for "*" label
264
  if node_label == "*":
265
- # For "*", return the entire graph including all nodes and edges
266
- subgraph = (
267
- graph.copy()
268
- ) # Create a copy to avoid modifying the original graph
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
- # Find nodes with matching node id based on search_mode
271
- nodes_to_explore = []
272
- for n, attr in graph.nodes(data=True):
273
- node_str = str(n)
274
- if not inclusive:
275
- if node_label == node_str: # Use exact matching
276
- nodes_to_explore.append(n)
277
- else: # inclusive mode
278
- if node_label in node_str: # Use partial matching
279
- nodes_to_explore.append(n)
280
-
281
- if not nodes_to_explore:
282
- logger.warning(f"No nodes found with label {node_label}")
283
- return result
284
-
285
- # Get subgraph using ego_graph from all matching nodes
286
- combined_subgraph = nx.Graph()
287
- for start_node in nodes_to_explore:
288
- node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
289
- combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
290
-
291
- # Get start nodes and direct connected nodes
292
- if nodes_to_explore:
293
- start_nodes = set(nodes_to_explore)
294
- # Get nodes directly connected to all start nodes
295
- for start_node in start_nodes:
296
- direct_connected_nodes.update(
297
- combined_subgraph.neighbors(start_node)
298
- )
299
-
300
- # Remove start nodes from directly connected nodes (avoid duplicates)
301
- direct_connected_nodes -= start_nodes
302
-
303
- subgraph = combined_subgraph
304
-
305
- # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
306
- if min_degree > 0:
307
- nodes_to_keep = [
308
- node
309
- for node, degree in subgraph.degree()
310
- if node in start_nodes
311
- or node in direct_connected_nodes
312
- or degree >= min_degree
313
- ]
314
- subgraph = subgraph.subgraph(nodes_to_keep)
315
-
316
- # Check if number of nodes exceeds max_graph_nodes
317
- if len(subgraph.nodes()) > MAX_GRAPH_NODES:
318
- origin_nodes = len(subgraph.nodes())
319
- node_degrees = dict(subgraph.degree())
320
-
321
- def priority_key(node_item):
322
- node, degree = node_item
323
- # Priority order: start(2) > directly connected(1) > other nodes(0)
324
- if node in start_nodes:
325
- priority = 2
326
- elif node in direct_connected_nodes:
327
- priority = 1
328
- else:
329
- priority = 0
330
- return (priority, degree)
331
-
332
- # Sort by priority and degree and select top MAX_GRAPH_NODES nodes
333
- top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
334
- :MAX_GRAPH_NODES
335
- ]
336
- top_node_ids = [node[0] for node in top_nodes]
337
- # Create new subgraph and keep nodes only with most degree
338
- subgraph = subgraph.subgraph(top_node_ids)
339
- logger.info(
340
- f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
341
- )
342
 
343
  # Add nodes to result
 
 
344
  for node in subgraph.nodes():
345
  if str(node) in seen_nodes:
346
  continue
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
368
  for edge in subgraph.edges():
369
  source, target = edge
370
  # Esure unique edge_id for undirect graph
371
- if source > target:
372
  source, target = target, source
373
  edge_id = f"{source}-{target}"
374
  if edge_id in seen_edges:
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
424
  return False # Return error
425
 
426
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
  nx.write_graphml(graph, file_name)
44
 
45
+ # TODO:deprecated, remove later
46
  @staticmethod
47
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
48
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
 
156
  return None
157
 
158
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
159
+ """
160
+ Importance notes:
161
+ 1. Changes will be persisted to disk during the next index_done_callback
162
+ 2. Only one process should updating the storage at a time before index_done_callback,
163
+ KG-storage-log should be used to avoid data corruption
164
+ """
165
  graph = await self._get_graph()
166
  graph.add_node(node_id, **node_data)
167
 
168
  async def upsert_edge(
169
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
170
  ) -> None:
171
+ """
172
+ Importance notes:
173
+ 1. Changes will be persisted to disk during the next index_done_callback
174
+ 2. Only one process should updating the storage at a time before index_done_callback,
175
+ KG-storage-log should be used to avoid data corruption
176
+ """
177
  graph = await self._get_graph()
178
  graph.add_edge(source_node_id, target_node_id, **edge_data)
179
 
180
  async def delete_node(self, node_id: str) -> None:
181
+ """
182
+ Importance notes:
183
+ 1. Changes will be persisted to disk during the next index_done_callback
184
+ 2. Only one process should updating the storage at a time before index_done_callback,
185
+ KG-storage-log should be used to avoid data corruption
186
+ """
187
  graph = await self._get_graph()
188
  if graph.has_node(node_id):
189
  graph.remove_node(node_id)
 
191
  else:
192
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
193
 
194
+ # TODO: NOT USED
195
  async def embed_nodes(
196
  self, algorithm: str
197
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
 
212
  async def remove_nodes(self, nodes: list[str]):
213
  """Delete multiple nodes
214
 
215
+ Importance notes:
216
+ 1. Changes will be persisted to disk during the next index_done_callback
217
+ 2. Only one process should updating the storage at a time before index_done_callback,
218
+ KG-storage-log should be used to avoid data corruption
219
+
220
  Args:
221
  nodes: List of node IDs to be deleted
222
  """
 
228
  async def remove_edges(self, edges: list[tuple[str, str]]):
229
  """Delete multiple edges
230
 
231
+ Importance notes:
232
+ 1. Changes will be persisted to disk during the next index_done_callback
233
+ 2. Only one process should updating the storage at a time before index_done_callback,
234
+ KG-storage-log should be used to avoid data corruption
235
+
236
  Args:
237
  edges: List of edges to be deleted, each edge is a (source, target) tuple
238
  """
 
259
  self,
260
  node_label: str,
261
  max_depth: int = 3,
262
+ max_nodes: int = MAX_GRAPH_NODES,
 
263
  ) -> KnowledgeGraph:
264
  """
265
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
 
 
 
 
 
 
266
 
267
  Args:
268
+ node_label: Label of the starting node,* means all nodes
269
+ max_depth: Maximum depth of the subgraph, Defaults to 3
270
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
 
271
 
272
  Returns:
273
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
274
+ indicating whether the graph was truncated due to max_nodes limit
275
  """
 
 
 
 
276
  graph = await self._get_graph()
277
 
278
+ result = KnowledgeGraph()
 
 
279
 
280
  # Handle special case for "*" label
281
  if node_label == "*":
282
+ # Get degrees of all nodes
283
+ degrees = dict(graph.degree())
284
+ # Sort nodes by degree in descending order and take top max_nodes
285
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
286
+
287
+ # Check if graph is truncated
288
+ if len(sorted_nodes) > max_nodes:
289
+ result.is_truncated = True
290
+ logger.info(
291
+ f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
292
+ )
293
+
294
+ limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
295
+ # Create subgraph with the highest degree nodes
296
+ subgraph = graph.subgraph(limited_nodes)
297
  else:
298
+ # Check if node exists
299
+ if node_label not in graph:
300
+ logger.warning(f"Node {node_label} not found in the graph")
301
+ return KnowledgeGraph() # Return empty graph
302
+
303
+ # Use BFS to get nodes
304
+ bfs_nodes = []
305
+ visited = set()
306
+ queue = [(node_label, 0)] # (node, depth) tuple
307
+
308
+ # Breadth-first search
309
+ while queue and len(bfs_nodes) < max_nodes:
310
+ current, depth = queue.pop(0)
311
+ if current not in visited:
312
+ visited.add(current)
313
+ bfs_nodes.append(current)
314
+
315
+ # Only explore neighbors if we haven't reached max_depth
316
+ if depth < max_depth:
317
+ # Add neighbor nodes to queue with incremented depth
318
+ neighbors = list(graph.neighbors(current))
319
+ queue.extend(
320
+ [(n, depth + 1) for n in neighbors if n not in visited]
321
+ )
322
+
323
+ # Check if graph is truncated - if we still have nodes in the queue
324
+ # and we've reached max_nodes, then the graph is truncated
325
+ if queue and len(bfs_nodes) >= max_nodes:
326
+ result.is_truncated = True
327
+ logger.info(
328
+ f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
329
+ )
330
+
331
+ # Create subgraph with BFS discovered nodes
332
+ subgraph = graph.subgraph(bfs_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  # Add nodes to result
335
+ seen_nodes = set()
336
+ seen_edges = set()
337
  for node in subgraph.nodes():
338
  if str(node) in seen_nodes:
339
  continue
 
361
  for edge in subgraph.edges():
362
  source, target = edge
363
  # Esure unique edge_id for undirect graph
364
+ if str(source) > str(target):
365
  source, target = target, source
366
  edge_id = f"{source}-{target}"
367
  if edge_id in seen_edges:
 
417
  return False # Return error
418
 
419
  return True
420
+
421
+ async def drop(self) -> dict[str, str]:
422
+ """Drop all graph data from storage and clean up resources
423
+
424
+ This method will:
425
+ 1. Remove the graph storage file if it exists
426
+ 2. Reset the graph to an empty state
427
+ 3. Update flags to notify other processes
428
+ 4. Changes is persisted to disk immediately
429
+
430
+ Returns:
431
+ dict[str, str]: Operation status and message
432
+ - On success: {"status": "success", "message": "data dropped"}
433
+ - On failure: {"status": "error", "message": "<error details>"}
434
+ """
435
+ try:
436
+ async with self._storage_lock:
437
+ # delete _client_file_name
438
+ if os.path.exists(self._graphml_xml_file):
439
+ os.remove(self._graphml_xml_file)
440
+ self._graph = nx.Graph()
441
+ # Notify other processes that data has been updated
442
+ await set_all_update_flags(self.namespace)
443
+ # Reset own update flag to avoid self-reloading
444
+ self.storage_updated.value = False
445
+ logger.info(
446
+ f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
447
+ )
448
+ return {"status": "success", "message": "data dropped"}
449
+ except Exception as e:
450
+ logger.error(f"Error dropping graph {self.namespace}: {e}")
451
+ return {"status": "error", "message": str(e)}
lightrag/kg/oracle_impl.py DELETED
@@ -1,1346 +0,0 @@
1
- import array
2
- import asyncio
3
-
4
- # import html
5
- import os
6
- from dataclasses import dataclass, field
7
- from typing import Any, Union, final
8
- import numpy as np
9
- import configparser
10
-
11
- from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
12
-
13
- from ..base import (
14
- BaseGraphStorage,
15
- BaseKVStorage,
16
- BaseVectorStorage,
17
- )
18
- from ..namespace import NameSpace, is_namespace
19
- from ..utils import logger
20
-
21
- import pipmaster as pm
22
-
23
- if not pm.is_installed("graspologic"):
24
- pm.install("graspologic")
25
-
26
- if not pm.is_installed("oracledb"):
27
- pm.install("oracledb")
28
-
29
- from graspologic import embed
30
- import oracledb
31
-
32
-
33
- class OracleDB:
34
- def __init__(self, config, **kwargs):
35
- self.host = config.get("host", None)
36
- self.port = config.get("port", None)
37
- self.user = config.get("user", None)
38
- self.password = config.get("password", None)
39
- self.dsn = config.get("dsn", None)
40
- self.config_dir = config.get("config_dir", None)
41
- self.wallet_location = config.get("wallet_location", None)
42
- self.wallet_password = config.get("wallet_password", None)
43
- self.workspace = config.get("workspace", None)
44
- self.max = 12
45
- self.increment = 1
46
- logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
47
- if self.user is None or self.password is None:
48
- raise ValueError("Missing database user or password")
49
-
50
- try:
51
- oracledb.defaults.fetch_lobs = False
52
-
53
- self.pool = oracledb.create_pool_async(
54
- user=self.user,
55
- password=self.password,
56
- dsn=self.dsn,
57
- config_dir=self.config_dir,
58
- wallet_location=self.wallet_location,
59
- wallet_password=self.wallet_password,
60
- min=1,
61
- max=self.max,
62
- increment=self.increment,
63
- )
64
- logger.info(f"Connected to Oracle database at {self.dsn}")
65
- except Exception as e:
66
- logger.error(f"Failed to connect to Oracle database at {self.dsn}")
67
- logger.error(f"Oracle database error: {e}")
68
- raise
69
-
70
- def numpy_converter_in(self, value):
71
- """Convert numpy array to array.array"""
72
- if value.dtype == np.float64:
73
- dtype = "d"
74
- elif value.dtype == np.float32:
75
- dtype = "f"
76
- else:
77
- dtype = "b"
78
- return array.array(dtype, value)
79
-
80
- def input_type_handler(self, cursor, value, arraysize):
81
- """Set the type handler for the input data"""
82
- if isinstance(value, np.ndarray):
83
- return cursor.var(
84
- oracledb.DB_TYPE_VECTOR,
85
- arraysize=arraysize,
86
- inconverter=self.numpy_converter_in,
87
- )
88
-
89
- def numpy_converter_out(self, value):
90
- """Convert array.array to numpy array"""
91
- if value.typecode == "b":
92
- dtype = np.int8
93
- elif value.typecode == "f":
94
- dtype = np.float32
95
- else:
96
- dtype = np.float64
97
- return np.array(value, copy=False, dtype=dtype)
98
-
99
- def output_type_handler(self, cursor, metadata):
100
- """Set the type handler for the output data"""
101
- if metadata.type_code is oracledb.DB_TYPE_VECTOR:
102
- return cursor.var(
103
- metadata.type_code,
104
- arraysize=cursor.arraysize,
105
- outconverter=self.numpy_converter_out,
106
- )
107
-
108
- async def check_tables(self):
109
- for k, v in TABLES.items():
110
- try:
111
- if k.lower() == "lightrag_graph":
112
- await self.query(
113
- "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
114
- )
115
- else:
116
- await self.query(f"SELECT 1 FROM {k}")
117
- except Exception as e:
118
- logger.error(f"Failed to check table {k} in Oracle database")
119
- logger.error(f"Oracle database error: {e}")
120
- try:
121
- # print(v["ddl"])
122
- await self.execute(v["ddl"])
123
- logger.info(f"Created table {k} in Oracle database")
124
- except Exception as e:
125
- logger.error(f"Failed to create table {k} in Oracle database")
126
- logger.error(f"Oracle database error: {e}")
127
-
128
- logger.info("Finished check all tables in Oracle database")
129
-
130
- async def query(
131
- self, sql: str, params: dict = None, multirows: bool = False
132
- ) -> Union[dict, None]:
133
- async with self.pool.acquire() as connection:
134
- connection.inputtypehandler = self.input_type_handler
135
- connection.outputtypehandler = self.output_type_handler
136
- with connection.cursor() as cursor:
137
- try:
138
- await cursor.execute(sql, params)
139
- except Exception as e:
140
- logger.error(f"Oracle database error: {e}")
141
- raise
142
- columns = [column[0].lower() for column in cursor.description]
143
- if multirows:
144
- rows = await cursor.fetchall()
145
- if rows:
146
- data = [dict(zip(columns, row)) for row in rows]
147
- else:
148
- data = []
149
- else:
150
- row = await cursor.fetchone()
151
- if row:
152
- data = dict(zip(columns, row))
153
- else:
154
- data = None
155
- return data
156
-
157
- async def execute(self, sql: str, data: Union[list, dict] = None):
158
- # logger.info("go into OracleDB execute method")
159
- try:
160
- async with self.pool.acquire() as connection:
161
- connection.inputtypehandler = self.input_type_handler
162
- connection.outputtypehandler = self.output_type_handler
163
- with connection.cursor() as cursor:
164
- if data is None:
165
- await cursor.execute(sql)
166
- else:
167
- await cursor.execute(sql, data)
168
- await connection.commit()
169
- except Exception as e:
170
- logger.error(f"Oracle database error: {e}")
171
- raise
172
-
173
-
174
- class ClientManager:
175
- _instances: dict[str, Any] = {"db": None, "ref_count": 0}
176
- _lock = asyncio.Lock()
177
-
178
- @staticmethod
179
- def get_config() -> dict[str, Any]:
180
- config = configparser.ConfigParser()
181
- config.read("config.ini", "utf-8")
182
-
183
- return {
184
- "user": os.environ.get(
185
- "ORACLE_USER",
186
- config.get("oracle", "user", fallback=None),
187
- ),
188
- "password": os.environ.get(
189
- "ORACLE_PASSWORD",
190
- config.get("oracle", "password", fallback=None),
191
- ),
192
- "dsn": os.environ.get(
193
- "ORACLE_DSN",
194
- config.get("oracle", "dsn", fallback=None),
195
- ),
196
- "config_dir": os.environ.get(
197
- "ORACLE_CONFIG_DIR",
198
- config.get("oracle", "config_dir", fallback=None),
199
- ),
200
- "wallet_location": os.environ.get(
201
- "ORACLE_WALLET_LOCATION",
202
- config.get("oracle", "wallet_location", fallback=None),
203
- ),
204
- "wallet_password": os.environ.get(
205
- "ORACLE_WALLET_PASSWORD",
206
- config.get("oracle", "wallet_password", fallback=None),
207
- ),
208
- "workspace": os.environ.get(
209
- "ORACLE_WORKSPACE",
210
- config.get("oracle", "workspace", fallback="default"),
211
- ),
212
- }
213
-
214
- @classmethod
215
- async def get_client(cls) -> OracleDB:
216
- async with cls._lock:
217
- if cls._instances["db"] is None:
218
- config = ClientManager.get_config()
219
- db = OracleDB(config)
220
- await db.check_tables()
221
- cls._instances["db"] = db
222
- cls._instances["ref_count"] = 0
223
- cls._instances["ref_count"] += 1
224
- return cls._instances["db"]
225
-
226
- @classmethod
227
- async def release_client(cls, db: OracleDB):
228
- async with cls._lock:
229
- if db is not None:
230
- if db is cls._instances["db"]:
231
- cls._instances["ref_count"] -= 1
232
- if cls._instances["ref_count"] == 0:
233
- await db.pool.close()
234
- logger.info("Closed OracleDB database connection pool")
235
- cls._instances["db"] = None
236
- else:
237
- await db.pool.close()
238
-
239
-
240
- @final
241
- @dataclass
242
- class OracleKVStorage(BaseKVStorage):
243
- db: OracleDB = field(default=None)
244
- meta_fields = None
245
-
246
- def __post_init__(self):
247
- self._data = {}
248
- self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
249
-
250
- async def initialize(self):
251
- if self.db is None:
252
- self.db = await ClientManager.get_client()
253
-
254
- async def finalize(self):
255
- if self.db is not None:
256
- await ClientManager.release_client(self.db)
257
- self.db = None
258
-
259
- ################ QUERY METHODS ################
260
-
261
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
262
- """Get doc_full data based on id."""
263
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
264
- params = {"workspace": self.db.workspace, "id": id}
265
- # print("get_by_id:"+SQL)
266
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
267
- array_res = await self.db.query(SQL, params, multirows=True)
268
- res = {}
269
- for row in array_res:
270
- res[row["id"]] = row
271
- if res:
272
- return res
273
- else:
274
- return None
275
- else:
276
- return await self.db.query(SQL, params)
277
-
278
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
279
- """Specifically for llm_response_cache."""
280
- SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
281
- params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
282
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
283
- array_res = await self.db.query(SQL, params, multirows=True)
284
- res = {}
285
- for row in array_res:
286
- res[row["id"]] = row
287
- return res
288
- else:
289
- return None
290
-
291
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
292
- """Get doc_chunks data based on id"""
293
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
294
- ids=",".join([f"'{id}'" for id in ids])
295
- )
296
- params = {"workspace": self.db.workspace}
297
- # print("get_by_ids:"+SQL)
298
- res = await self.db.query(SQL, params, multirows=True)
299
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
300
- modes = set()
301
- dict_res: dict[str, dict] = {}
302
- for row in res:
303
- modes.add(row["mode"])
304
- for mode in modes:
305
- if mode not in dict_res:
306
- dict_res[mode] = {}
307
- for row in res:
308
- dict_res[row["mode"]][row["id"]] = row
309
- res = [{k: v} for k, v in dict_res.items()]
310
- return res
311
-
312
- async def filter_keys(self, keys: set[str]) -> set[str]:
313
- """Return keys that don't exist in storage"""
314
- SQL = SQL_TEMPLATES["filter_keys"].format(
315
- table_name=namespace_to_table_name(self.namespace),
316
- ids=",".join([f"'{id}'" for id in keys]),
317
- )
318
- params = {"workspace": self.db.workspace}
319
- res = await self.db.query(SQL, params, multirows=True)
320
- if res:
321
- exist_keys = [key["id"] for key in res]
322
- data = set([s for s in keys if s not in exist_keys])
323
- return data
324
- else:
325
- return set(keys)
326
-
327
- ################ INSERT METHODS ################
328
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
329
- logger.info(f"Inserting {len(data)} to {self.namespace}")
330
- if not data:
331
- return
332
-
333
- if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
334
- list_data = [
335
- {
336
- "id": k,
337
- **{k1: v1 for k1, v1 in v.items()},
338
- }
339
- for k, v in data.items()
340
- ]
341
- contents = [v["content"] for v in data.values()]
342
- batches = [
343
- contents[i : i + self._max_batch_size]
344
- for i in range(0, len(contents), self._max_batch_size)
345
- ]
346
- embeddings_list = await asyncio.gather(
347
- *[self.embedding_func(batch) for batch in batches]
348
- )
349
- embeddings = np.concatenate(embeddings_list)
350
- for i, d in enumerate(list_data):
351
- d["__vector__"] = embeddings[i]
352
-
353
- merge_sql = SQL_TEMPLATES["merge_chunk"]
354
- for item in list_data:
355
- _data = {
356
- "id": item["id"],
357
- "content": item["content"],
358
- "workspace": self.db.workspace,
359
- "tokens": item["tokens"],
360
- "chunk_order_index": item["chunk_order_index"],
361
- "full_doc_id": item["full_doc_id"],
362
- "content_vector": item["__vector__"],
363
- "status": item["status"],
364
- }
365
- await self.db.execute(merge_sql, _data)
366
- if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
367
- for k, v in data.items():
368
- # values.clear()
369
- merge_sql = SQL_TEMPLATES["merge_doc_full"]
370
- _data = {
371
- "id": k,
372
- "content": v["content"],
373
- "workspace": self.db.workspace,
374
- }
375
- await self.db.execute(merge_sql, _data)
376
-
377
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
378
- for mode, items in data.items():
379
- for k, v in items.items():
380
- upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
381
- _data = {
382
- "workspace": self.db.workspace,
383
- "id": k,
384
- "original_prompt": v["original_prompt"],
385
- "return_value": v["return"],
386
- "cache_mode": mode,
387
- }
388
-
389
- await self.db.execute(upsert_sql, _data)
390
-
391
- async def index_done_callback(self) -> None:
392
- # Oracle handles persistence automatically
393
- pass
394
-
395
-
396
- @final
397
- @dataclass
398
- class OracleVectorDBStorage(BaseVectorStorage):
399
- db: OracleDB | None = field(default=None)
400
-
401
- def __post_init__(self):
402
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
403
- cosine_threshold = config.get("cosine_better_than_threshold")
404
- if cosine_threshold is None:
405
- raise ValueError(
406
- "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
407
- )
408
- self.cosine_better_than_threshold = cosine_threshold
409
-
410
- async def initialize(self):
411
- if self.db is None:
412
- self.db = await ClientManager.get_client()
413
-
414
- async def finalize(self):
415
- if self.db is not None:
416
- await ClientManager.release_client(self.db)
417
- self.db = None
418
-
419
- #################### query method ###############
420
- async def query(
421
- self, query: str, top_k: int, ids: list[str] | None = None
422
- ) -> list[dict[str, Any]]:
423
- embeddings = await self.embedding_func([query])
424
- embedding = embeddings[0]
425
- # 转换精度
426
- dtype = str(embedding.dtype).upper()
427
- dimension = embedding.shape[0]
428
- embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
429
-
430
- SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
431
- params = {
432
- "embedding_string": embedding_string,
433
- "workspace": self.db.workspace,
434
- "top_k": top_k,
435
- "better_than_threshold": self.cosine_better_than_threshold,
436
- }
437
- results = await self.db.query(SQL, params=params, multirows=True)
438
- return results
439
-
440
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
441
- raise NotImplementedError
442
-
443
- async def index_done_callback(self) -> None:
444
- # Oracles handles persistence automatically
445
- pass
446
-
447
- async def delete(self, ids: list[str]) -> None:
448
- """Delete vectors with specified IDs
449
-
450
- Args:
451
- ids: List of vector IDs to be deleted
452
- """
453
- if not ids:
454
- return
455
-
456
- try:
457
- SQL = SQL_TEMPLATES["delete_vectors"].format(
458
- ids=",".join([f"'{id}'" for id in ids])
459
- )
460
- params = {"workspace": self.db.workspace}
461
- await self.db.execute(SQL, params)
462
- logger.info(
463
- f"Successfully deleted {len(ids)} vectors from {self.namespace}"
464
- )
465
- except Exception as e:
466
- logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
467
- raise
468
-
469
- async def delete_entity(self, entity_name: str) -> None:
470
- """Delete entity by name
471
-
472
- Args:
473
- entity_name: Name of the entity to delete
474
- """
475
- try:
476
- SQL = SQL_TEMPLATES["delete_entity"]
477
- params = {"workspace": self.db.workspace, "entity_name": entity_name}
478
- await self.db.execute(SQL, params)
479
- logger.info(f"Successfully deleted entity {entity_name}")
480
- except Exception as e:
481
- logger.error(f"Error deleting entity {entity_name}: {e}")
482
- raise
483
-
484
- async def delete_entity_relation(self, entity_name: str) -> None:
485
- """Delete all relations connected to an entity
486
-
487
- Args:
488
- entity_name: Name of the entity whose relations should be deleted
489
- """
490
- try:
491
- SQL = SQL_TEMPLATES["delete_entity_relations"]
492
- params = {"workspace": self.db.workspace, "entity_name": entity_name}
493
- await self.db.execute(SQL, params)
494
- logger.info(f"Successfully deleted relations for entity {entity_name}")
495
- except Exception as e:
496
- logger.error(f"Error deleting relations for entity {entity_name}: {e}")
497
- raise
498
-
499
- async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
500
- """Search for records with IDs starting with a specific prefix.
501
-
502
- Args:
503
- prefix: The prefix to search for in record IDs
504
-
505
- Returns:
506
- List of records with matching ID prefixes
507
- """
508
- try:
509
- # Determine the appropriate table based on namespace
510
- table_name = namespace_to_table_name(self.namespace)
511
-
512
- # Create SQL query to find records with IDs starting with prefix
513
- search_sql = f"""
514
- SELECT * FROM {table_name}
515
- WHERE workspace = :workspace
516
- AND id LIKE :prefix_pattern
517
- ORDER BY id
518
- """
519
-
520
- params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
521
-
522
- # Execute query and get results
523
- results = await self.db.query(search_sql, params, multirows=True)
524
-
525
- logger.debug(
526
- f"Found {len(results) if results else 0} records with prefix '{prefix}'"
527
- )
528
- return results or []
529
-
530
- except Exception as e:
531
- logger.error(f"Error searching records with prefix '{prefix}': {e}")
532
- return []
533
-
534
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
535
- """Get vector data by its ID
536
-
537
- Args:
538
- id: The unique identifier of the vector
539
-
540
- Returns:
541
- The vector data if found, or None if not found
542
- """
543
- try:
544
- # Determine the table name based on namespace
545
- table_name = namespace_to_table_name(self.namespace)
546
- if not table_name:
547
- logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
548
- return None
549
-
550
- # Create the appropriate ID field name based on namespace
551
- id_field = "entity_id" if "NODES" in table_name else "relation_id"
552
- if "CHUNKS" in table_name:
553
- id_field = "chunk_id"
554
-
555
- # Prepare and execute the query
556
- query = f"""
557
- SELECT * FROM {table_name}
558
- WHERE {id_field} = :id AND workspace = :workspace
559
- """
560
- params = {"id": id, "workspace": self.db.workspace}
561
-
562
- result = await self.db.query(query, params)
563
- return result
564
- except Exception as e:
565
- logger.error(f"Error retrieving vector data for ID {id}: {e}")
566
- return None
567
-
568
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
569
- """Get multiple vector data by their IDs
570
-
571
- Args:
572
- ids: List of unique identifiers
573
-
574
- Returns:
575
- List of vector data objects that were found
576
- """
577
- if not ids:
578
- return []
579
-
580
- try:
581
- # Determine the table name based on namespace
582
- table_name = namespace_to_table_name(self.namespace)
583
- if not table_name:
584
- logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
585
- return []
586
-
587
- # Create the appropriate ID field name based on namespace
588
- id_field = "entity_id" if "NODES" in table_name else "relation_id"
589
- if "CHUNKS" in table_name:
590
- id_field = "chunk_id"
591
-
592
- # Format the list of IDs for SQL IN clause
593
- ids_list = ", ".join([f"'{id}'" for id in ids])
594
-
595
- # Prepare and execute the query
596
- query = f"""
597
- SELECT * FROM {table_name}
598
- WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
599
- """
600
- params = {"workspace": self.db.workspace}
601
-
602
- results = await self.db.query(query, params, multirows=True)
603
- return results or []
604
- except Exception as e:
605
- logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
606
- return []
607
-
608
-
609
- @final
610
- @dataclass
611
- class OracleGraphStorage(BaseGraphStorage):
612
- db: OracleDB = field(default=None)
613
-
614
- def __post_init__(self):
615
- self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
616
-
617
- async def initialize(self):
618
- if self.db is None:
619
- self.db = await ClientManager.get_client()
620
-
621
- async def finalize(self):
622
- if self.db is not None:
623
- await ClientManager.release_client(self.db)
624
- self.db = None
625
-
626
- #################### insert method ################
627
-
628
- async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
629
- entity_name = node_id
630
- entity_type = node_data["entity_type"]
631
- description = node_data["description"]
632
- source_id = node_data["source_id"]
633
- logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
634
-
635
- content = entity_name + description
636
- contents = [content]
637
- batches = [
638
- contents[i : i + self._max_batch_size]
639
- for i in range(0, len(contents), self._max_batch_size)
640
- ]
641
- embeddings_list = await asyncio.gather(
642
- *[self.embedding_func(batch) for batch in batches]
643
- )
644
- embeddings = np.concatenate(embeddings_list)
645
- content_vector = embeddings[0]
646
- merge_sql = SQL_TEMPLATES["merge_node"]
647
- data = {
648
- "workspace": self.db.workspace,
649
- "name": entity_name,
650
- "entity_type": entity_type,
651
- "description": description,
652
- "source_chunk_id": source_id,
653
- "content": content,
654
- "content_vector": content_vector,
655
- }
656
- await self.db.execute(merge_sql, data)
657
- # self._graph.add_node(node_id, **node_data)
658
-
659
- async def upsert_edge(
660
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
661
- ) -> None:
662
- """插入或更新边"""
663
- # print("go into upsert edge method")
664
- source_name = source_node_id
665
- target_name = target_node_id
666
- weight = edge_data["weight"]
667
- keywords = edge_data["keywords"]
668
- description = edge_data["description"]
669
- source_chunk_id = edge_data["source_id"]
670
- logger.debug(
671
- f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
672
- )
673
-
674
- content = keywords + source_name + target_name + description
675
- contents = [content]
676
- batches = [
677
- contents[i : i + self._max_batch_size]
678
- for i in range(0, len(contents), self._max_batch_size)
679
- ]
680
- embeddings_list = await asyncio.gather(
681
- *[self.embedding_func(batch) for batch in batches]
682
- )
683
- embeddings = np.concatenate(embeddings_list)
684
- content_vector = embeddings[0]
685
- merge_sql = SQL_TEMPLATES["merge_edge"]
686
- data = {
687
- "workspace": self.db.workspace,
688
- "source_name": source_name,
689
- "target_name": target_name,
690
- "weight": weight,
691
- "keywords": keywords,
692
- "description": description,
693
- "source_chunk_id": source_chunk_id,
694
- "content": content,
695
- "content_vector": content_vector,
696
- }
697
- # print(merge_sql)
698
- await self.db.execute(merge_sql, data)
699
- # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
700
-
701
- async def embed_nodes(
702
- self, algorithm: str
703
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
704
- if algorithm not in self._node_embed_algorithms:
705
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
706
- return await self._node_embed_algorithms[algorithm]()
707
-
708
- async def _node2vec_embed(self):
709
- """为节点生成向量"""
710
- embeddings, nodes = embed.node2vec_embed(
711
- self._graph,
712
- **self.config["node2vec_params"],
713
- )
714
-
715
- nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
716
- return embeddings, nodes_ids
717
-
718
- async def index_done_callback(self) -> None:
719
- # Oracles handles persistence automatically
720
- pass
721
-
722
- #################### query method #################
723
- async def has_node(self, node_id: str) -> bool:
724
- """根据节点id检查节点是否存在"""
725
- SQL = SQL_TEMPLATES["has_node"]
726
- params = {"workspace": self.db.workspace, "node_id": node_id}
727
- res = await self.db.query(SQL, params)
728
- if res:
729
- # print("Node exist!",res)
730
- return True
731
- else:
732
- # print("Node not exist!")
733
- return False
734
-
735
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
736
- SQL = SQL_TEMPLATES["has_edge"]
737
- params = {
738
- "workspace": self.db.workspace,
739
- "source_node_id": source_node_id,
740
- "target_node_id": target_node_id,
741
- }
742
- res = await self.db.query(SQL, params)
743
- if res:
744
- # print("Edge exist!",res)
745
- return True
746
- else:
747
- # print("Edge not exist!")
748
- return False
749
-
750
- async def node_degree(self, node_id: str) -> int:
751
- SQL = SQL_TEMPLATES["node_degree"]
752
- params = {"workspace": self.db.workspace, "node_id": node_id}
753
- res = await self.db.query(SQL, params)
754
- if res:
755
- return res["degree"]
756
- else:
757
- return 0
758
-
759
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
760
- """根据源和目标节点id获取边的度"""
761
- degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
762
- return degree
763
-
764
- async def get_node(self, node_id: str) -> dict[str, str] | None:
765
- """根据节点id获取节点数据"""
766
- SQL = SQL_TEMPLATES["get_node"]
767
- params = {"workspace": self.db.workspace, "node_id": node_id}
768
- res = await self.db.query(SQL, params)
769
- if res:
770
- return res
771
- else:
772
- return None
773
-
774
- async def get_edge(
775
- self, source_node_id: str, target_node_id: str
776
- ) -> dict[str, str] | None:
777
- SQL = SQL_TEMPLATES["get_edge"]
778
- params = {
779
- "workspace": self.db.workspace,
780
- "source_node_id": source_node_id,
781
- "target_node_id": target_node_id,
782
- }
783
- res = await self.db.query(SQL, params)
784
- if res:
785
- # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
786
- return res
787
- else:
788
- # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
789
- return None
790
-
791
- async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
792
- if await self.has_node(source_node_id):
793
- SQL = SQL_TEMPLATES["get_node_edges"]
794
- params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
795
- res = await self.db.query(sql=SQL, params=params, multirows=True)
796
- if res:
797
- data = [(i["source_name"], i["target_name"]) for i in res]
798
- # print("Get node edge!",self.db.workspace, source_node_id,data)
799
- return data
800
- else:
801
- # print("Node Edge not exist!",self.db.workspace, source_node_id)
802
- return []
803
-
804
- async def get_all_nodes(self, limit: int):
805
- """查询所有节点"""
806
- SQL = SQL_TEMPLATES["get_all_nodes"]
807
- params = {"workspace": self.db.workspace, "limit": str(limit)}
808
- res = await self.db.query(sql=SQL, params=params, multirows=True)
809
- if res:
810
- return res
811
-
812
- async def get_all_edges(self, limit: int):
813
- """查询所有边"""
814
- SQL = SQL_TEMPLATES["get_all_edges"]
815
- params = {"workspace": self.db.workspace, "limit": str(limit)}
816
- res = await self.db.query(sql=SQL, params=params, multirows=True)
817
- if res:
818
- return res
819
-
820
- async def get_statistics(self):
821
- SQL = SQL_TEMPLATES["get_statistics"]
822
- params = {"workspace": self.db.workspace}
823
- res = await self.db.query(sql=SQL, params=params, multirows=True)
824
- if res:
825
- return res
826
-
827
- async def delete_node(self, node_id: str) -> None:
828
- """Delete a node from the graph
829
-
830
- Args:
831
- node_id: ID of the node to delete
832
- """
833
- try:
834
- # First delete all relations connected to this node
835
- delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
836
- params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
837
- await self.db.execute(delete_relations_sql, params_relations)
838
-
839
- # Then delete the node itself
840
- delete_node_sql = SQL_TEMPLATES["delete_entity"]
841
- params_node = {"workspace": self.db.workspace, "entity_name": node_id}
842
- await self.db.execute(delete_node_sql, params_node)
843
-
844
- logger.info(
845
- f"Successfully deleted node {node_id} and all its relationships"
846
- )
847
- except Exception as e:
848
- logger.error(f"Error deleting node {node_id}: {e}")
849
- raise
850
-
851
- async def remove_nodes(self, nodes: list[str]) -> None:
852
- """Delete multiple nodes from the graph
853
-
854
- Args:
855
- nodes: List of node IDs to be deleted
856
- """
857
- if not nodes:
858
- return
859
-
860
- try:
861
- for node in nodes:
862
- # For each node, first delete all its relationships
863
- delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
864
- params_relations = {"workspace": self.db.workspace, "entity_name": node}
865
- await self.db.execute(delete_relations_sql, params_relations)
866
-
867
- # Then delete the node itself
868
- delete_node_sql = SQL_TEMPLATES["delete_entity"]
869
- params_node = {"workspace": self.db.workspace, "entity_name": node}
870
- await self.db.execute(delete_node_sql, params_node)
871
-
872
- logger.info(
873
- f"Successfully deleted {len(nodes)} nodes and their relationships"
874
- )
875
- except Exception as e:
876
- logger.error(f"Error during batch node deletion: {e}")
877
- raise
878
-
879
- async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
880
- """Delete multiple edges from the graph
881
-
882
- Args:
883
- edges: List of edges to be deleted, each edge is a (source, target) tuple
884
- """
885
- if not edges:
886
- return
887
-
888
- try:
889
- for source, target in edges:
890
- # Check if the edge exists before attempting to delete
891
- if await self.has_edge(source, target):
892
- # Delete the edge using a SQL query that matches both source and target
893
- delete_edge_sql = """
894
- DELETE FROM LIGHTRAG_GRAPH_EDGES
895
- WHERE workspace = :workspace
896
- AND source_name = :source_name
897
- AND target_name = :target_name
898
- """
899
- params = {
900
- "workspace": self.db.workspace,
901
- "source_name": source,
902
- "target_name": target,
903
- }
904
- await self.db.execute(delete_edge_sql, params)
905
-
906
- logger.info(f"Successfully deleted {len(edges)} edges from the graph")
907
- except Exception as e:
908
- logger.error(f"Error during batch edge deletion: {e}")
909
- raise
910
-
911
- async def get_all_labels(self) -> list[str]:
912
- """Get all unique entity types (labels) in the graph
913
-
914
- Returns:
915
- List of unique entity types/labels
916
- """
917
- try:
918
- SQL = """
919
- SELECT DISTINCT entity_type
920
- FROM LIGHTRAG_GRAPH_NODES
921
- WHERE workspace = :workspace
922
- ORDER BY entity_type
923
- """
924
- params = {"workspace": self.db.workspace}
925
- results = await self.db.query(SQL, params, multirows=True)
926
-
927
- if results:
928
- labels = [row["entity_type"] for row in results]
929
- return labels
930
- else:
931
- return []
932
- except Exception as e:
933
- logger.error(f"Error retrieving entity types: {e}")
934
- return []
935
-
936
- async def get_knowledge_graph(
937
- self, node_label: str, max_depth: int = 5
938
- ) -> KnowledgeGraph:
939
- """Retrieve a connected subgraph starting from nodes matching the given label
940
-
941
- Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
942
- Prioritizes nodes by:
943
- 1. Nodes matching the specified label
944
- 2. Nodes directly connected to matching nodes
945
- 3. Node degree (number of connections)
946
-
947
- Args:
948
- node_label: Label to match for starting nodes (use "*" for all nodes)
949
- max_depth: Maximum depth of traversal from starting nodes
950
-
951
- Returns:
952
- KnowledgeGraph object containing nodes and edges
953
- """
954
- result = KnowledgeGraph()
955
-
956
- try:
957
- # Define maximum number of nodes to return
958
- max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
959
-
960
- if node_label == "*":
961
- # For "*" label, get all nodes up to the limit
962
- nodes_sql = """
963
- SELECT name, entity_type, description, source_chunk_id
964
- FROM LIGHTRAG_GRAPH_NODES
965
- WHERE workspace = :workspace
966
- ORDER BY id
967
- FETCH FIRST :limit ROWS ONLY
968
- """
969
- nodes_params = {
970
- "workspace": self.db.workspace,
971
- "limit": max_graph_nodes,
972
- }
973
- nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
974
- else:
975
- # For specific label, find matching nodes and related nodes
976
- nodes_sql = """
977
- WITH matching_nodes AS (
978
- SELECT name
979
- FROM LIGHTRAG_GRAPH_NODES
980
- WHERE workspace = :workspace
981
- AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
982
- )
983
- SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
984
- CASE
985
- WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
986
- WHEN EXISTS (
987
- SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
988
- WHERE workspace = :workspace
989
- AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
990
- OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
991
- ) THEN 1
992
- ELSE 0
993
- END AS priority,
994
- (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
995
- WHERE workspace = :workspace
996
- AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
997
- FROM LIGHTRAG_GRAPH_NODES n
998
- WHERE workspace = :workspace
999
- ORDER BY priority DESC, degree DESC
1000
- FETCH FIRST :limit ROWS ONLY
1001
- """
1002
- nodes_params = {
1003
- "workspace": self.db.workspace,
1004
- "node_label": node_label,
1005
- "limit": max_graph_nodes,
1006
- }
1007
- nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
1008
-
1009
- if not nodes:
1010
- logger.warning(f"No nodes found matching '{node_label}'")
1011
- return result
1012
-
1013
- # Create mapping of node IDs to be used to filter edges
1014
- node_names = [node["name"] for node in nodes]
1015
-
1016
- # Add nodes to result
1017
- seen_nodes = set()
1018
- for node in nodes:
1019
- node_id = node["name"]
1020
- if node_id in seen_nodes:
1021
- continue
1022
-
1023
- # Create node properties dictionary
1024
- properties = {
1025
- "entity_type": node["entity_type"],
1026
- "description": node["description"] or "",
1027
- "source_id": node["source_chunk_id"] or "",
1028
- }
1029
-
1030
- # Add node to result
1031
- result.nodes.append(
1032
- KnowledgeGraphNode(
1033
- id=node_id, labels=[node["entity_type"]], properties=properties
1034
- )
1035
- )
1036
- seen_nodes.add(node_id)
1037
-
1038
- # Get edges between these nodes
1039
- edges_sql = """
1040
- SELECT source_name, target_name, weight, keywords, description, source_chunk_id
1041
- FROM LIGHTRAG_GRAPH_EDGES
1042
- WHERE workspace = :workspace
1043
- AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
1044
- AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
1045
- ORDER BY id
1046
- """
1047
- edges_params = {"workspace": self.db.workspace, "node_names": node_names}
1048
- edges = await self.db.query(edges_sql, edges_params, multirows=True)
1049
-
1050
- # Add edges to result
1051
- seen_edges = set()
1052
- for edge in edges:
1053
- source = edge["source_name"]
1054
- target = edge["target_name"]
1055
- edge_id = f"{source}-{target}"
1056
-
1057
- if edge_id in seen_edges:
1058
- continue
1059
-
1060
- # Create edge properties dictionary
1061
- properties = {
1062
- "weight": edge["weight"] or 0.0,
1063
- "keywords": edge["keywords"] or "",
1064
- "description": edge["description"] or "",
1065
- "source_id": edge["source_chunk_id"] or "",
1066
- }
1067
-
1068
- # Add edge to result
1069
- result.edges.append(
1070
- KnowledgeGraphEdge(
1071
- id=edge_id,
1072
- type="RELATED",
1073
- source=source,
1074
- target=target,
1075
- properties=properties,
1076
- )
1077
- )
1078
- seen_edges.add(edge_id)
1079
-
1080
- logger.info(
1081
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
1082
- )
1083
-
1084
- except Exception as e:
1085
- logger.error(f"Error retrieving knowledge graph: {e}")
1086
-
1087
- return result
1088
-
1089
-
1090
- N_T = {
1091
- NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
1092
- NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1093
- NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1094
- NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
1095
- NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
1096
- }
1097
-
1098
-
1099
- def namespace_to_table_name(namespace: str) -> str:
1100
- for k, v in N_T.items():
1101
- if is_namespace(namespace, k):
1102
- return v
1103
-
1104
-
1105
- TABLES = {
1106
- "LIGHTRAG_DOC_FULL": {
1107
- "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
1108
- id varchar(256),
1109
- workspace varchar(1024),
1110
- doc_name varchar(1024),
1111
- content CLOB,
1112
- meta JSON,
1113
- content_summary varchar(1024),
1114
- content_length NUMBER,
1115
- status varchar(256),
1116
- chunks_count NUMBER,
1117
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1118
- updatetime TIMESTAMP DEFAULT NULL,
1119
- error varchar(4096)
1120
- )"""
1121
- },
1122
- "LIGHTRAG_DOC_CHUNKS": {
1123
- "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
1124
- id varchar(256),
1125
- workspace varchar(1024),
1126
- full_doc_id varchar(256),
1127
- status varchar(256),
1128
- chunk_order_index NUMBER,
1129
- tokens NUMBER,
1130
- content CLOB,
1131
- content_vector VECTOR,
1132
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1133
- updatetime TIMESTAMP DEFAULT NULL
1134
- )"""
1135
- },
1136
- "LIGHTRAG_GRAPH_NODES": {
1137
- "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
1138
- id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
1139
- workspace varchar(1024),
1140
- name varchar(2048),
1141
- entity_type varchar(1024),
1142
- description CLOB,
1143
- source_chunk_id varchar(256),
1144
- content CLOB,
1145
- content_vector VECTOR,
1146
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1147
- updatetime TIMESTAMP DEFAULT NULL
1148
- )"""
1149
- },
1150
- "LIGHTRAG_GRAPH_EDGES": {
1151
- "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
1152
- id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
1153
- workspace varchar(1024),
1154
- source_name varchar(2048),
1155
- target_name varchar(2048),
1156
- weight NUMBER,
1157
- keywords CLOB,
1158
- description CLOB,
1159
- source_chunk_id varchar(256),
1160
- content CLOB,
1161
- content_vector VECTOR,
1162
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1163
- updatetime TIMESTAMP DEFAULT NULL
1164
- )"""
1165
- },
1166
- "LIGHTRAG_LLM_CACHE": {
1167
- "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
1168
- id varchar(256) PRIMARY KEY,
1169
- workspace varchar(1024),
1170
- cache_mode varchar(256),
1171
- model_name varchar(256),
1172
- original_prompt clob,
1173
- return_value clob,
1174
- embedding CLOB,
1175
- embedding_shape NUMBER,
1176
- embedding_min NUMBER,
1177
- embedding_max NUMBER,
1178
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1179
- updatetime TIMESTAMP DEFAULT NULL
1180
- )"""
1181
- },
1182
- "LIGHTRAG_GRAPH": {
1183
- "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
1184
- VERTEX TABLES (
1185
- lightrag_graph_nodes KEY (id)
1186
- LABEL entity
1187
- PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
1188
- )
1189
- EDGE TABLES (
1190
- lightrag_graph_edges KEY (id)
1191
- SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
1192
- DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
1193
- LABEL has_relation
1194
- PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
1195
- ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
1196
- },
1197
- }
1198
-
1199
-
1200
- SQL_TEMPLATES = {
1201
- # SQL for KVStorage
1202
- "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
1203
- "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
1204
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1205
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
1206
- "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1207
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
1208
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1209
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
1210
- "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
1211
- "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
1212
- "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
1213
- "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
1214
- "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
1215
- "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
1216
- "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
1217
- "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
1218
- USING DUAL
1219
- ON (a.id = :id and a.workspace = :workspace)
1220
- WHEN NOT MATCHED THEN
1221
- INSERT(id,content,workspace) values(:id,:content,:workspace)""",
1222
- "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
1223
- USING DUAL
1224
- ON (id = :id and workspace = :workspace)
1225
- WHEN NOT MATCHED THEN INSERT
1226
- (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
1227
- values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
1228
- "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
1229
- USING DUAL
1230
- ON (a.id = :id)
1231
- WHEN NOT MATCHED THEN
1232
- INSERT (workspace,id,original_prompt,return_value,cache_mode)
1233
- VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
1234
- WHEN MATCHED THEN UPDATE
1235
- SET original_prompt = :original_prompt,
1236
- return_value = :return_value,
1237
- cache_mode = :cache_mode,
1238
- updatetime = SYSDATE""",
1239
- # SQL for VectorStorage
1240
- "entities": """SELECT name as entity_name FROM
1241
- (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1242
- FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
1243
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1244
- "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
1245
- (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1246
- FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
1247
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1248
- "chunks": """SELECT id FROM
1249
- (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1250
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
1251
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1252
- # SQL for GraphStorage
1253
- "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
1254
- MATCH (a)
1255
- WHERE a.workspace=:workspace AND a.name=:node_id
1256
- COLUMNS (a.name))""",
1257
- "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
1258
- MATCH (a) -[e]-> (b)
1259
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1260
- AND a.name=:source_node_id AND b.name=:target_node_id
1261
- COLUMNS (e.source_name,e.target_name) )""",
1262
- "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
1263
- MATCH (a)-[e]->(b)
1264
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1265
- AND a.name=:node_id or b.name = :node_id
1266
- COLUMNS (a.name))""",
1267
- "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
1268
- FROM GRAPH_TABLE (lightrag_graph
1269
- MATCH (a)
1270
- WHERE a.workspace=:workspace AND a.name=:node_id
1271
- COLUMNS (a.name)
1272
- ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
1273
- WHERE t2.workspace=:workspace""",
1274
- "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
1275
- NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
1276
- FROM GRAPH_TABLE (lightrag_graph
1277
- MATCH (a)-[e]->(b)
1278
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1279
- AND a.name=:source_node_id and b.name = :target_node_id
1280
- COLUMNS (e.id,a.name as source_id)
1281
- ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
1282
- "get_node_edges": """SELECT source_name,target_name
1283
- FROM GRAPH_TABLE (lightrag_graph
1284
- MATCH (a)-[e]->(b)
1285
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1286
- AND a.name=:source_node_id
1287
- COLUMNS (a.name as source_name,b.name as target_name))""",
1288
- "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
1289
- USING DUAL
1290
- ON (a.workspace=:workspace and a.name=:name)
1291
- WHEN NOT MATCHED THEN
1292
- INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
1293
- values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
1294
- WHEN MATCHED THEN
1295
- UPDATE SET
1296
- entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
1297
- "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
1298
- USING DUAL
1299
- ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
1300
- WHEN NOT MATCHED THEN
1301
- INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
1302
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
1303
- WHEN MATCHED THEN
1304
- UPDATE SET
1305
- weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
1306
- "get_all_nodes": """WITH t0 AS (
1307
- SELECT name AS id, entity_type AS label, entity_type, description,
1308
- '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
1309
- FROM lightrag_graph_nodes
1310
- WHERE workspace = :workspace
1311
- ORDER BY createtime DESC fetch first :limit rows only
1312
- ), t1 AS (
1313
- SELECT t0.id, source_chunk_id
1314
- FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
1315
- ), t2 AS (
1316
- SELECT t1.id, LISTAGG(t2.content, '\n') content
1317
- FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
1318
- GROUP BY t1.id
1319
- )
1320
- SELECT t0.id, label, entity_type, description, t2.content
1321
- FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
1322
- "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
1323
- t1.weight,t1.DESCRIPTION,t2.content
1324
- FROM LIGHTRAG_GRAPH_EDGES t1
1325
- LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
1326
- WHERE t1.workspace=:workspace
1327
- order by t1.CREATETIME DESC
1328
- fetch first :limit rows only""",
1329
- "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
1330
- count(distinct CASE WHEN type='edge' THEN id END) as edges_count
1331
- FROM (
1332
- select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
1333
- MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
1334
- UNION
1335
- select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
1336
- MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
1337
- )""",
1338
- # SQL for deletion
1339
- "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
1340
- "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
1341
- "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
1342
- "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
1343
- MATCH (a)
1344
- WHERE a.workspace=:workspace AND a.name=:node_id
1345
- ACTION DELETE a)""",
1346
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/kg/postgres_impl.py CHANGED
@@ -9,7 +9,6 @@ import configparser
9
 
10
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
11
 
12
- import sys
13
  from tenacity import (
14
  retry,
15
  retry_if_exception_type,
@@ -28,11 +27,6 @@ from ..base import (
28
  from ..namespace import NameSpace, is_namespace
29
  from ..utils import logger
30
 
31
- if sys.platform.startswith("win"):
32
- import asyncio.windows_events
33
-
34
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
35
-
36
  import pipmaster as pm
37
 
38
  if not pm.is_installed("asyncpg"):
@@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"):
41
  import asyncpg # type: ignore
42
  from asyncpg import Pool # type: ignore
43
 
 
 
 
44
 
45
  class PostgreSQLDB:
46
  def __init__(self, config: dict[str, Any], **kwargs: Any):
@@ -118,6 +115,25 @@ class PostgreSQLDB:
118
  )
119
  raise e
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  async def query(
122
  self,
123
  sql: str,
@@ -254,8 +270,6 @@ class PGKVStorage(BaseKVStorage):
254
  db: PostgreSQLDB = field(default=None)
255
 
256
  def __post_init__(self):
257
- namespace_prefix = self.global_config.get("namespace_prefix")
258
- self.base_namespace = self.namespace.replace(namespace_prefix, "")
259
  self._max_batch_size = self.global_config["embedding_batch_num"]
260
 
261
  async def initialize(self):
@@ -271,7 +285,7 @@ class PGKVStorage(BaseKVStorage):
271
 
272
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
273
  """Get doc_full data by id."""
274
- sql = SQL_TEMPLATES["get_by_id_" + self.base_namespace]
275
  params = {"workspace": self.db.workspace, "id": id}
276
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
277
  array_res = await self.db.query(sql, params, multirows=True)
@@ -285,7 +299,7 @@ class PGKVStorage(BaseKVStorage):
285
 
286
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
287
  """Specifically for llm_response_cache."""
288
- sql = SQL_TEMPLATES["get_by_mode_id_" + self.base_namespace]
289
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
290
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
291
  array_res = await self.db.query(sql, params, multirows=True)
@@ -299,7 +313,7 @@ class PGKVStorage(BaseKVStorage):
299
  # Query by id
300
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
301
  """Get doc_chunks data by id"""
302
- sql = SQL_TEMPLATES["get_by_ids_" + self.base_namespace].format(
303
  ids=",".join([f"'{id}'" for id in ids])
304
  )
305
  params = {"workspace": self.db.workspace}
@@ -320,7 +334,7 @@ class PGKVStorage(BaseKVStorage):
320
 
321
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
322
  """Specifically for llm_response_cache."""
323
- SQL = SQL_TEMPLATES["get_by_status_" + self.base_namespace]
324
  params = {"workspace": self.db.workspace, "status": status}
325
  return await self.db.query(SQL, params, multirows=True)
326
 
@@ -380,10 +394,85 @@ class PGKVStorage(BaseKVStorage):
380
  # PG handles persistence automatically
381
  pass
382
 
383
- async def drop(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  """Drop the storage"""
385
- drop_sql = SQL_TEMPLATES["drop_all"]
386
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
 
389
  @final
@@ -393,8 +482,6 @@ class PGVectorStorage(BaseVectorStorage):
393
 
394
  def __post_init__(self):
395
  self._max_batch_size = self.global_config["embedding_batch_num"]
396
- namespace_prefix = self.global_config.get("namespace_prefix")
397
- self.base_namespace = self.namespace.replace(namespace_prefix, "")
398
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
399
  cosine_threshold = config.get("cosine_better_than_threshold")
400
  if cosine_threshold is None:
@@ -523,7 +610,7 @@ class PGVectorStorage(BaseVectorStorage):
523
  else:
524
  formatted_ids = "NULL"
525
 
526
- sql = SQL_TEMPLATES[self.base_namespace].format(
527
  embedding_string=embedding_string, doc_ids=formatted_ids
528
  )
529
  params = {
@@ -552,13 +639,12 @@ class PGVectorStorage(BaseVectorStorage):
552
  logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
553
  return
554
 
555
- ids_list = ",".join([f"'{id}'" for id in ids])
556
- delete_sql = (
557
- f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
558
- )
559
 
560
  try:
561
- await self.db.execute(delete_sql, {"workspace": self.db.workspace})
 
 
562
  logger.debug(
563
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
564
  )
@@ -690,6 +776,24 @@ class PGVectorStorage(BaseVectorStorage):
690
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
691
  return []
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  @final
695
  @dataclass
@@ -810,6 +914,35 @@ class PGDocStatusStorage(DocStatusStorage):
810
  # PG handles persistence automatically
811
  pass
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
814
  """Update or insert document status
815
 
@@ -846,10 +979,23 @@ class PGDocStatusStorage(DocStatusStorage):
846
  },
847
  )
848
 
849
- async def drop(self) -> None:
850
  """Drop the storage"""
851
- drop_sql = SQL_TEMPLATES["drop_doc_full"]
852
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
853
 
854
 
855
  class PGGraphQueryException(Exception):
@@ -937,31 +1083,11 @@ class PGGraphStorage(BaseGraphStorage):
937
  if v.startswith("[") and v.endswith("]"):
938
  if "::vertex" in v:
939
  v = v.replace("::vertex", "")
940
- vertexes = json.loads(v)
941
- dl = []
942
- for vertex in vertexes:
943
- prop = vertex.get("properties")
944
- if not prop:
945
- prop = {}
946
- prop["label"] = PGGraphStorage._decode_graph_label(
947
- prop["node_id"]
948
- )
949
- dl.append(prop)
950
- d[k] = dl
951
 
952
  elif "::edge" in v:
953
  v = v.replace("::edge", "")
954
- edges = json.loads(v)
955
- dl = []
956
- for edge in edges:
957
- dl.append(
958
- (
959
- vertices[edge["start_id"]],
960
- edge["label"],
961
- vertices[edge["end_id"]],
962
- )
963
- )
964
- d[k] = dl
965
  else:
966
  print("WARNING: unsupported type")
967
  continue
@@ -970,32 +1096,19 @@ class PGGraphStorage(BaseGraphStorage):
970
  dtype = v.split("::")[-1]
971
  v = v.split("::")[0]
972
  if dtype == "vertex":
973
- vertex = json.loads(v)
974
- field = vertex.get("properties")
975
- if not field:
976
- field = {}
977
- field["label"] = PGGraphStorage._decode_graph_label(
978
- field["node_id"]
979
- )
980
- d[k] = field
981
- # convert edge from id-label->id by replacing id with node information
982
- # we only do this if the vertex was also returned in the query
983
- # this is an attempt to be consistent with neo4j implementation
984
  elif dtype == "edge":
985
- edge = json.loads(v)
986
- d[k] = (
987
- vertices.get(edge["start_id"], {}),
988
- edge[
989
- "label"
990
- ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
991
- vertices.get(edge["end_id"], {}),
992
- )
993
  else:
994
- d[k] = (
995
- json.loads(v)
996
- if isinstance(v, str) and ("{" in v or "[" in v)
997
- else v
998
- )
 
 
 
 
999
 
1000
  return d
1001
 
@@ -1025,56 +1138,6 @@ class PGGraphStorage(BaseGraphStorage):
1025
  )
1026
  return "{" + ", ".join(props) + "}"
1027
 
1028
- @staticmethod
1029
- def _encode_graph_label(label: str) -> str:
1030
- """
1031
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1032
-
1033
- Args:
1034
- label (str): the original label
1035
-
1036
- Returns:
1037
- str: the encoded label
1038
- """
1039
- return "x" + label.encode().hex()
1040
-
1041
- @staticmethod
1042
- def _decode_graph_label(encoded_label: str) -> str:
1043
- """
1044
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1045
-
1046
- Args:
1047
- encoded_label (str): the encoded label
1048
-
1049
- Returns:
1050
- str: the decoded label
1051
- """
1052
- return bytes.fromhex(encoded_label.removeprefix("x")).decode()
1053
-
1054
- @staticmethod
1055
- def _get_col_name(field: str, idx: int) -> str:
1056
- """
1057
- Convert a cypher return field to a pgsql select field
1058
- If possible keep the cypher column name, but create a generic name if necessary
1059
-
1060
- Args:
1061
- field (str): a return field from a cypher query to be formatted for pgsql
1062
- idx (int): the position of the field in the return statement
1063
-
1064
- Returns:
1065
- str: the field to be used in the pgsql select statement
1066
- """
1067
- # remove white space
1068
- field = field.strip()
1069
- # if an alias is provided for the field, use it
1070
- if " as " in field:
1071
- return field.split(" as ")[-1].strip()
1072
- # if the return value is an unnamed primitive, give it a generic name
1073
- if field.isnumeric() or field in ("true", "false", "null"):
1074
- return f"column_{idx}"
1075
- # otherwise return the value stripping out some common special chars
1076
- return field.replace("(", "_").replace(")", "")
1077
-
1078
  async def _query(
1079
  self,
1080
  query: str,
@@ -1125,10 +1188,10 @@ class PGGraphStorage(BaseGraphStorage):
1125
  return result
1126
 
1127
  async def has_node(self, node_id: str) -> bool:
1128
- entity_name_label = self._encode_graph_label(node_id.strip('"'))
1129
 
1130
  query = """SELECT * FROM cypher('%s', $$
1131
- MATCH (n:Entity {node_id: "%s"})
1132
  RETURN count(n) > 0 AS node_exists
1133
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1134
 
@@ -1137,11 +1200,11 @@ class PGGraphStorage(BaseGraphStorage):
1137
  return single_result["node_exists"]
1138
 
1139
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1140
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1141
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1142
 
1143
  query = """SELECT * FROM cypher('%s', $$
1144
- MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
1145
  RETURN COUNT(r) > 0 AS edge_exists
1146
  $$) AS (edge_exists bool)""" % (
1147
  self.graph_name,
@@ -1154,30 +1217,31 @@ class PGGraphStorage(BaseGraphStorage):
1154
  return single_result["edge_exists"]
1155
 
1156
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1157
- label = self._encode_graph_label(node_id.strip('"'))
 
 
1158
  query = """SELECT * FROM cypher('%s', $$
1159
- MATCH (n:Entity {node_id: "%s"})
1160
  RETURN n
1161
  $$) AS (n agtype)""" % (self.graph_name, label)
1162
  record = await self._query(query)
1163
  if record:
1164
  node = record[0]
1165
- node_dict = node["n"]
1166
 
1167
  return node_dict
1168
  return None
1169
 
1170
  async def node_degree(self, node_id: str) -> int:
1171
- label = self._encode_graph_label(node_id.strip('"'))
1172
 
1173
  query = """SELECT * FROM cypher('%s', $$
1174
- MATCH (n:Entity {node_id: "%s"})-[]->(x)
1175
  RETURN count(x) AS total_edge_count
1176
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1177
  record = (await self._query(query))[0]
1178
  if record:
1179
  edge_count = int(record["total_edge_count"])
1180
-
1181
  return edge_count
1182
 
1183
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -1195,11 +1259,13 @@ class PGGraphStorage(BaseGraphStorage):
1195
  async def get_edge(
1196
  self, source_node_id: str, target_node_id: str
1197
  ) -> dict[str, str] | None:
1198
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1199
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
 
 
1200
 
1201
  query = """SELECT * FROM cypher('%s', $$
1202
- MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
1203
  RETURN properties(r) as edge_properties
1204
  LIMIT 1
1205
  $$) AS (edge_properties agtype)""" % (
@@ -1218,11 +1284,11 @@ class PGGraphStorage(BaseGraphStorage):
1218
  Retrieves all edges (relationships) for a particular node identified by its label.
1219
  :return: list of dictionaries containing edge information
1220
  """
1221
- label = self._encode_graph_label(source_node_id.strip('"'))
1222
 
1223
  query = """SELECT * FROM cypher('%s', $$
1224
- MATCH (n:Entity {node_id: "%s"})
1225
- OPTIONAL MATCH (n)-[]-(connected)
1226
  RETURN n, connected
1227
  $$) AS (n agtype, connected agtype)""" % (
1228
  self.graph_name,
@@ -1235,24 +1301,17 @@ class PGGraphStorage(BaseGraphStorage):
1235
  source_node = record["n"] if record["n"] else None
1236
  connected_node = record["connected"] if record["connected"] else None
1237
 
1238
- source_label = (
1239
- source_node["node_id"]
1240
- if source_node and source_node["node_id"]
1241
- else None
1242
- )
1243
- target_label = (
1244
- connected_node["node_id"]
1245
- if connected_node and connected_node["node_id"]
1246
- else None
1247
- )
1248
 
1249
- if source_label and target_label:
1250
- edges.append(
1251
- (
1252
- self._decode_graph_label(source_label),
1253
- self._decode_graph_label(target_label),
1254
- )
1255
- )
1256
 
1257
  return edges
1258
 
@@ -1262,24 +1321,36 @@ class PGGraphStorage(BaseGraphStorage):
1262
  retry=retry_if_exception_type((PGGraphQueryException,)),
1263
  )
1264
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1265
- label = self._encode_graph_label(node_id.strip('"'))
1266
- properties = node_data
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
  query = """SELECT * FROM cypher('%s', $$
1269
- MERGE (n:Entity {node_id: "%s"})
1270
  SET n += %s
1271
  RETURN n
1272
  $$) AS (n agtype)""" % (
1273
  self.graph_name,
1274
  label,
1275
- self._format_properties(properties),
1276
  )
1277
 
1278
  try:
1279
  await self._query(query, readonly=False, upsert=True)
1280
 
1281
- except Exception as e:
1282
- logger.error("POSTGRES, Error during upsert: {%s}", e)
1283
  raise
1284
 
1285
  @retry(
@@ -1298,14 +1369,14 @@ class PGGraphStorage(BaseGraphStorage):
1298
  target_node_id (str): Label of the target node (used as identifier)
1299
  edge_data (dict): dictionary of properties to set on the edge
1300
  """
1301
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1302
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1303
- edge_properties = edge_data
1304
 
1305
  query = """SELECT * FROM cypher('%s', $$
1306
- MATCH (source:Entity {node_id: "%s"})
1307
  WITH source
1308
- MATCH (target:Entity {node_id: "%s"})
1309
  MERGE (source)-[r:DIRECTED]->(target)
1310
  SET r += %s
1311
  RETURN r
@@ -1313,14 +1384,16 @@ class PGGraphStorage(BaseGraphStorage):
1313
  self.graph_name,
1314
  src_label,
1315
  tgt_label,
1316
- self._format_properties(edge_properties),
1317
  )
1318
 
1319
  try:
1320
  await self._query(query, readonly=False, upsert=True)
1321
 
1322
- except Exception as e:
1323
- logger.error("Error during edge upsert: {%s}", e)
 
 
1324
  raise
1325
 
1326
  async def _node2vec_embed(self):
@@ -1333,10 +1406,10 @@ class PGGraphStorage(BaseGraphStorage):
1333
  Args:
1334
  node_id (str): The ID of the node to delete.
1335
  """
1336
- label = self._encode_graph_label(node_id.strip('"'))
1337
 
1338
  query = """SELECT * FROM cypher('%s', $$
1339
- MATCH (n:Entity {node_id: "%s"})
1340
  DETACH DELETE n
1341
  $$) AS (n agtype)""" % (self.graph_name, label)
1342
 
@@ -1353,14 +1426,12 @@ class PGGraphStorage(BaseGraphStorage):
1353
  Args:
1354
  node_ids (list[str]): A list of node IDs to remove.
1355
  """
1356
- encoded_node_ids = [
1357
- self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
1358
- ]
1359
- node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
1360
 
1361
  query = """SELECT * FROM cypher('%s', $$
1362
- MATCH (n:Entity)
1363
- WHERE n.node_id IN [%s]
1364
  DETACH DELETE n
1365
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1366
 
@@ -1377,26 +1448,21 @@ class PGGraphStorage(BaseGraphStorage):
1377
  Args:
1378
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1379
  """
1380
- encoded_edges = [
1381
- (
1382
- self._encode_graph_label(src.strip('"')),
1383
- self._encode_graph_label(tgt.strip('"')),
1384
- )
1385
- for src, tgt in edges
1386
- ]
1387
- edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
1388
 
1389
- query = """SELECT * FROM cypher('%s', $$
1390
- MATCH (a:Entity)-[r]->(b:Entity)
1391
- WHERE [a.node_id, b.node_id] IN [%s]
1392
- DELETE r
1393
- $$) AS (r agtype)""" % (self.graph_name, edge_list)
1394
 
1395
- try:
1396
- await self._query(query, readonly=False)
1397
- except Exception as e:
1398
- logger.error("Error during edge removal: {%s}", e)
1399
- raise
 
1400
 
1401
  async def get_all_labels(self) -> list[str]:
1402
  """
@@ -1407,15 +1473,16 @@ class PGGraphStorage(BaseGraphStorage):
1407
  """
1408
  query = (
1409
  """SELECT * FROM cypher('%s', $$
1410
- MATCH (n:Entity)
1411
- RETURN DISTINCT n.node_id AS label
 
 
1412
  $$) AS (label text)"""
1413
  % self.graph_name
1414
  )
1415
 
1416
  results = await self._query(query)
1417
- labels = [self._decode_graph_label(result["label"]) for result in results]
1418
-
1419
  return labels
1420
 
1421
  async def embed_nodes(
@@ -1437,105 +1504,135 @@ class PGGraphStorage(BaseGraphStorage):
1437
  return await embed_func()
1438
 
1439
  async def get_knowledge_graph(
1440
- self, node_label: str, max_depth: int = 5
 
 
 
1441
  ) -> KnowledgeGraph:
1442
  """
1443
- Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
1444
 
1445
  Args:
1446
- node_label (str): The label of the node to start from. If "*", the entire graph is returned.
1447
- max_depth (int): The maximum depth to traverse from the starting node.
 
1448
 
1449
  Returns:
1450
- KnowledgeGraph: The retrieved subgraph.
 
1451
  """
1452
- MAX_GRAPH_NODES = 1000
1453
-
1454
- # Build the query based on whether we want the full graph or a specific subgraph.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
  if node_label == "*":
1456
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1457
- MATCH (n:Entity)
1458
- OPTIONAL MATCH (n)-[r]->(m:Entity)
1459
- RETURN n, r, m
1460
- LIMIT {MAX_GRAPH_NODES}
1461
- $$) AS (n agtype, r agtype, m agtype)"""
1462
  else:
1463
- encoded_label = self._encode_graph_label(node_label.strip('"'))
1464
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1465
- MATCH (n:Entity {{node_id: "{encoded_label}"}})
1466
- OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1467
- RETURN nodes(p) AS nodes, relationships(p) AS relationships
1468
- LIMIT {MAX_GRAPH_NODES}
1469
- $$) AS (nodes agtype, relationships agtype)"""
1470
 
1471
  results = await self._query(query)
1472
 
1473
- nodes = {}
1474
- edges = []
1475
- unique_edge_ids = set()
1476
-
1477
- def add_node(node_data: dict):
1478
- node_id = self._decode_graph_label(node_data["node_id"])
1479
- if node_id not in nodes:
1480
- nodes[node_id] = node_data
1481
-
1482
- def add_edge(edge_data: list):
1483
- src_id = self._decode_graph_label(edge_data[0]["node_id"])
1484
- tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
1485
- edge_key = f"{src_id},{tgt_id}"
1486
- if edge_key not in unique_edge_ids:
1487
- unique_edge_ids.add(edge_key)
1488
- edges.append(
1489
- (
1490
- edge_key,
1491
- src_id,
1492
- tgt_id,
1493
- {"source": edge_data[0], "target": edge_data[2]},
1494
  )
1495
- )
 
 
 
 
 
 
 
 
 
 
1496
 
1497
- # Process the query results.
1498
- if node_label == "*":
1499
- for result in results:
1500
- if result.get("n"):
1501
- add_node(result["n"])
1502
- if result.get("m"):
1503
- add_node(result["m"])
1504
- if result.get("r"):
1505
- add_edge(result["r"])
1506
- else:
1507
- for result in results:
1508
- for node in result.get("nodes", []):
1509
- add_node(node)
1510
- for edge in result.get("relationships", []):
1511
- add_edge(edge)
 
 
 
 
 
 
 
 
 
1512
 
1513
- # Construct and return the KnowledgeGraph.
1514
  kg = KnowledgeGraph(
1515
- nodes=[
1516
- KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
1517
- for node_id, node_data in nodes.items()
1518
- ],
1519
- edges=[
1520
- KnowledgeGraphEdge(
1521
- id=edge_id,
1522
- type="DIRECTED",
1523
- source=src,
1524
- target=tgt,
1525
- properties=props,
1526
- )
1527
- for edge_id, src, tgt, props in edges
1528
- ],
1529
  )
1530
 
 
 
 
1531
  return kg
1532
 
1533
- async def drop(self) -> None:
1534
  """Drop the storage"""
1535
- drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
1536
- await self.db.execute(drop_sql)
1537
- drop_sql = SQL_TEMPLATES["drop_vdb_relation"]
1538
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
1539
 
1540
 
1541
  NAMESPACE_TABLE_MAP = {
@@ -1693,6 +1790,7 @@ SQL_TEMPLATES = {
1693
  file_path=EXCLUDED.file_path,
1694
  update_time = CURRENT_TIMESTAMP
1695
  """,
 
1696
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
1697
  content_vector, chunk_ids, file_path)
1698
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
@@ -1716,45 +1814,6 @@ SQL_TEMPLATES = {
1716
  file_path=EXCLUDED.file_path,
1717
  update_time = CURRENT_TIMESTAMP
1718
  """,
1719
- # SQL for VectorStorage
1720
- # "entities": """SELECT entity_name FROM
1721
- # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1722
- # FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1723
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1724
- # """,
1725
- # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1726
- # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1727
- # FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1728
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1729
- # """,
1730
- # "chunks": """SELECT id FROM
1731
- # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1732
- # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1733
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1734
- # """,
1735
- # DROP tables
1736
- "drop_all": """
1737
- DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
1738
- DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
1739
- DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
1740
- DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
1741
- DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1742
- """,
1743
- "drop_doc_full": """
1744
- DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
1745
- """,
1746
- "drop_doc_chunks": """
1747
- DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
1748
- """,
1749
- "drop_llm_cache": """
1750
- DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
1751
- """,
1752
- "drop_vdb_entity": """
1753
- DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
1754
- """,
1755
- "drop_vdb_relation": """
1756
- DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1757
- """,
1758
  "relationships": """
1759
  WITH relevant_chunks AS (
1760
  SELECT id as chunk_id
@@ -1795,9 +1854,9 @@ SQL_TEMPLATES = {
1795
  FROM LIGHTRAG_DOC_CHUNKS
1796
  WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1797
  )
1798
- SELECT id FROM
1799
  (
1800
- SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1801
  FROM LIGHTRAG_DOC_CHUNKS
1802
  where workspace=$1
1803
  AND id IN (SELECT chunk_id FROM relevant_chunks)
@@ -1806,4 +1865,8 @@ SQL_TEMPLATES = {
1806
  ORDER BY distance DESC
1807
  LIMIT $3
1808
  """,
 
 
 
 
1809
  }
 
9
 
10
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
11
 
 
12
  from tenacity import (
13
  retry,
14
  retry_if_exception_type,
 
27
  from ..namespace import NameSpace, is_namespace
28
  from ..utils import logger
29
 
 
 
 
 
 
30
  import pipmaster as pm
31
 
32
  if not pm.is_installed("asyncpg"):
 
35
  import asyncpg # type: ignore
36
  from asyncpg import Pool # type: ignore
37
 
38
+ # Get maximum number of graph nodes from environment variable, default is 1000
39
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
40
+
41
 
42
  class PostgreSQLDB:
43
  def __init__(self, config: dict[str, Any], **kwargs: Any):
 
115
  )
116
  raise e
117
 
118
+ # Create index for id column in each table
119
+ try:
120
+ index_name = f"idx_{k.lower()}_id"
121
+ check_index_sql = f"""
122
+ SELECT 1 FROM pg_indexes
123
+ WHERE indexname = '{index_name}'
124
+ AND tablename = '{k.lower()}'
125
+ """
126
+ index_exists = await self.query(check_index_sql)
127
+
128
+ if not index_exists:
129
+ create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
130
+ logger.info(f"PostgreSQL, Creating index {index_name} on table {k}")
131
+ await self.execute(create_index_sql)
132
+ except Exception as e:
133
+ logger.error(
134
+ f"PostgreSQL, Failed to create index on table {k}, Got: {e}"
135
+ )
136
+
137
  async def query(
138
  self,
139
  sql: str,
 
270
  db: PostgreSQLDB = field(default=None)
271
 
272
  def __post_init__(self):
 
 
273
  self._max_batch_size = self.global_config["embedding_batch_num"]
274
 
275
  async def initialize(self):
 
285
 
286
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
287
  """Get doc_full data by id."""
288
+ sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
289
  params = {"workspace": self.db.workspace, "id": id}
290
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
291
  array_res = await self.db.query(sql, params, multirows=True)
 
299
 
300
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
301
  """Specifically for llm_response_cache."""
302
+ sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
303
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
304
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
305
  array_res = await self.db.query(sql, params, multirows=True)
 
313
  # Query by id
314
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
315
  """Get doc_chunks data by id"""
316
+ sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
317
  ids=",".join([f"'{id}'" for id in ids])
318
  )
319
  params = {"workspace": self.db.workspace}
 
334
 
335
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
336
  """Specifically for llm_response_cache."""
337
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
338
  params = {"workspace": self.db.workspace, "status": status}
339
  return await self.db.query(SQL, params, multirows=True)
340
 
 
394
  # PG handles persistence automatically
395
  pass
396
 
397
+ async def delete(self, ids: list[str]) -> None:
398
+ """Delete specific records from storage by their IDs
399
+
400
+ Args:
401
+ ids (list[str]): List of document IDs to be deleted from storage
402
+
403
+ Returns:
404
+ None
405
+ """
406
+ if not ids:
407
+ return
408
+
409
+ table_name = namespace_to_table_name(self.namespace)
410
+ if not table_name:
411
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
412
+ return
413
+
414
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
415
+
416
+ try:
417
+ await self.db.execute(
418
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
419
+ )
420
+ logger.debug(
421
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
422
+ )
423
+ except Exception as e:
424
+ logger.error(f"Error while deleting records from {self.namespace}: {e}")
425
+
426
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
427
+ """Delete specific records from storage by cache mode
428
+
429
+ Args:
430
+ modes (list[str]): List of cache modes to be dropped from storage
431
+
432
+ Returns:
433
+ bool: True if successful, False otherwise
434
+ """
435
+ if not modes:
436
+ return False
437
+
438
+ try:
439
+ table_name = namespace_to_table_name(self.namespace)
440
+ if not table_name:
441
+ return False
442
+
443
+ if table_name != "LIGHTRAG_LLM_CACHE":
444
+ return False
445
+
446
+ sql = f"""
447
+ DELETE FROM {table_name}
448
+ WHERE workspace = $1 AND mode = ANY($2)
449
+ """
450
+ params = {"workspace": self.db.workspace, "modes": modes}
451
+
452
+ logger.info(f"Deleting cache by modes: {modes}")
453
+ await self.db.execute(sql, params)
454
+ return True
455
+ except Exception as e:
456
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
457
+ return False
458
+
459
+ async def drop(self) -> dict[str, str]:
460
  """Drop the storage"""
461
+ try:
462
+ table_name = namespace_to_table_name(self.namespace)
463
+ if not table_name:
464
+ return {
465
+ "status": "error",
466
+ "message": f"Unknown namespace: {self.namespace}",
467
+ }
468
+
469
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
470
+ table_name=table_name
471
+ )
472
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
473
+ return {"status": "success", "message": "data dropped"}
474
+ except Exception as e:
475
+ return {"status": "error", "message": str(e)}
476
 
477
 
478
  @final
 
482
 
483
  def __post_init__(self):
484
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
 
485
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
486
  cosine_threshold = config.get("cosine_better_than_threshold")
487
  if cosine_threshold is None:
 
610
  else:
611
  formatted_ids = "NULL"
612
 
613
+ sql = SQL_TEMPLATES[self.namespace].format(
614
  embedding_string=embedding_string, doc_ids=formatted_ids
615
  )
616
  params = {
 
639
  logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
640
  return
641
 
642
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
 
 
 
643
 
644
  try:
645
+ await self.db.execute(
646
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
647
+ )
648
  logger.debug(
649
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
650
  )
 
776
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
777
  return []
778
 
779
+ async def drop(self) -> dict[str, str]:
780
+ """Drop the storage"""
781
+ try:
782
+ table_name = namespace_to_table_name(self.namespace)
783
+ if not table_name:
784
+ return {
785
+ "status": "error",
786
+ "message": f"Unknown namespace: {self.namespace}",
787
+ }
788
+
789
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
790
+ table_name=table_name
791
+ )
792
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
793
+ return {"status": "success", "message": "data dropped"}
794
+ except Exception as e:
795
+ return {"status": "error", "message": str(e)}
796
+
797
 
798
  @final
799
  @dataclass
 
914
  # PG handles persistence automatically
915
  pass
916
 
917
+ async def delete(self, ids: list[str]) -> None:
918
+ """Delete specific records from storage by their IDs
919
+
920
+ Args:
921
+ ids (list[str]): List of document IDs to be deleted from storage
922
+
923
+ Returns:
924
+ None
925
+ """
926
+ if not ids:
927
+ return
928
+
929
+ table_name = namespace_to_table_name(self.namespace)
930
+ if not table_name:
931
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
932
+ return
933
+
934
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
935
+
936
+ try:
937
+ await self.db.execute(
938
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
939
+ )
940
+ logger.debug(
941
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
942
+ )
943
+ except Exception as e:
944
+ logger.error(f"Error while deleting records from {self.namespace}: {e}")
945
+
946
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
947
  """Update or insert document status
948
 
 
979
  },
980
  )
981
 
982
+ async def drop(self) -> dict[str, str]:
983
  """Drop the storage"""
984
+ try:
985
+ table_name = namespace_to_table_name(self.namespace)
986
+ if not table_name:
987
+ return {
988
+ "status": "error",
989
+ "message": f"Unknown namespace: {self.namespace}",
990
+ }
991
+
992
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
993
+ table_name=table_name
994
+ )
995
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
996
+ return {"status": "success", "message": "data dropped"}
997
+ except Exception as e:
998
+ return {"status": "error", "message": str(e)}
999
 
1000
 
1001
  class PGGraphQueryException(Exception):
 
1083
  if v.startswith("[") and v.endswith("]"):
1084
  if "::vertex" in v:
1085
  v = v.replace("::vertex", "")
1086
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1087
 
1088
  elif "::edge" in v:
1089
  v = v.replace("::edge", "")
1090
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1091
  else:
1092
  print("WARNING: unsupported type")
1093
  continue
 
1096
  dtype = v.split("::")[-1]
1097
  v = v.split("::")[0]
1098
  if dtype == "vertex":
1099
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1100
  elif dtype == "edge":
1101
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
1102
  else:
1103
+ try:
1104
+ d[k] = (
1105
+ json.loads(v)
1106
+ if isinstance(v, str)
1107
+ and (v.startswith("{") or v.startswith("["))
1108
+ else v
1109
+ )
1110
+ except json.JSONDecodeError:
1111
+ d[k] = v
1112
 
1113
  return d
1114
 
 
1138
  )
1139
  return "{" + ", ".join(props) + "}"
1140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  async def _query(
1142
  self,
1143
  query: str,
 
1188
  return result
1189
 
1190
  async def has_node(self, node_id: str) -> bool:
1191
+ entity_name_label = node_id.strip('"')
1192
 
1193
  query = """SELECT * FROM cypher('%s', $$
1194
+ MATCH (n:base {entity_id: "%s"})
1195
  RETURN count(n) > 0 AS node_exists
1196
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1197
 
 
1200
  return single_result["node_exists"]
1201
 
1202
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1203
+ src_label = source_node_id.strip('"')
1204
+ tgt_label = target_node_id.strip('"')
1205
 
1206
  query = """SELECT * FROM cypher('%s', $$
1207
+ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
1208
  RETURN COUNT(r) > 0 AS edge_exists
1209
  $$) AS (edge_exists bool)""" % (
1210
  self.graph_name,
 
1217
  return single_result["edge_exists"]
1218
 
1219
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1220
+ """Get node by its label identifier, return only node properties"""
1221
+
1222
+ label = node_id.strip('"')
1223
  query = """SELECT * FROM cypher('%s', $$
1224
+ MATCH (n:base {entity_id: "%s"})
1225
  RETURN n
1226
  $$) AS (n agtype)""" % (self.graph_name, label)
1227
  record = await self._query(query)
1228
  if record:
1229
  node = record[0]
1230
+ node_dict = node["n"]["properties"]
1231
 
1232
  return node_dict
1233
  return None
1234
 
1235
  async def node_degree(self, node_id: str) -> int:
1236
+ label = node_id.strip('"')
1237
 
1238
  query = """SELECT * FROM cypher('%s', $$
1239
+ MATCH (n:base {entity_id: "%s"})-[]-(x)
1240
  RETURN count(x) AS total_edge_count
1241
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1242
  record = (await self._query(query))[0]
1243
  if record:
1244
  edge_count = int(record["total_edge_count"])
 
1245
  return edge_count
1246
 
1247
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
1259
  async def get_edge(
1260
  self, source_node_id: str, target_node_id: str
1261
  ) -> dict[str, str] | None:
1262
+ """Get edge properties between two nodes"""
1263
+
1264
+ src_label = source_node_id.strip('"')
1265
+ tgt_label = target_node_id.strip('"')
1266
 
1267
  query = """SELECT * FROM cypher('%s', $$
1268
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1269
  RETURN properties(r) as edge_properties
1270
  LIMIT 1
1271
  $$) AS (edge_properties agtype)""" % (
 
1284
  Retrieves all edges (relationships) for a particular node identified by its label.
1285
  :return: list of dictionaries containing edge information
1286
  """
1287
+ label = source_node_id.strip('"')
1288
 
1289
  query = """SELECT * FROM cypher('%s', $$
1290
+ MATCH (n:base {entity_id: "%s"})
1291
+ OPTIONAL MATCH (n)-[]-(connected:base)
1292
  RETURN n, connected
1293
  $$) AS (n agtype, connected agtype)""" % (
1294
  self.graph_name,
 
1301
  source_node = record["n"] if record["n"] else None
1302
  connected_node = record["connected"] if record["connected"] else None
1303
 
1304
+ if (
1305
+ source_node
1306
+ and connected_node
1307
+ and "properties" in source_node
1308
+ and "properties" in connected_node
1309
+ ):
1310
+ source_label = source_node["properties"].get("entity_id")
1311
+ target_label = connected_node["properties"].get("entity_id")
 
 
1312
 
1313
+ if source_label and target_label:
1314
+ edges.append((source_label, target_label))
 
 
 
 
 
1315
 
1316
  return edges
1317
 
 
1321
  retry=retry_if_exception_type((PGGraphQueryException,)),
1322
  )
1323
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1324
+ """
1325
+ Upsert a node in the Neo4j database.
1326
+
1327
+ Args:
1328
+ node_id: The unique identifier for the node (used as label)
1329
+ node_data: Dictionary of node properties
1330
+ """
1331
+ if "entity_id" not in node_data:
1332
+ raise ValueError(
1333
+ "PostgreSQL: node properties must contain an 'entity_id' field"
1334
+ )
1335
+
1336
+ label = node_id.strip('"')
1337
+ properties = self._format_properties(node_data)
1338
 
1339
  query = """SELECT * FROM cypher('%s', $$
1340
+ MERGE (n:base {entity_id: "%s"})
1341
  SET n += %s
1342
  RETURN n
1343
  $$) AS (n agtype)""" % (
1344
  self.graph_name,
1345
  label,
1346
+ properties,
1347
  )
1348
 
1349
  try:
1350
  await self._query(query, readonly=False, upsert=True)
1351
 
1352
+ except Exception:
1353
+ logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`")
1354
  raise
1355
 
1356
  @retry(
 
1369
  target_node_id (str): Label of the target node (used as identifier)
1370
  edge_data (dict): dictionary of properties to set on the edge
1371
  """
1372
+ src_label = source_node_id.strip('"')
1373
+ tgt_label = target_node_id.strip('"')
1374
+ edge_properties = self._format_properties(edge_data)
1375
 
1376
  query = """SELECT * FROM cypher('%s', $$
1377
+ MATCH (source:base {entity_id: "%s"})
1378
  WITH source
1379
+ MATCH (target:base {entity_id: "%s"})
1380
  MERGE (source)-[r:DIRECTED]->(target)
1381
  SET r += %s
1382
  RETURN r
 
1384
  self.graph_name,
1385
  src_label,
1386
  tgt_label,
1387
+ edge_properties,
1388
  )
1389
 
1390
  try:
1391
  await self._query(query, readonly=False, upsert=True)
1392
 
1393
+ except Exception:
1394
+ logger.error(
1395
+ f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
1396
+ )
1397
  raise
1398
 
1399
  async def _node2vec_embed(self):
 
1406
  Args:
1407
  node_id (str): The ID of the node to delete.
1408
  """
1409
+ label = node_id.strip('"')
1410
 
1411
  query = """SELECT * FROM cypher('%s', $$
1412
+ MATCH (n:base {entity_id: "%s"})
1413
  DETACH DELETE n
1414
  $$) AS (n agtype)""" % (self.graph_name, label)
1415
 
 
1426
  Args:
1427
  node_ids (list[str]): A list of node IDs to remove.
1428
  """
1429
+ node_ids = [node_id.strip('"') for node_id in node_ids]
1430
+ node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
 
 
1431
 
1432
  query = """SELECT * FROM cypher('%s', $$
1433
+ MATCH (n:base)
1434
+ WHERE n.entity_id IN [%s]
1435
  DETACH DELETE n
1436
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1437
 
 
1448
  Args:
1449
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1450
  """
1451
+ for source, target in edges:
1452
+ src_label = source.strip('"')
1453
+ tgt_label = target.strip('"')
 
 
 
 
 
1454
 
1455
+ query = """SELECT * FROM cypher('%s', $$
1456
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1457
+ DELETE r
1458
+ $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
 
1459
 
1460
+ try:
1461
+ await self._query(query, readonly=False)
1462
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
1463
+ except Exception as e:
1464
+ logger.error(f"Error during edge deletion: {str(e)}")
1465
+ raise
1466
 
1467
  async def get_all_labels(self) -> list[str]:
1468
  """
 
1473
  """
1474
  query = (
1475
  """SELECT * FROM cypher('%s', $$
1476
+ MATCH (n:base)
1477
+ WHERE n.entity_id IS NOT NULL
1478
+ RETURN DISTINCT n.entity_id AS label
1479
+ ORDER BY n.entity_id
1480
  $$) AS (label text)"""
1481
  % self.graph_name
1482
  )
1483
 
1484
  results = await self._query(query)
1485
+ labels = [result["label"] for result in results]
 
1486
  return labels
1487
 
1488
  async def embed_nodes(
 
1504
  return await embed_func()
1505
 
1506
  async def get_knowledge_graph(
1507
+ self,
1508
+ node_label: str,
1509
+ max_depth: int = 3,
1510
+ max_nodes: int = MAX_GRAPH_NODES,
1511
  ) -> KnowledgeGraph:
1512
  """
1513
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
1514
 
1515
  Args:
1516
+ node_label: Label of the starting node, * means all nodes
1517
+ max_depth: Maximum depth of the subgraph, Defaults to 3
1518
+ max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
1519
 
1520
  Returns:
1521
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
1522
+ indicating whether the graph was truncated due to max_nodes limit
1523
  """
1524
+ # First, count the total number of nodes that would be returned without limit
1525
+ if node_label == "*":
1526
+ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1527
+ MATCH (n:base)
1528
+ RETURN count(distinct n) AS total_nodes
1529
+ $$) AS (total_nodes bigint)"""
1530
+ else:
1531
+ strip_label = node_label.strip('"')
1532
+ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1533
+ MATCH (n:base {{entity_id: "{strip_label}"}})
1534
+ OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1535
+ RETURN count(distinct m) AS total_nodes
1536
+ $$) AS (total_nodes bigint)"""
1537
+
1538
+ count_result = await self._query(count_query)
1539
+ total_nodes = count_result[0]["total_nodes"] if count_result else 0
1540
+ is_truncated = total_nodes > max_nodes
1541
+
1542
+ # Now get the actual data with limit
1543
  if node_label == "*":
1544
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1545
+ MATCH (n:base)
1546
+ OPTIONAL MATCH (n)-[r]->(target:base)
1547
+ RETURN collect(distinct n) AS n, collect(distinct r) AS r
1548
+ LIMIT {max_nodes}
1549
+ $$) AS (n agtype, r agtype)"""
1550
  else:
1551
+ strip_label = node_label.strip('"')
1552
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1553
+ MATCH (n:base {{entity_id: "{strip_label}"}})
1554
+ OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1555
+ RETURN nodes(p) AS n, relationships(p) AS r
1556
+ LIMIT {max_nodes}
1557
+ $$) AS (n agtype, r agtype)"""
1558
 
1559
  results = await self._query(query)
1560
 
1561
+ # Process the query results with deduplication by node and edge IDs
1562
+ nodes_dict = {}
1563
+ edges_dict = {}
1564
+ for result in results:
1565
+ # Handle single node cases
1566
+ if result.get("n") and isinstance(result["n"], dict):
1567
+ node_id = str(result["n"]["id"])
1568
+ if node_id not in nodes_dict:
1569
+ nodes_dict[node_id] = KnowledgeGraphNode(
1570
+ id=node_id,
1571
+ labels=[result["n"]["properties"]["entity_id"]],
1572
+ properties=result["n"]["properties"],
 
 
 
 
 
 
 
 
 
1573
  )
1574
+ # Handle node list cases
1575
+ elif result.get("n") and isinstance(result["n"], list):
1576
+ for node in result["n"]:
1577
+ if isinstance(node, dict) and "id" in node:
1578
+ node_id = str(node["id"])
1579
+ if node_id not in nodes_dict and "properties" in node:
1580
+ nodes_dict[node_id] = KnowledgeGraphNode(
1581
+ id=node_id,
1582
+ labels=[node["properties"]["entity_id"]],
1583
+ properties=node["properties"],
1584
+ )
1585
 
1586
+ # Handle single edge cases
1587
+ if result.get("r") and isinstance(result["r"], dict):
1588
+ edge_id = str(result["r"]["id"])
1589
+ if edge_id not in edges_dict:
1590
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1591
+ id=edge_id,
1592
+ type="DIRECTED",
1593
+ source=str(result["r"]["start_id"]),
1594
+ target=str(result["r"]["end_id"]),
1595
+ properties=result["r"]["properties"],
1596
+ )
1597
+ # Handle edge list cases
1598
+ elif result.get("r") and isinstance(result["r"], list):
1599
+ for edge in result["r"]:
1600
+ if isinstance(edge, dict) and "id" in edge:
1601
+ edge_id = str(edge["id"])
1602
+ if edge_id not in edges_dict:
1603
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1604
+ id=edge_id,
1605
+ type="DIRECTED",
1606
+ source=str(edge["start_id"]),
1607
+ target=str(edge["end_id"]),
1608
+ properties=edge["properties"],
1609
+ )
1610
 
1611
+ # Construct and return the KnowledgeGraph with deduplicated nodes and edges
1612
  kg = KnowledgeGraph(
1613
+ nodes=list(nodes_dict.values()),
1614
+ edges=list(edges_dict.values()),
1615
+ is_truncated=is_truncated,
 
 
 
 
 
 
 
 
 
 
 
1616
  )
1617
 
1618
+ logger.info(
1619
+ f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
1620
+ )
1621
  return kg
1622
 
1623
+ async def drop(self) -> dict[str, str]:
1624
  """Drop the storage"""
1625
+ try:
1626
+ drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1627
+ MATCH (n)
1628
+ DETACH DELETE n
1629
+ $$) AS (result agtype)"""
1630
+
1631
+ await self._query(drop_query, readonly=False)
1632
+ return {"status": "success", "message": "graph data dropped"}
1633
+ except Exception as e:
1634
+ logger.error(f"Error dropping graph: {e}")
1635
+ return {"status": "error", "message": str(e)}
1636
 
1637
 
1638
  NAMESPACE_TABLE_MAP = {
 
1790
  file_path=EXCLUDED.file_path,
1791
  update_time = CURRENT_TIMESTAMP
1792
  """,
1793
+ # SQL for VectorStorage
1794
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
1795
  content_vector, chunk_ids, file_path)
1796
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
 
1814
  file_path=EXCLUDED.file_path,
1815
  update_time = CURRENT_TIMESTAMP
1816
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1817
  "relationships": """
1818
  WITH relevant_chunks AS (
1819
  SELECT id as chunk_id
 
1854
  FROM LIGHTRAG_DOC_CHUNKS
1855
  WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1856
  )
1857
+ SELECT id, content, file_path FROM
1858
  (
1859
+ SELECT id, content, file_path, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1860
  FROM LIGHTRAG_DOC_CHUNKS
1861
  where workspace=$1
1862
  AND id IN (SELECT chunk_id FROM relevant_chunks)
 
1865
  ORDER BY distance DESC
1866
  LIMIT $3
1867
  """,
1868
+ # DROP tables
1869
+ "drop_specifiy_table_workspace": """
1870
+ DELETE FROM {table_name} WHERE workspace=$1
1871
+ """,
1872
  }
lightrag/kg/qdrant_impl.py CHANGED
@@ -8,17 +8,15 @@ import uuid
8
  from ..utils import logger
9
  from ..base import BaseVectorStorage
10
  import configparser
11
-
12
-
13
- config = configparser.ConfigParser()
14
- config.read("config.ini", "utf-8")
15
-
16
  import pipmaster as pm
17
 
18
  if not pm.is_installed("qdrant-client"):
19
  pm.install("qdrant-client")
20
 
21
- from qdrant_client import QdrantClient, models
 
 
 
22
 
23
 
24
  def compute_mdhash_id_for_qdrant(
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
275
  except Exception as e:
276
  logger.error(f"Error searching for prefix '{prefix}': {e}")
277
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from ..utils import logger
9
  from ..base import BaseVectorStorage
10
  import configparser
 
 
 
 
 
11
  import pipmaster as pm
12
 
13
  if not pm.is_installed("qdrant-client"):
14
  pm.install("qdrant-client")
15
 
16
+ from qdrant_client import QdrantClient, models # type: ignore
17
+
18
+ config = configparser.ConfigParser()
19
+ config.read("config.ini", "utf-8")
20
 
21
 
22
  def compute_mdhash_id_for_qdrant(
 
273
  except Exception as e:
274
  logger.error(f"Error searching for prefix '{prefix}': {e}")
275
  return []
276
+
277
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
278
+ """Get vector data by its ID
279
+
280
+ Args:
281
+ id: The unique identifier of the vector
282
+
283
+ Returns:
284
+ The vector data if found, or None if not found
285
+ """
286
+ try:
287
+ # Convert to Qdrant compatible ID
288
+ qdrant_id = compute_mdhash_id_for_qdrant(id)
289
+
290
+ # Retrieve the point by ID
291
+ result = self._client.retrieve(
292
+ collection_name=self.namespace,
293
+ ids=[qdrant_id],
294
+ with_payload=True,
295
+ )
296
+
297
+ if not result:
298
+ return None
299
+
300
+ return result[0].payload
301
+ except Exception as e:
302
+ logger.error(f"Error retrieving vector data for ID {id}: {e}")
303
+ return None
304
+
305
+ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
306
+ """Get multiple vector data by their IDs
307
+
308
+ Args:
309
+ ids: List of unique identifiers
310
+
311
+ Returns:
312
+ List of vector data objects that were found
313
+ """
314
+ if not ids:
315
+ return []
316
+
317
+ try:
318
+ # Convert to Qdrant compatible IDs
319
+ qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
320
+
321
+ # Retrieve the points by IDs
322
+ results = self._client.retrieve(
323
+ collection_name=self.namespace,
324
+ ids=qdrant_ids,
325
+ with_payload=True,
326
+ )
327
+
328
+ return [point.payload for point in results]
329
+ except Exception as e:
330
+ logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
331
+ return []
332
+
333
+ async def drop(self) -> dict[str, str]:
334
+ """Drop all vector data from storage and clean up resources
335
+
336
+ This method will delete all data from the Qdrant collection.
337
+
338
+ Returns:
339
+ dict[str, str]: Operation status and message
340
+ - On success: {"status": "success", "message": "data dropped"}
341
+ - On failure: {"status": "error", "message": "<error details>"}
342
+ """
343
+ try:
344
+ # Delete the collection and recreate it
345
+ if self._client.collection_exists(self.namespace):
346
+ self._client.delete_collection(self.namespace)
347
+
348
+ # Recreate the collection
349
+ QdrantVectorDBStorage.create_collection_if_not_exist(
350
+ self._client,
351
+ self.namespace,
352
+ vectors_config=models.VectorParams(
353
+ size=self.embedding_func.embedding_dim,
354
+ distance=models.Distance.COSINE,
355
+ ),
356
+ )
357
+
358
+ logger.info(
359
+ f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
360
+ )
361
+ return {"status": "success", "message": "data dropped"}
362
+ except Exception as e:
363
+ logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
364
+ return {"status": "error", "message": str(e)}
lightrag/kg/redis_impl.py CHANGED
@@ -8,8 +8,8 @@ if not pm.is_installed("redis"):
8
  pm.install("redis")
9
 
10
  # aioredis is a depricated library, replaced with redis
11
- from redis.asyncio import Redis
12
- from lightrag.utils import logger, compute_mdhash_id
13
  from lightrag.base import BaseKVStorage
14
  import json
15
 
@@ -84,66 +84,50 @@ class RedisKVStorage(BaseKVStorage):
84
  f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
85
  )
86
 
87
- async def delete_entity(self, entity_name: str) -> None:
88
- """Delete an entity by name
 
 
 
89
 
90
  Args:
91
- entity_name: Name of the entity to delete
 
 
 
 
92
  """
 
 
93
 
94
  try:
95
- entity_id = compute_mdhash_id(entity_name, prefix="ent-")
96
- logger.debug(
97
- f"Attempting to delete entity {entity_name} with ID {entity_id}"
98
- )
99
-
100
- # Delete the entity
101
- result = await self._redis.delete(f"{self.namespace}:{entity_id}")
102
 
103
- if result:
104
- logger.debug(f"Successfully deleted entity {entity_name}")
105
- else:
106
- logger.debug(f"Entity {entity_name} not found in storage")
107
- except Exception as e:
108
- logger.error(f"Error deleting entity {entity_name}: {e}")
109
 
110
- async def delete_entity_relation(self, entity_name: str) -> None:
111
- """Delete all relations associated with an entity
112
-
113
- Args:
114
- entity_name: Name of the entity whose relations should be deleted
115
  """
116
  try:
117
- # Get all keys in this namespace
118
- cursor = 0
119
- relation_keys = []
120
- pattern = f"{self.namespace}:*"
121
 
122
- while True:
123
- cursor, keys = await self._redis.scan(cursor, match=pattern)
124
-
125
- # For each key, get the value and check if it's related to entity_name
126
  for key in keys:
127
- value = await self._redis.get(key)
128
- if value:
129
- data = json.loads(value)
130
- # Check if this is a relation involving the entity
131
- if (
132
- data.get("src_id") == entity_name
133
- or data.get("tgt_id") == entity_name
134
- ):
135
- relation_keys.append(key)
136
-
137
- # Exit loop when cursor returns to 0
138
- if cursor == 0:
139
- break
140
-
141
- # Delete the relation keys
142
- if relation_keys:
143
- deleted = await self._redis.delete(*relation_keys)
144
- logger.debug(f"Deleted {deleted} relations for {entity_name}")
145
  else:
146
- logger.debug(f"No relations found for entity {entity_name}")
 
147
 
148
  except Exception as e:
149
- logger.error(f"Error deleting relations for {entity_name}: {e}")
 
 
8
  pm.install("redis")
9
 
10
  # aioredis is a depricated library, replaced with redis
11
+ from redis.asyncio import Redis # type: ignore
12
+ from lightrag.utils import logger
13
  from lightrag.base import BaseKVStorage
14
  import json
15
 
 
84
  f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
85
  )
86
 
87
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
88
+ """Delete specific records from storage by by cache mode
89
+
90
+ Importance notes for Redis storage:
91
+ 1. This will immediately delete the specified cache modes from Redis
92
 
93
  Args:
94
+ modes (list[str]): List of cache mode to be drop from storage
95
+
96
+ Returns:
97
+ True: if the cache drop successfully
98
+ False: if the cache drop failed
99
  """
100
+ if not modes:
101
+ return False
102
 
103
  try:
104
+ await self.delete(modes)
105
+ return True
106
+ except Exception:
107
+ return False
 
 
 
108
 
109
+ async def drop(self) -> dict[str, str]:
110
+ """Drop the storage by removing all keys under the current namespace.
 
 
 
 
111
 
112
+ Returns:
113
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
 
 
 
114
  """
115
  try:
116
+ keys = await self._redis.keys(f"{self.namespace}:*")
 
 
 
117
 
118
+ if keys:
119
+ pipe = self._redis.pipeline()
 
 
120
  for key in keys:
121
+ pipe.delete(key)
122
+ results = await pipe.execute()
123
+ deleted_count = sum(results)
124
+
125
+ logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
126
+ return {"status": "success", "message": f"{deleted_count} keys dropped"}
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
+ logger.info(f"No keys found to drop in {self.namespace}")
129
+ return {"status": "success", "message": "no keys to drop"}
130
 
131
  except Exception as e:
132
+ logger.error(f"Error dropping keys from {self.namespace}: {e}")
133
+ return {"status": "error", "message": str(e)}
lightrag/kg/tidb_impl.py CHANGED
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
20
  if not pm.is_installed("sqlalchemy"):
21
  pm.install("sqlalchemy")
22
 
23
- from sqlalchemy import create_engine, text
24
 
25
 
26
  class TiDB:
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
278
  # Ti handles persistence automatically
279
  pass
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  @final
283
  @dataclass
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
406
  params = {"workspace": self.db.workspace, "status": status}
407
  return await self.db.query(SQL, params, multirows=True)
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  async def delete_entity(self, entity_name: str) -> None:
410
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  async def delete_entity_relation(self, entity_name: str) -> None:
413
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  async def index_done_callback(self) -> None:
416
  # Ti handles persistence automatically
417
  pass
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
420
  """Search for records with IDs starting with a specific prefix.
421
 
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
710
  # Ti handles persistence automatically
711
  pass
712
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  async def delete_node(self, node_id: str) -> None:
714
  """Delete a node and all its related edges
715
 
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
1129
  FROM LIGHTRAG_DOC_CHUNKS
1130
  WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
1131
  """,
 
 
1132
  }
 
20
  if not pm.is_installed("sqlalchemy"):
21
  pm.install("sqlalchemy")
22
 
23
+ from sqlalchemy import create_engine, text # type: ignore
24
 
25
 
26
  class TiDB:
 
278
  # Ti handles persistence automatically
279
  pass
280
 
281
+ async def delete(self, ids: list[str]) -> None:
282
+ """Delete records with specified IDs from the storage.
283
+
284
+ Args:
285
+ ids: List of record IDs to be deleted
286
+ """
287
+ if not ids:
288
+ return
289
+
290
+ try:
291
+ table_name = namespace_to_table_name(self.namespace)
292
+ id_field = namespace_to_id(self.namespace)
293
+
294
+ if not table_name or not id_field:
295
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
296
+ return
297
+
298
+ ids_list = ",".join([f"'{id}'" for id in ids])
299
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
300
+
301
+ await self.db.execute(delete_sql, {"workspace": self.db.workspace})
302
+ logger.info(
303
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
304
+ )
305
+ except Exception as e:
306
+ logger.error(f"Error deleting records from {self.namespace}: {e}")
307
+
308
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
309
+ """Delete specific records from storage by cache mode
310
+
311
+ Args:
312
+ modes (list[str]): List of cache modes to be dropped from storage
313
+
314
+ Returns:
315
+ bool: True if successful, False otherwise
316
+ """
317
+ if not modes:
318
+ return False
319
+
320
+ try:
321
+ table_name = namespace_to_table_name(self.namespace)
322
+ if not table_name:
323
+ return False
324
+
325
+ if table_name != "LIGHTRAG_LLM_CACHE":
326
+ return False
327
+
328
+ # 构建MySQL风格的IN查询
329
+ modes_list = ", ".join([f"'{mode}'" for mode in modes])
330
+ sql = f"""
331
+ DELETE FROM {table_name}
332
+ WHERE workspace = :workspace
333
+ AND mode IN ({modes_list})
334
+ """
335
+
336
+ logger.info(f"Deleting cache by modes: {modes}")
337
+ await self.db.execute(sql, {"workspace": self.db.workspace})
338
+ return True
339
+ except Exception as e:
340
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
341
+ return False
342
+
343
+ async def drop(self) -> dict[str, str]:
344
+ """Drop the storage"""
345
+ try:
346
+ table_name = namespace_to_table_name(self.namespace)
347
+ if not table_name:
348
+ return {
349
+ "status": "error",
350
+ "message": f"Unknown namespace: {self.namespace}",
351
+ }
352
+
353
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
354
+ table_name=table_name
355
+ )
356
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
357
+ return {"status": "success", "message": "data dropped"}
358
+ except Exception as e:
359
+ return {"status": "error", "message": str(e)}
360
+
361
 
362
  @final
363
  @dataclass
 
486
  params = {"workspace": self.db.workspace, "status": status}
487
  return await self.db.query(SQL, params, multirows=True)
488
 
489
+ async def delete(self, ids: list[str]) -> None:
490
+ """Delete vectors with specified IDs from the storage.
491
+
492
+ Args:
493
+ ids: List of vector IDs to be deleted
494
+ """
495
+ if not ids:
496
+ return
497
+
498
+ table_name = namespace_to_table_name(self.namespace)
499
+ id_field = namespace_to_id(self.namespace)
500
+
501
+ if not table_name or not id_field:
502
+ logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
503
+ return
504
+
505
+ ids_list = ",".join([f"'{id}'" for id in ids])
506
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
507
+
508
+ try:
509
+ await self.db.execute(delete_sql, {"workspace": self.db.workspace})
510
+ logger.debug(
511
+ f"Successfully deleted {len(ids)} vectors from {self.namespace}"
512
+ )
513
+ except Exception as e:
514
+ logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
515
+
516
  async def delete_entity(self, entity_name: str) -> None:
517
+ """Delete an entity by its name from the vector storage.
518
+
519
+ Args:
520
+ entity_name: The name of the entity to delete
521
+ """
522
+ try:
523
+ # Construct SQL to delete the entity
524
+ delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
525
+ WHERE workspace = :workspace AND name = :entity_name"""
526
+
527
+ await self.db.execute(
528
+ delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
529
+ )
530
+ logger.debug(f"Successfully deleted entity {entity_name}")
531
+ except Exception as e:
532
+ logger.error(f"Error deleting entity {entity_name}: {e}")
533
 
534
  async def delete_entity_relation(self, entity_name: str) -> None:
535
+ """Delete all relations associated with an entity.
536
+
537
+ Args:
538
+ entity_name: The name of the entity whose relations should be deleted
539
+ """
540
+ try:
541
+ # Delete relations where the entity is either the source or target
542
+ delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
543
+ WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
544
+
545
+ await self.db.execute(
546
+ delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
547
+ )
548
+ logger.debug(f"Successfully deleted relations for entity {entity_name}")
549
+ except Exception as e:
550
+ logger.error(f"Error deleting relations for entity {entity_name}: {e}")
551
 
552
  async def index_done_callback(self) -> None:
553
  # Ti handles persistence automatically
554
  pass
555
 
556
+ async def drop(self) -> dict[str, str]:
557
+ """Drop the storage"""
558
+ try:
559
+ table_name = namespace_to_table_name(self.namespace)
560
+ if not table_name:
561
+ return {
562
+ "status": "error",
563
+ "message": f"Unknown namespace: {self.namespace}",
564
+ }
565
+
566
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
567
+ table_name=table_name
568
+ )
569
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
570
+ return {"status": "success", "message": "data dropped"}
571
+ except Exception as e:
572
+ return {"status": "error", "message": str(e)}
573
+
574
  async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
575
  """Search for records with IDs starting with a specific prefix.
576
 
 
865
  # Ti handles persistence automatically
866
  pass
867
 
868
+ async def drop(self) -> dict[str, str]:
869
+ """Drop the storage"""
870
+ try:
871
+ drop_sql = """
872
+ DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
873
+ DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
874
+ """
875
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
876
+ return {"status": "success", "message": "graph data dropped"}
877
+ except Exception as e:
878
+ return {"status": "error", "message": str(e)}
879
+
880
  async def delete_node(self, node_id: str) -> None:
881
  """Delete a node and all its related edges
882
 
 
1296
  FROM LIGHTRAG_DOC_CHUNKS
1297
  WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
1298
  """,
1299
+ # Drop tables
1300
+ "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
1301
  }
lightrag/lightrag.py CHANGED
@@ -13,7 +13,6 @@ import pandas as pd
13
 
14
 
15
  from lightrag.kg import (
16
- STORAGE_ENV_REQUIREMENTS,
17
  STORAGES,
18
  verify_storage_implementation,
19
  )
@@ -230,6 +229,7 @@ class LightRAG:
230
  vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
231
  """Additional parameters for vector database storage."""
232
 
 
233
  namespace_prefix: str = field(default="")
234
  """Prefix for namespacing stored data across different environments."""
235
 
@@ -510,36 +510,22 @@ class LightRAG:
510
  self,
511
  node_label: str,
512
  max_depth: int = 3,
513
- min_degree: int = 0,
514
- inclusive: bool = False,
515
  ) -> KnowledgeGraph:
516
  """Get knowledge graph for a given label
517
 
518
  Args:
519
  node_label (str): Label to get knowledge graph for
520
  max_depth (int): Maximum depth of graph
521
- min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
522
- inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
523
 
524
  Returns:
525
  KnowledgeGraph: Knowledge graph containing nodes and edges
526
  """
527
- # get params supported by get_knowledge_graph of specified storage
528
- import inspect
529
 
530
- storage_params = inspect.signature(
531
- self.chunk_entity_relation_graph.get_knowledge_graph
532
- ).parameters
533
-
534
- kwargs = {"node_label": node_label, "max_depth": max_depth}
535
-
536
- if "min_degree" in storage_params and min_degree > 0:
537
- kwargs["min_degree"] = min_degree
538
-
539
- if "inclusive" in storage_params:
540
- kwargs["inclusive"] = inclusive
541
-
542
- return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
543
 
544
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
545
  import_path = STORAGES[storage_name]
@@ -1449,6 +1435,7 @@ class LightRAG:
1449
  loop = always_get_an_event_loop()
1450
  return loop.run_until_complete(self.adelete_by_entity(entity_name))
1451
 
 
1452
  async def adelete_by_entity(self, entity_name: str) -> None:
1453
  try:
1454
  await self.entities_vdb.delete_entity(entity_name)
@@ -1486,6 +1473,7 @@ class LightRAG:
1486
  self.adelete_by_relation(source_entity, target_entity)
1487
  )
1488
 
 
1489
  async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
1490
  """Asynchronously delete a relation between two entities.
1491
 
@@ -1494,6 +1482,7 @@ class LightRAG:
1494
  target_entity: Name of the target entity
1495
  """
1496
  try:
 
1497
  # Check if the relation exists
1498
  edge_exists = await self.chunk_entity_relation_graph.has_edge(
1499
  source_entity, target_entity
@@ -1554,6 +1543,7 @@ class LightRAG:
1554
  """
1555
  return await self.doc_status.get_docs_by_status(status)
1556
 
 
1557
  async def adelete_by_doc_id(self, doc_id: str) -> None:
1558
  """Delete a document and all its related data
1559
 
@@ -1586,6 +1576,8 @@ class LightRAG:
1586
  chunk_ids = set(related_chunks.keys())
1587
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1588
 
 
 
1589
  # 3. Before deleting, check the related entities and relationships for these chunks
1590
  for chunk_id in chunk_ids:
1591
  # Check entities
@@ -1857,24 +1849,6 @@ class LightRAG:
1857
 
1858
  return result
1859
 
1860
- def check_storage_env_vars(self, storage_name: str) -> None:
1861
- """Check if all required environment variables for storage implementation exist
1862
-
1863
- Args:
1864
- storage_name: Storage implementation name
1865
-
1866
- Raises:
1867
- ValueError: If required environment variables are missing
1868
- """
1869
- required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
1870
- missing_vars = [var for var in required_vars if var not in os.environ]
1871
-
1872
- if missing_vars:
1873
- raise ValueError(
1874
- f"Storage implementation '{storage_name}' requires the following "
1875
- f"environment variables: {', '.join(missing_vars)}"
1876
- )
1877
-
1878
  async def aclear_cache(self, modes: list[str] | None = None) -> None:
1879
  """Clear cache data from the LLM response cache storage.
1880
 
@@ -1906,12 +1880,18 @@ class LightRAG:
1906
  try:
1907
  # Reset the cache storage for specified mode
1908
  if modes:
1909
- await self.llm_response_cache.delete(modes)
1910
- logger.info(f"Cleared cache for modes: {modes}")
 
 
 
1911
  else:
1912
  # Clear all modes
1913
- await self.llm_response_cache.delete(valid_modes)
1914
- logger.info("Cleared all cache")
 
 
 
1915
 
1916
  await self.llm_response_cache.index_done_callback()
1917
 
@@ -1922,6 +1902,7 @@ class LightRAG:
1922
  """Synchronous version of aclear_cache."""
1923
  return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
1924
 
 
1925
  async def aedit_entity(
1926
  self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
1927
  ) -> dict[str, Any]:
@@ -2134,6 +2115,7 @@ class LightRAG:
2134
  ]
2135
  )
2136
 
 
2137
  async def aedit_relation(
2138
  self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
2139
  ) -> dict[str, Any]:
@@ -2448,6 +2430,7 @@ class LightRAG:
2448
  self.acreate_relation(source_entity, target_entity, relation_data)
2449
  )
2450
 
 
2451
  async def amerge_entities(
2452
  self,
2453
  source_entities: list[str],
 
13
 
14
 
15
  from lightrag.kg import (
 
16
  STORAGES,
17
  verify_storage_implementation,
18
  )
 
229
  vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
230
  """Additional parameters for vector database storage."""
231
 
232
+ # TODO:deprecated, remove in the future, use WORKSPACE instead
233
  namespace_prefix: str = field(default="")
234
  """Prefix for namespacing stored data across different environments."""
235
 
 
510
  self,
511
  node_label: str,
512
  max_depth: int = 3,
513
+ max_nodes: int = 1000,
 
514
  ) -> KnowledgeGraph:
515
  """Get knowledge graph for a given label
516
 
517
  Args:
518
  node_label (str): Label to get knowledge graph for
519
  max_depth (int): Maximum depth of graph
520
+ max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
 
521
 
522
  Returns:
523
  KnowledgeGraph: Knowledge graph containing nodes and edges
524
  """
 
 
525
 
526
+ return await self.chunk_entity_relation_graph.get_knowledge_graph(
527
+ node_label, max_depth, max_nodes
528
+ )
 
 
 
 
 
 
 
 
 
 
529
 
530
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
531
  import_path = STORAGES[storage_name]
 
1435
  loop = always_get_an_event_loop()
1436
  return loop.run_until_complete(self.adelete_by_entity(entity_name))
1437
 
1438
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1439
  async def adelete_by_entity(self, entity_name: str) -> None:
1440
  try:
1441
  await self.entities_vdb.delete_entity(entity_name)
 
1473
  self.adelete_by_relation(source_entity, target_entity)
1474
  )
1475
 
1476
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1477
  async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
1478
  """Asynchronously delete a relation between two entities.
1479
 
 
1482
  target_entity: Name of the target entity
1483
  """
1484
  try:
1485
+ # TODO: check if has_edge function works on reverse relation
1486
  # Check if the relation exists
1487
  edge_exists = await self.chunk_entity_relation_graph.has_edge(
1488
  source_entity, target_entity
 
1543
  """
1544
  return await self.doc_status.get_docs_by_status(status)
1545
 
1546
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1547
  async def adelete_by_doc_id(self, doc_id: str) -> None:
1548
  """Delete a document and all its related data
1549
 
 
1576
  chunk_ids = set(related_chunks.keys())
1577
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1578
 
1579
+ # TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
1580
+
1581
  # 3. Before deleting, check the related entities and relationships for these chunks
1582
  for chunk_id in chunk_ids:
1583
  # Check entities
 
1849
 
1850
  return result
1851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1852
  async def aclear_cache(self, modes: list[str] | None = None) -> None:
1853
  """Clear cache data from the LLM response cache storage.
1854
 
 
1880
  try:
1881
  # Reset the cache storage for specified mode
1882
  if modes:
1883
+ success = await self.llm_response_cache.drop_cache_by_modes(modes)
1884
+ if success:
1885
+ logger.info(f"Cleared cache for modes: {modes}")
1886
+ else:
1887
+ logger.warning(f"Failed to clear cache for modes: {modes}")
1888
  else:
1889
  # Clear all modes
1890
+ success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
1891
+ if success:
1892
+ logger.info("Cleared all cache")
1893
+ else:
1894
+ logger.warning("Failed to clear all cache")
1895
 
1896
  await self.llm_response_cache.index_done_callback()
1897
 
 
1902
  """Synchronous version of aclear_cache."""
1903
  return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
1904
 
1905
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1906
  async def aedit_entity(
1907
  self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
1908
  ) -> dict[str, Any]:
 
2115
  ]
2116
  )
2117
 
2118
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
2119
  async def aedit_relation(
2120
  self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
2121
  ) -> dict[str, Any]:
 
2430
  self.acreate_relation(source_entity, target_entity, relation_data)
2431
  )
2432
 
2433
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
2434
  async def amerge_entities(
2435
  self,
2436
  source_entities: list[str],
lightrag/operate.py CHANGED
@@ -25,7 +25,6 @@ from .utils import (
25
  CacheData,
26
  statistic_data,
27
  get_conversation_turns,
28
- verbose_debug,
29
  )
30
  from .base import (
31
  BaseGraphStorage,
@@ -441,6 +440,13 @@ async def extract_entities(
441
 
442
  processed_chunks = 0
443
  total_chunks = len(ordered_chunks)
 
 
 
 
 
 
 
444
 
445
  async def _user_llm_func_with_cache(
446
  input_text: str, history_messages: list[dict[str, str]] = None
@@ -539,7 +545,7 @@ async def extract_entities(
539
  chunk_key_dp (tuple[str, TextChunkSchema]):
540
  ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
541
  """
542
- nonlocal processed_chunks
543
  chunk_key = chunk_key_dp[0]
544
  chunk_dp = chunk_key_dp[1]
545
  content = chunk_dp["content"]
@@ -597,102 +603,74 @@ async def extract_entities(
597
  async with pipeline_status_lock:
598
  pipeline_status["latest_message"] = log_message
599
  pipeline_status["history_messages"].append(log_message)
600
- return dict(maybe_nodes), dict(maybe_edges)
601
-
602
- tasks = [_process_single_content(c) for c in ordered_chunks]
603
- results = await asyncio.gather(*tasks)
604
 
605
- maybe_nodes = defaultdict(list)
606
- maybe_edges = defaultdict(list)
607
- for m_nodes, m_edges in results:
608
- for k, v in m_nodes.items():
609
- maybe_nodes[k].extend(v)
610
- for k, v in m_edges.items():
611
- maybe_edges[tuple(sorted(k))].extend(v)
612
 
613
- from .kg.shared_storage import get_graph_db_lock
614
-
615
- graph_db_lock = get_graph_db_lock(enable_logging=False)
616
-
617
- # Ensure that nodes and edges are merged and upserted atomically
618
- async with graph_db_lock:
619
- all_entities_data = await asyncio.gather(
620
- *[
621
- _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
622
- for k, v in maybe_nodes.items()
623
- ]
624
- )
625
-
626
- all_relationships_data = await asyncio.gather(
627
- *[
628
- _merge_edges_then_upsert(
629
- k[0], k[1], v, knowledge_graph_inst, global_config
630
  )
631
- for k, v in maybe_edges.items()
632
- ]
633
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
- if not (all_entities_data or all_relationships_data):
636
- log_message = "Didn't extract any entities and relationships."
637
- logger.info(log_message)
638
- if pipeline_status is not None:
639
- async with pipeline_status_lock:
640
- pipeline_status["latest_message"] = log_message
641
- pipeline_status["history_messages"].append(log_message)
642
- return
643
 
644
- if not all_entities_data:
645
- log_message = "Didn't extract any entities"
646
- logger.info(log_message)
647
- if pipeline_status is not None:
648
- async with pipeline_status_lock:
649
- pipeline_status["latest_message"] = log_message
650
- pipeline_status["history_messages"].append(log_message)
651
- if not all_relationships_data:
652
- log_message = "Didn't extract any relationships"
653
- logger.info(log_message)
654
- if pipeline_status is not None:
655
- async with pipeline_status_lock:
656
- pipeline_status["latest_message"] = log_message
657
- pipeline_status["history_messages"].append(log_message)
658
 
659
- log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
660
  logger.info(log_message)
661
  if pipeline_status is not None:
662
  async with pipeline_status_lock:
663
  pipeline_status["latest_message"] = log_message
664
  pipeline_status["history_messages"].append(log_message)
665
- verbose_debug(
666
- f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
667
- )
668
- verbose_debug(f"New relationships:{all_relationships_data}")
669
-
670
- if entity_vdb is not None:
671
- data_for_vdb = {
672
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
673
- "entity_name": dp["entity_name"],
674
- "entity_type": dp["entity_type"],
675
- "content": f"{dp['entity_name']}\n{dp['description']}",
676
- "source_id": dp["source_id"],
677
- "file_path": dp.get("file_path", "unknown_source"),
678
- }
679
- for dp in all_entities_data
680
- }
681
- await entity_vdb.upsert(data_for_vdb)
682
-
683
- if relationships_vdb is not None:
684
- data_for_vdb = {
685
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
686
- "src_id": dp["src_id"],
687
- "tgt_id": dp["tgt_id"],
688
- "keywords": dp["keywords"],
689
- "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
690
- "source_id": dp["source_id"],
691
- "file_path": dp.get("file_path", "unknown_source"),
692
- }
693
- for dp in all_relationships_data
694
- }
695
- await relationships_vdb.upsert(data_for_vdb)
696
 
697
 
698
  async def kg_query(
@@ -1367,7 +1345,9 @@ async def _get_node_data(
1367
 
1368
  text_units_section_list = [["id", "content", "file_path"]]
1369
  for i, t in enumerate(use_text_units):
1370
- text_units_section_list.append([i, t["content"], t["file_path"]])
 
 
1371
  text_units_context = list_of_list_to_csv(text_units_section_list)
1372
  return entities_context, relations_context, text_units_context
1373
 
 
25
  CacheData,
26
  statistic_data,
27
  get_conversation_turns,
 
28
  )
29
  from .base import (
30
  BaseGraphStorage,
 
440
 
441
  processed_chunks = 0
442
  total_chunks = len(ordered_chunks)
443
+ total_entities_count = 0
444
+ total_relations_count = 0
445
+
446
+ # Get lock manager from shared storage
447
+ from .kg.shared_storage import get_graph_db_lock
448
+
449
+ graph_db_lock = get_graph_db_lock(enable_logging=False)
450
 
451
  async def _user_llm_func_with_cache(
452
  input_text: str, history_messages: list[dict[str, str]] = None
 
545
  chunk_key_dp (tuple[str, TextChunkSchema]):
546
  ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
547
  """
548
+ nonlocal processed_chunks, total_entities_count, total_relations_count
549
  chunk_key = chunk_key_dp[0]
550
  chunk_dp = chunk_key_dp[1]
551
  content = chunk_dp["content"]
 
603
  async with pipeline_status_lock:
604
  pipeline_status["latest_message"] = log_message
605
  pipeline_status["history_messages"].append(log_message)
 
 
 
 
606
 
607
+ # Use graph database lock to ensure atomic merges and updates
608
+ chunk_entities_data = []
609
+ chunk_relationships_data = []
 
 
 
 
610
 
611
+ async with graph_db_lock:
612
+ # Process and update entities
613
+ for entity_name, entities in maybe_nodes.items():
614
+ entity_data = await _merge_nodes_then_upsert(
615
+ entity_name, entities, knowledge_graph_inst, global_config
 
 
 
 
 
 
 
 
 
 
 
 
616
  )
617
+ chunk_entities_data.append(entity_data)
618
+
619
+ # Process and update relationships
620
+ for edge_key, edges in maybe_edges.items():
621
+ # Ensure edge direction consistency
622
+ sorted_edge_key = tuple(sorted(edge_key))
623
+ edge_data = await _merge_edges_then_upsert(
624
+ sorted_edge_key[0],
625
+ sorted_edge_key[1],
626
+ edges,
627
+ knowledge_graph_inst,
628
+ global_config,
629
+ )
630
+ chunk_relationships_data.append(edge_data)
631
+
632
+ # Update vector database (within the same lock to ensure atomicity)
633
+ if entity_vdb is not None and chunk_entities_data:
634
+ data_for_vdb = {
635
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
636
+ "entity_name": dp["entity_name"],
637
+ "entity_type": dp["entity_type"],
638
+ "content": f"{dp['entity_name']}\n{dp['description']}",
639
+ "source_id": dp["source_id"],
640
+ "file_path": dp.get("file_path", "unknown_source"),
641
+ }
642
+ for dp in chunk_entities_data
643
+ }
644
+ await entity_vdb.upsert(data_for_vdb)
645
+
646
+ if relationships_vdb is not None and chunk_relationships_data:
647
+ data_for_vdb = {
648
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
649
+ "src_id": dp["src_id"],
650
+ "tgt_id": dp["tgt_id"],
651
+ "keywords": dp["keywords"],
652
+ "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
653
+ "source_id": dp["source_id"],
654
+ "file_path": dp.get("file_path", "unknown_source"),
655
+ }
656
+ for dp in chunk_relationships_data
657
+ }
658
+ await relationships_vdb.upsert(data_for_vdb)
659
 
660
+ # Update counters
661
+ total_entities_count += len(chunk_entities_data)
662
+ total_relations_count += len(chunk_relationships_data)
 
 
 
 
 
663
 
664
+ # Handle all chunks in parallel
665
+ tasks = [_process_single_content(c) for c in ordered_chunks]
666
+ await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
667
 
668
+ log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
669
  logger.info(log_message)
670
  if pipeline_status is not None:
671
  async with pipeline_status_lock:
672
  pipeline_status["latest_message"] = log_message
673
  pipeline_status["history_messages"].append(log_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
 
676
  async def kg_query(
 
1345
 
1346
  text_units_section_list = [["id", "content", "file_path"]]
1347
  for i, t in enumerate(use_text_units):
1348
+ text_units_section_list.append(
1349
+ [i, t["content"], t.get("file_path", "unknown_source")]
1350
+ )
1351
  text_units_context = list_of_list_to_csv(text_units_section_list)
1352
  return entities_context, relations_context, text_units_context
1353
 
lightrag/types.py CHANGED
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
26
  class KnowledgeGraph(BaseModel):
27
  nodes: list[KnowledgeGraphNode] = []
28
  edges: list[KnowledgeGraphEdge] = []
 
 
26
  class KnowledgeGraph(BaseModel):
27
  nodes: list[KnowledgeGraphNode] = []
28
  edges: list[KnowledgeGraphEdge] = []
29
+ is_truncated: bool = False
lightrag_webui/src/AppRouter.tsx CHANGED
@@ -80,7 +80,12 @@ const AppRouter = () => {
80
  <ThemeProvider>
81
  <Router>
82
  <AppContent />
83
- <Toaster position="bottom-center" />
 
 
 
 
 
84
  </Router>
85
  </ThemeProvider>
86
  )
 
80
  <ThemeProvider>
81
  <Router>
82
  <AppContent />
83
+ <Toaster
84
+ position="bottom-center"
85
+ theme="system"
86
+ closeButton
87
+ richColors
88
+ />
89
  </Router>
90
  </ThemeProvider>
91
  )
lightrag_webui/src/api/lightrag.ts CHANGED
@@ -3,6 +3,7 @@ import { backendBaseUrl } from '@/lib/constants'
3
  import { errorMessage } from '@/lib/utils'
4
  import { useSettingsStore } from '@/stores/settings'
5
  import { navigationService } from '@/services/navigation'
 
6
 
7
  // Types
8
  export type LightragNodeType = {
@@ -46,6 +47,8 @@ export type LightragStatus = {
46
  api_version?: string
47
  auth_mode?: 'enabled' | 'disabled'
48
  pipeline_busy: boolean
 
 
49
  }
50
 
51
  export type LightragDocumentsScanProgress = {
@@ -140,6 +143,8 @@ export type AuthStatusResponse = {
140
  message?: string
141
  core_version?: string
142
  api_version?: string
 
 
143
  }
144
 
145
  export type PipelineStatusResponse = {
@@ -163,6 +168,8 @@ export type LoginResponse = {
163
  message?: string // Optional message
164
  core_version?: string
165
  api_version?: string
 
 
166
  }
167
 
168
  export const InvalidApiKeyError = 'Invalid API Key'
@@ -221,9 +228,9 @@ axiosInstance.interceptors.response.use(
221
  export const queryGraphs = async (
222
  label: string,
223
  maxDepth: number,
224
- minDegree: number
225
  ): Promise<LightragGraphType> => {
226
- const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
227
  return response.data
228
  }
229
 
@@ -382,6 +389,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
382
  return response.data
383
  }
384
 
 
 
 
 
 
 
 
 
385
  export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
386
  try {
387
  // Add a timeout to the request to prevent hanging
@@ -411,12 +426,26 @@ export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
411
  // For unconfigured auth, ensure we have an access token
412
  if (!response.data.auth_configured) {
413
  if (response.data.access_token && typeof response.data.access_token === 'string') {
 
 
 
 
 
 
 
414
  return response.data;
415
  } else {
416
  console.warn('Auth not configured but no valid access token provided');
417
  }
418
  } else {
419
  // For configured auth, just return the data
 
 
 
 
 
 
 
420
  return response.data;
421
  }
422
  }
@@ -455,5 +484,13 @@ export const loginToServer = async (username: string, password: string): Promise
455
  }
456
  });
457
 
 
 
 
 
 
 
 
 
458
  return response.data;
459
  }
 
3
  import { errorMessage } from '@/lib/utils'
4
  import { useSettingsStore } from '@/stores/settings'
5
  import { navigationService } from '@/services/navigation'
6
+ import { useAuthStore } from '@/stores/state'
7
 
8
  // Types
9
  export type LightragNodeType = {
 
47
  api_version?: string
48
  auth_mode?: 'enabled' | 'disabled'
49
  pipeline_busy: boolean
50
+ webui_title?: string
51
+ webui_description?: string
52
  }
53
 
54
  export type LightragDocumentsScanProgress = {
 
143
  message?: string
144
  core_version?: string
145
  api_version?: string
146
+ webui_title?: string
147
+ webui_description?: string
148
  }
149
 
150
  export type PipelineStatusResponse = {
 
168
  message?: string // Optional message
169
  core_version?: string
170
  api_version?: string
171
+ webui_title?: string
172
+ webui_description?: string
173
  }
174
 
175
  export const InvalidApiKeyError = 'Invalid API Key'
 
228
  export const queryGraphs = async (
229
  label: string,
230
  maxDepth: number,
231
+ maxNodes: number
232
  ): Promise<LightragGraphType> => {
233
+ const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
234
  return response.data
235
  }
236
 
 
389
  return response.data
390
  }
391
 
392
+ export const clearCache = async (modes?: string[]): Promise<{
393
+ status: 'success' | 'fail'
394
+ message: string
395
+ }> => {
396
+ const response = await axiosInstance.post('/documents/clear_cache', { modes })
397
+ return response.data
398
+ }
399
+
400
  export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
401
  try {
402
  // Add a timeout to the request to prevent hanging
 
426
  // For unconfigured auth, ensure we have an access token
427
  if (!response.data.auth_configured) {
428
  if (response.data.access_token && typeof response.data.access_token === 'string') {
429
+ // Update custom title if available
430
+ if ('webui_title' in response.data || 'webui_description' in response.data) {
431
+ useAuthStore.getState().setCustomTitle(
432
+ 'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
433
+ 'webui_description' in response.data ? (response.data.webui_description ?? null) : null
434
+ );
435
+ }
436
  return response.data;
437
  } else {
438
  console.warn('Auth not configured but no valid access token provided');
439
  }
440
  } else {
441
  // For configured auth, just return the data
442
+ // Update custom title if available
443
+ if ('webui_title' in response.data || 'webui_description' in response.data) {
444
+ useAuthStore.getState().setCustomTitle(
445
+ 'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
446
+ 'webui_description' in response.data ? (response.data.webui_description ?? null) : null
447
+ );
448
+ }
449
  return response.data;
450
  }
451
  }
 
484
  }
485
  });
486
 
487
+ // Update custom title if available
488
+ if ('webui_title' in response.data || 'webui_description' in response.data) {
489
+ useAuthStore.getState().setCustomTitle(
490
+ 'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
491
+ 'webui_description' in response.data ? (response.data.webui_description ?? null) : null
492
+ );
493
+ }
494
+
495
  return response.data;
496
  }
lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useState, useCallback } from 'react'
2
  import Button from '@/components/ui/Button'
3
  import {
4
  Dialog,
@@ -6,32 +6,88 @@ import {
6
  DialogDescription,
7
  DialogHeader,
8
  DialogTitle,
9
- DialogTrigger
 
10
  } from '@/components/ui/Dialog'
 
 
11
  import { toast } from 'sonner'
12
  import { errorMessage } from '@/lib/utils'
13
- import { clearDocuments } from '@/api/lightrag'
14
 
15
- import { EraserIcon } from 'lucide-react'
16
  import { useTranslation } from 'react-i18next'
17
 
18
- export default function ClearDocumentsDialog() {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  const { t } = useTranslation()
20
  const [open, setOpen] = useState(false)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  const handleClear = useCallback(async () => {
 
 
23
  try {
24
  const result = await clearDocuments()
25
- if (result.status === 'success') {
26
- toast.success(t('documentPanel.clearDocuments.success'))
27
- setOpen(false)
28
- } else {
29
  toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
 
 
30
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  } catch (err) {
32
  toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
 
33
  }
34
- }, [setOpen, t])
35
 
36
  return (
37
  <Dialog open={open} onOpenChange={setOpen}>
@@ -42,12 +98,58 @@ export default function ClearDocumentsDialog() {
42
  </DialogTrigger>
43
  <DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
44
  <DialogHeader>
45
- <DialogTitle>{t('documentPanel.clearDocuments.title')}</DialogTitle>
46
- <DialogDescription>{t('documentPanel.clearDocuments.confirm')}</DialogDescription>
 
 
 
 
 
 
 
 
 
 
47
  </DialogHeader>
48
- <Button variant="destructive" onClick={handleClear}>
49
- {t('documentPanel.clearDocuments.confirmButton')}
50
- </Button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  </DialogContent>
52
  </Dialog>
53
  )
 
1
+ import { useState, useCallback, useEffect } from 'react'
2
  import Button from '@/components/ui/Button'
3
  import {
4
  Dialog,
 
6
  DialogDescription,
7
  DialogHeader,
8
  DialogTitle,
9
+ DialogTrigger,
10
+ DialogFooter
11
  } from '@/components/ui/Dialog'
12
+ import Input from '@/components/ui/Input'
13
+ import Checkbox from '@/components/ui/Checkbox'
14
  import { toast } from 'sonner'
15
  import { errorMessage } from '@/lib/utils'
16
+ import { clearDocuments, clearCache } from '@/api/lightrag'
17
 
18
+ import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
19
  import { useTranslation } from 'react-i18next'
20
 
21
+ // 简单的Label组件
22
+ const Label = ({
23
+ htmlFor,
24
+ className,
25
+ children,
26
+ ...props
27
+ }: React.LabelHTMLAttributes<HTMLLabelElement>) => (
28
+ <label
29
+ htmlFor={htmlFor}
30
+ className={className}
31
+ {...props}
32
+ >
33
+ {children}
34
+ </label>
35
+ )
36
+
37
+ interface ClearDocumentsDialogProps {
38
+ onDocumentsCleared?: () => Promise<void>
39
+ }
40
+
41
+ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
42
  const { t } = useTranslation()
43
  const [open, setOpen] = useState(false)
44
+ const [confirmText, setConfirmText] = useState('')
45
+ const [clearCacheOption, setClearCacheOption] = useState(false)
46
+ const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
47
+
48
+ // 重置状态当对话框关闭时
49
+ useEffect(() => {
50
+ if (!open) {
51
+ setConfirmText('')
52
+ setClearCacheOption(false)
53
+ }
54
+ }, [open])
55
 
56
  const handleClear = useCallback(async () => {
57
+ if (!isConfirmEnabled) return
58
+
59
  try {
60
  const result = await clearDocuments()
61
+
62
+ if (result.status !== 'success') {
 
 
63
  toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
64
+ setConfirmText('')
65
+ return
66
  }
67
+
68
+ toast.success(t('documentPanel.clearDocuments.success'))
69
+
70
+ if (clearCacheOption) {
71
+ try {
72
+ await clearCache()
73
+ toast.success(t('documentPanel.clearDocuments.cacheCleared'))
74
+ } catch (cacheErr) {
75
+ toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
76
+ }
77
+ }
78
+
79
+ // Refresh document list if provided
80
+ if (onDocumentsCleared) {
81
+ onDocumentsCleared().catch(console.error)
82
+ }
83
+
84
+ // 所有操作成功后关闭对话框
85
+ setOpen(false)
86
  } catch (err) {
87
  toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
88
+ setConfirmText('')
89
  }
90
+ }, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
91
 
92
  return (
93
  <Dialog open={open} onOpenChange={setOpen}>
 
98
  </DialogTrigger>
99
  <DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
100
  <DialogHeader>
101
+ <DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
102
+ <AlertTriangleIcon className="h-5 w-5" />
103
+ {t('documentPanel.clearDocuments.title')}
104
+ </DialogTitle>
105
+ <DialogDescription className="pt-2">
106
+ <div className="text-red-500 dark:text-red-400 font-semibold mb-4">
107
+ {t('documentPanel.clearDocuments.warning')}
108
+ </div>
109
+ <div className="mb-4">
110
+ {t('documentPanel.clearDocuments.confirm')}
111
+ </div>
112
+ </DialogDescription>
113
  </DialogHeader>
114
+
115
+ <div className="space-y-4">
116
+ <div className="space-y-2">
117
+ <Label htmlFor="confirm-text" className="text-sm font-medium">
118
+ {t('documentPanel.clearDocuments.confirmPrompt')}
119
+ </Label>
120
+ <Input
121
+ id="confirm-text"
122
+ value={confirmText}
123
+ onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
124
+ placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
125
+ className="w-full"
126
+ />
127
+ </div>
128
+
129
+ <div className="flex items-center space-x-2">
130
+ <Checkbox
131
+ id="clear-cache"
132
+ checked={clearCacheOption}
133
+ onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
134
+ />
135
+ <Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
136
+ {t('documentPanel.clearDocuments.clearCache')}
137
+ </Label>
138
+ </div>
139
+ </div>
140
+
141
+ <DialogFooter>
142
+ <Button variant="outline" onClick={() => setOpen(false)}>
143
+ {t('common.cancel')}
144
+ </Button>
145
+ <Button
146
+ variant="destructive"
147
+ onClick={handleClear}
148
+ disabled={!isConfirmEnabled}
149
+ >
150
+ {t('documentPanel.clearDocuments.confirmButton')}
151
+ </Button>
152
+ </DialogFooter>
153
  </DialogContent>
154
  </Dialog>
155
  )
lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx CHANGED
@@ -17,7 +17,11 @@ import { uploadDocument } from '@/api/lightrag'
17
  import { UploadIcon } from 'lucide-react'
18
  import { useTranslation } from 'react-i18next'
19
 
20
- export default function UploadDocumentsDialog() {
 
 
 
 
21
  const { t } = useTranslation()
22
  const [open, setOpen] = useState(false)
23
  const [isUploading, setIsUploading] = useState(false)
@@ -55,6 +59,7 @@ export default function UploadDocumentsDialog() {
55
  const handleDocumentsUpload = useCallback(
56
  async (filesToUpload: File[]) => {
57
  setIsUploading(true)
 
58
 
59
  // Only clear errors for files that are being uploaded, keep errors for rejected files
60
  setFileErrors(prev => {
@@ -101,6 +106,9 @@ export default function UploadDocumentsDialog() {
101
  ...prev,
102
  [file.name]: result.message
103
  }))
 
 
 
104
  }
105
  } catch (err) {
106
  console.error(`Upload failed for ${file.name}:`, err)
@@ -142,6 +150,16 @@ export default function UploadDocumentsDialog() {
142
  } else {
143
  toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
144
  }
 
 
 
 
 
 
 
 
 
 
145
  } catch (err) {
146
  console.error('Unexpected error during upload:', err)
147
  toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
@@ -149,7 +167,7 @@ export default function UploadDocumentsDialog() {
149
  setIsUploading(false)
150
  }
151
  },
152
- [setIsUploading, setProgresses, setFileErrors, t]
153
  )
154
 
155
  return (
 
17
  import { UploadIcon } from 'lucide-react'
18
  import { useTranslation } from 'react-i18next'
19
 
20
+ interface UploadDocumentsDialogProps {
21
+ onDocumentsUploaded?: () => Promise<void>
22
+ }
23
+
24
+ export default function UploadDocumentsDialog({ onDocumentsUploaded }: UploadDocumentsDialogProps) {
25
  const { t } = useTranslation()
26
  const [open, setOpen] = useState(false)
27
  const [isUploading, setIsUploading] = useState(false)
 
59
  const handleDocumentsUpload = useCallback(
60
  async (filesToUpload: File[]) => {
61
  setIsUploading(true)
62
+ let hasSuccessfulUpload = false
63
 
64
  // Only clear errors for files that are being uploaded, keep errors for rejected files
65
  setFileErrors(prev => {
 
106
  ...prev,
107
  [file.name]: result.message
108
  }))
109
+ } else {
110
+ // Mark that we had at least one successful upload
111
+ hasSuccessfulUpload = true
112
  }
113
  } catch (err) {
114
  console.error(`Upload failed for ${file.name}:`, err)
 
150
  } else {
151
  toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
152
  }
153
+
154
+ // Only update if at least one file was uploaded successfully
155
+ if (hasSuccessfulUpload) {
156
+ // Refresh document list
157
+ if (onDocumentsUploaded) {
158
+ onDocumentsUploaded().catch(err => {
159
+ console.error('Error refreshing documents:', err)
160
+ })
161
+ }
162
+ }
163
  } catch (err) {
164
  console.error('Unexpected error during upload:', err)
165
  toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
 
167
  setIsUploading(false)
168
  }
169
  },
170
+ [setIsUploading, setProgresses, setFileErrors, t, onDocumentsUploaded]
171
  )
172
 
173
  return (
lightrag_webui/src/components/graph/Settings.tsx CHANGED
@@ -8,7 +8,7 @@ import Input from '@/components/ui/Input'
8
  import { controlButtonVariant } from '@/lib/constants'
9
  import { useSettingsStore } from '@/stores/settings'
10
 
11
- import { SettingsIcon } from 'lucide-react'
12
  import { useTranslation } from 'react-i18next';
13
 
14
  /**
@@ -44,14 +44,17 @@ const LabeledNumberInput = ({
44
  onEditFinished,
45
  label,
46
  min,
47
- max
 
48
  }: {
49
  value: number
50
  onEditFinished: (value: number) => void
51
  label: string
52
  min: number
53
  max?: number
 
54
  }) => {
 
55
  const [currentValue, setCurrentValue] = useState<number | null>(value)
56
 
57
  const onValueChange = useCallback(
@@ -81,6 +84,13 @@ const LabeledNumberInput = ({
81
  }
82
  }, [value, currentValue, onEditFinished])
83
 
 
 
 
 
 
 
 
84
  return (
85
  <div className="flex flex-col gap-2">
86
  <label
@@ -89,20 +99,34 @@ const LabeledNumberInput = ({
89
  >
90
  {label}
91
  </label>
92
- <Input
93
- type="number"
94
- value={currentValue === null ? '' : currentValue}
95
- onChange={onValueChange}
96
- className="h-6 w-full min-w-0 pr-1"
97
- min={min}
98
- max={max}
99
- onBlur={onBlur}
100
- onKeyDown={(e) => {
101
- if (e.key === 'Enter') {
102
- onBlur()
103
- }
104
- }}
105
- />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  </div>
107
  )
108
  }
@@ -121,7 +145,7 @@ export default function Settings() {
121
  const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
122
  const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
123
  const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
124
- const graphMinDegree = useSettingsStore.use.graphMinDegree()
125
  const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
126
 
127
  const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
@@ -180,15 +204,14 @@ export default function Settings() {
180
  }, 300)
181
  }, [])
182
 
183
- const setGraphMinDegree = useCallback((degree: number) => {
184
- if (degree < 0) return
185
- useSettingsStore.setState({ graphMinDegree: degree })
186
  const currentLabel = useSettingsStore.getState().queryLabel
187
  useSettingsStore.getState().setQueryLabel('')
188
  setTimeout(() => {
189
  useSettingsStore.getState().setQueryLabel(currentLabel)
190
  }, 300)
191
-
192
  }, [])
193
 
194
  const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
@@ -274,19 +297,23 @@ export default function Settings() {
274
  label={t('graphPanel.sideBar.settings.maxQueryDepth')}
275
  min={1}
276
  value={graphQueryMaxDepth}
 
277
  onEditFinished={setGraphQueryMaxDepth}
278
  />
279
  <LabeledNumberInput
280
- label={t('graphPanel.sideBar.settings.minDegree')}
281
- min={0}
282
- value={graphMinDegree}
283
- onEditFinished={setGraphMinDegree}
 
 
284
  />
285
  <LabeledNumberInput
286
  label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
287
  min={1}
288
  max={30}
289
  value={graphLayoutMaxIterations}
 
290
  onEditFinished={setGraphLayoutMaxIterations}
291
  />
292
  <Separator />
 
8
  import { controlButtonVariant } from '@/lib/constants'
9
  import { useSettingsStore } from '@/stores/settings'
10
 
11
+ import { SettingsIcon, Undo2 } from 'lucide-react'
12
  import { useTranslation } from 'react-i18next';
13
 
14
  /**
 
44
  onEditFinished,
45
  label,
46
  min,
47
+ max,
48
+ defaultValue
49
  }: {
50
  value: number
51
  onEditFinished: (value: number) => void
52
  label: string
53
  min: number
54
  max?: number
55
+ defaultValue?: number
56
  }) => {
57
+ const { t } = useTranslation();
58
  const [currentValue, setCurrentValue] = useState<number | null>(value)
59
 
60
  const onValueChange = useCallback(
 
84
  }
85
  }, [value, currentValue, onEditFinished])
86
 
87
+ const handleReset = useCallback(() => {
88
+ if (defaultValue !== undefined && value !== defaultValue) {
89
+ setCurrentValue(defaultValue)
90
+ onEditFinished(defaultValue)
91
+ }
92
+ }, [defaultValue, value, onEditFinished])
93
+
94
  return (
95
  <div className="flex flex-col gap-2">
96
  <label
 
99
  >
100
  {label}
101
  </label>
102
+ <div className="flex items-center gap-1">
103
+ <Input
104
+ type="number"
105
+ value={currentValue === null ? '' : currentValue}
106
+ onChange={onValueChange}
107
+ className="h-6 w-full min-w-0 pr-1"
108
+ min={min}
109
+ max={max}
110
+ onBlur={onBlur}
111
+ onKeyDown={(e) => {
112
+ if (e.key === 'Enter') {
113
+ onBlur()
114
+ }
115
+ }}
116
+ />
117
+ {defaultValue !== undefined && (
118
+ <Button
119
+ variant="ghost"
120
+ size="icon"
121
+ className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
122
+ onClick={handleReset}
123
+ type="button"
124
+ title={t('graphPanel.sideBar.settings.resetToDefault')}
125
+ >
126
+ <Undo2 className="h-3.5 w-3.5" />
127
+ </Button>
128
+ )}
129
+ </div>
130
  </div>
131
  )
132
  }
 
145
  const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
146
  const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
147
  const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
148
+ const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
149
  const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
150
 
151
  const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
 
204
  }, 300)
205
  }, [])
206
 
207
+ const setGraphMaxNodes = useCallback((nodes: number) => {
208
+ if (nodes < 1 || nodes > 1000) return
209
+ useSettingsStore.setState({ graphMaxNodes: nodes })
210
  const currentLabel = useSettingsStore.getState().queryLabel
211
  useSettingsStore.getState().setQueryLabel('')
212
  setTimeout(() => {
213
  useSettingsStore.getState().setQueryLabel(currentLabel)
214
  }, 300)
 
215
  }, [])
216
 
217
  const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
 
297
  label={t('graphPanel.sideBar.settings.maxQueryDepth')}
298
  min={1}
299
  value={graphQueryMaxDepth}
300
+ defaultValue={3}
301
  onEditFinished={setGraphQueryMaxDepth}
302
  />
303
  <LabeledNumberInput
304
+ label={t('graphPanel.sideBar.settings.maxNodes')}
305
+ min={1}
306
+ max={1000}
307
+ value={graphMaxNodes}
308
+ defaultValue={1000}
309
+ onEditFinished={setGraphMaxNodes}
310
  />
311
  <LabeledNumberInput
312
  label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
313
  min={1}
314
  max={30}
315
  value={graphLayoutMaxIterations}
316
+ defaultValue={15}
317
  onEditFinished={setGraphLayoutMaxIterations}
318
  />
319
  <Separator />