Merge pull request #1237 from danielaskdd/clear-doc
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README-zh.md +4 -4
- README.md +4 -4
- config.ini.example +0 -17
- env.example +20 -27
- examples/lightrag_api_ollama_demo.py +0 -188
- examples/lightrag_api_openai_compatible_demo.py +0 -204
- examples/lightrag_api_oracle_demo.py +0 -267
- examples/lightrag_ollama_gremlin_demo.py +4 -0
- examples/lightrag_oracle_demo.py +0 -141
- examples/lightrag_tidb_demo.py +4 -0
- lightrag/api/README-zh.md +4 -12
- lightrag/api/README.md +5 -13
- lightrag/api/__init__.py +1 -1
- lightrag/api/auth.py +9 -8
- lightrag/api/config.py +335 -0
- lightrag/api/lightrag_server.py +37 -19
- lightrag/api/routers/document_routes.py +465 -52
- lightrag/api/routers/graph_routes.py +10 -14
- lightrag/api/run_with_gunicorn.py +31 -25
- lightrag/api/utils_api.py +18 -353
- lightrag/api/webui/assets/{index-D8zGvNlV.js → index-BaHKTcxB.js} +0 -0
- lightrag/api/webui/assets/index-CD5HxTy1.css +0 -0
- lightrag/api/webui/assets/index-f0HMqdqP.css +0 -0
- lightrag/api/webui/index.html +0 -0
- lightrag/base.py +122 -9
- lightrag/kg/__init__.py +15 -40
- lightrag/kg/age_impl.py +21 -3
- lightrag/kg/chroma_impl.py +28 -2
- lightrag/kg/faiss_impl.py +64 -5
- lightrag/kg/gremlin_impl.py +24 -3
- lightrag/kg/json_doc_status_impl.py +49 -10
- lightrag/kg/json_kv_impl.py +73 -3
- lightrag/kg/milvus_impl.py +31 -1
- lightrag/kg/mongo_impl.py +130 -3
- lightrag/kg/nano_vector_db_impl.py +66 -0
- lightrag/kg/neo4j_impl.py +373 -246
- lightrag/kg/networkx_impl.py +122 -97
- lightrag/kg/oracle_impl.py +0 -1346
- lightrag/kg/postgres_impl.py +380 -317
- lightrag/kg/qdrant_impl.py +93 -6
- lightrag/kg/redis_impl.py +35 -51
- lightrag/kg/tidb_impl.py +172 -3
- lightrag/lightrag.py +25 -42
- lightrag/operate.py +68 -88
- lightrag/types.py +1 -0
- lightrag_webui/src/AppRouter.tsx +6 -1
- lightrag_webui/src/api/lightrag.ts +39 -2
- lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +117 -15
- lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx +20 -2
- lightrag_webui/src/components/graph/Settings.tsx +52 -25
README-zh.md
CHANGED
@@ -11,7 +11,6 @@
|
|
11 |
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
12 |
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
|
13 |
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
|
14 |
-
- [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
|
15 |
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
16 |
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
|
17 |
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
|
@@ -1085,9 +1084,10 @@ rag.clear_cache(modes=["local"])
|
|
1085 |
| **参数** | **类型** | **说明** | **默认值** |
|
1086 |
|--------------|----------|-----------------|-------------|
|
1087 |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
1088 |
-
| **kv_storage** | `str` |
|
1089 |
-
| **vector_storage** | `str` |
|
1090 |
-
| **graph_storage** | `str` |
|
|
|
1091 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
1092 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
1093 |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
|
|
11 |
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
12 |
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
|
13 |
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
|
|
|
14 |
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
15 |
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
|
16 |
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
|
|
|
1084 |
| **参数** | **类型** | **说明** | **默认值** |
|
1085 |
|--------------|----------|-----------------|-------------|
|
1086 |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
1087 |
+
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
1088 |
+
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
1089 |
+
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
1090 |
+
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1091 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
1092 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
1093 |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
README.md
CHANGED
@@ -41,7 +41,6 @@
|
|
41 |
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
42 |
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
43 |
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
44 |
-
- [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
|
45 |
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
46 |
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
47 |
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
@@ -1145,9 +1144,10 @@ Valid modes are:
|
|
1145 |
| **Parameter** | **Type** | **Explanation** | **Default** |
|
1146 |
|--------------|----------|-----------------|-------------|
|
1147 |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
1148 |
-
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage
|
1149 |
-
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage
|
1150 |
-
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage
|
|
|
1151 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
1152 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
1153 |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
|
|
41 |
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
42 |
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
43 |
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
|
|
44 |
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
45 |
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
46 |
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
|
|
1144 |
| **Parameter** | **Type** | **Explanation** | **Default** |
|
1145 |
|--------------|----------|-----------------|-------------|
|
1146 |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
1147 |
+
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
1148 |
+
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
1149 |
+
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
1150 |
+
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1151 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
1152 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
1153 |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
config.ini.example
CHANGED
@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
|
|
13 |
[qdrant]
|
14 |
uri = http://localhost:16333
|
15 |
|
16 |
-
[oracle]
|
17 |
-
dsn = localhost:1521/XEPDB1
|
18 |
-
user = your_username
|
19 |
-
password = your_password
|
20 |
-
config_dir = /path/to/oracle/config
|
21 |
-
wallet_location = /path/to/wallet # 可选
|
22 |
-
wallet_password = your_wallet_password # 可选
|
23 |
-
workspace = default # 可选,默认为default
|
24 |
-
|
25 |
-
[tidb]
|
26 |
-
host = localhost
|
27 |
-
port = 4000
|
28 |
-
user = your_username
|
29 |
-
password = your_password
|
30 |
-
database = your_database
|
31 |
-
workspace = default # 可选,默认为default
|
32 |
-
|
33 |
[postgres]
|
34 |
host = localhost
|
35 |
port = 5432
|
|
|
13 |
[qdrant]
|
14 |
uri = http://localhost:16333
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
[postgres]
|
17 |
host = localhost
|
18 |
port = 5432
|
env.example
CHANGED
@@ -4,11 +4,9 @@
|
|
4 |
# HOST=0.0.0.0
|
5 |
# PORT=9621
|
6 |
# WORKERS=2
|
7 |
-
### separating data from difference Lightrag instances
|
8 |
-
# NAMESPACE_PREFIX=lightrag
|
9 |
-
### Max nodes return from grap retrieval
|
10 |
-
# MAX_GRAPH_NODES=1000
|
11 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
|
|
|
|
12 |
|
13 |
### Optional SSL Configuration
|
14 |
# SSL=true
|
@@ -22,6 +20,9 @@
|
|
22 |
### Ollama Emulating Model Tag
|
23 |
# OLLAMA_EMULATING_MODEL_TAG=latest
|
24 |
|
|
|
|
|
|
|
25 |
### Logging level
|
26 |
# LOG_LEVEL=INFO
|
27 |
# VERBOSE=False
|
@@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
|
|
110 |
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
111 |
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
112 |
|
113 |
-
###
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
#
|
119 |
-
|
120 |
-
|
121 |
-
#ORACLE_WORKSPACE=default
|
122 |
-
|
123 |
-
### TiDB Configuration
|
124 |
-
TIDB_HOST=localhost
|
125 |
-
TIDB_PORT=4000
|
126 |
-
TIDB_USER=your_username
|
127 |
-
TIDB_PASSWORD='your_password'
|
128 |
-
TIDB_DATABASE=your_database
|
129 |
-
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
130 |
-
#TIDB_WORKSPACE=default
|
131 |
|
132 |
### PostgreSQL Configuration
|
133 |
POSTGRES_HOST=localhost
|
@@ -135,8 +126,8 @@ POSTGRES_PORT=5432
|
|
135 |
POSTGRES_USER=your_username
|
136 |
POSTGRES_PASSWORD='your_password'
|
137 |
POSTGRES_DATABASE=your_database
|
138 |
-
### separating all data from difference Lightrag instances(deprecating
|
139 |
-
#POSTGRES_WORKSPACE=default
|
140 |
|
141 |
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
142 |
AGE_POSTGRES_DB=
|
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
|
|
145 |
AGE_POSTGRES_HOST=
|
146 |
# AGE_POSTGRES_PORT=8529
|
147 |
|
148 |
-
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
149 |
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
|
|
150 |
# AGE_GRAPH_NAME=lightrag
|
151 |
|
152 |
### Neo4j Configuration
|
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
|
|
157 |
### MongoDB Configuration
|
158 |
MONGO_URI=mongodb://root:root@localhost:27017/
|
159 |
MONGO_DATABASE=LightRAG
|
160 |
-
### separating all data from difference Lightrag instances(deprecating
|
161 |
# MONGODB_GRAPH=false
|
162 |
|
163 |
### Milvus Configuration
|
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
|
|
177 |
### For JWT Auth
|
178 |
# AUTH_ACCOUNTS='admin:admin123,user1:pass456'
|
179 |
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
|
180 |
-
# TOKEN_EXPIRE_HOURS=
|
|
|
|
|
181 |
|
182 |
### API-Key to access LightRAG Server API
|
183 |
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
|
|
4 |
# HOST=0.0.0.0
|
5 |
# PORT=9621
|
6 |
# WORKERS=2
|
|
|
|
|
|
|
|
|
7 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
8 |
+
WEBUI_TITLE='Graph RAG Engine'
|
9 |
+
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
10 |
|
11 |
### Optional SSL Configuration
|
12 |
# SSL=true
|
|
|
20 |
### Ollama Emulating Model Tag
|
21 |
# OLLAMA_EMULATING_MODEL_TAG=latest
|
22 |
|
23 |
+
### Max nodes return from grap retrieval
|
24 |
+
# MAX_GRAPH_NODES=1000
|
25 |
+
|
26 |
### Logging level
|
27 |
# LOG_LEVEL=INFO
|
28 |
# VERBOSE=False
|
|
|
111 |
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
112 |
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
113 |
|
114 |
+
### TiDB Configuration (Deprecated)
|
115 |
+
# TIDB_HOST=localhost
|
116 |
+
# TIDB_PORT=4000
|
117 |
+
# TIDB_USER=your_username
|
118 |
+
# TIDB_PASSWORD='your_password'
|
119 |
+
# TIDB_DATABASE=your_database
|
120 |
+
### separating all data from difference Lightrag instances(deprecating)
|
121 |
+
# TIDB_WORKSPACE=default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
### PostgreSQL Configuration
|
124 |
POSTGRES_HOST=localhost
|
|
|
126 |
POSTGRES_USER=your_username
|
127 |
POSTGRES_PASSWORD='your_password'
|
128 |
POSTGRES_DATABASE=your_database
|
129 |
+
### separating all data from difference Lightrag instances(deprecating)
|
130 |
+
# POSTGRES_WORKSPACE=default
|
131 |
|
132 |
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
133 |
AGE_POSTGRES_DB=
|
|
|
136 |
AGE_POSTGRES_HOST=
|
137 |
# AGE_POSTGRES_PORT=8529
|
138 |
|
|
|
139 |
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
140 |
+
### AGE_GRAPH_NAME is precated
|
141 |
# AGE_GRAPH_NAME=lightrag
|
142 |
|
143 |
### Neo4j Configuration
|
|
|
148 |
### MongoDB Configuration
|
149 |
MONGO_URI=mongodb://root:root@localhost:27017/
|
150 |
MONGO_DATABASE=LightRAG
|
151 |
+
### separating all data from difference Lightrag instances(deprecating)
|
152 |
# MONGODB_GRAPH=false
|
153 |
|
154 |
### Milvus Configuration
|
|
|
168 |
### For JWT Auth
|
169 |
# AUTH_ACCOUNTS='admin:admin123,user1:pass456'
|
170 |
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
|
171 |
+
# TOKEN_EXPIRE_HOURS=48
|
172 |
+
# GUEST_TOKEN_EXPIRE_HOURS=24
|
173 |
+
# JWT_ALGORITHM=HS256
|
174 |
|
175 |
### API-Key to access LightRAG Server API
|
176 |
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
examples/lightrag_api_ollama_demo.py
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
2 |
-
from contextlib import asynccontextmanager
|
3 |
-
from pydantic import BaseModel
|
4 |
-
import os
|
5 |
-
from lightrag import LightRAG, QueryParam
|
6 |
-
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
|
7 |
-
from lightrag.utils import EmbeddingFunc
|
8 |
-
from typing import Optional
|
9 |
-
import asyncio
|
10 |
-
import nest_asyncio
|
11 |
-
import aiofiles
|
12 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
13 |
-
|
14 |
-
# Apply nest_asyncio to solve event loop issues
|
15 |
-
nest_asyncio.apply()
|
16 |
-
|
17 |
-
DEFAULT_RAG_DIR = "index_default"
|
18 |
-
|
19 |
-
DEFAULT_INPUT_FILE = "book.txt"
|
20 |
-
INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
|
21 |
-
print(f"INPUT_FILE: {INPUT_FILE}")
|
22 |
-
|
23 |
-
# Configure working directory
|
24 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
25 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
26 |
-
|
27 |
-
|
28 |
-
if not os.path.exists(WORKING_DIR):
|
29 |
-
os.mkdir(WORKING_DIR)
|
30 |
-
|
31 |
-
|
32 |
-
async def init():
|
33 |
-
rag = LightRAG(
|
34 |
-
working_dir=WORKING_DIR,
|
35 |
-
llm_model_func=ollama_model_complete,
|
36 |
-
llm_model_name="gemma2:9b",
|
37 |
-
llm_model_max_async=4,
|
38 |
-
llm_model_max_token_size=8192,
|
39 |
-
llm_model_kwargs={
|
40 |
-
"host": "http://localhost:11434",
|
41 |
-
"options": {"num_ctx": 8192},
|
42 |
-
},
|
43 |
-
embedding_func=EmbeddingFunc(
|
44 |
-
embedding_dim=768,
|
45 |
-
max_token_size=8192,
|
46 |
-
func=lambda texts: ollama_embed(
|
47 |
-
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
|
48 |
-
),
|
49 |
-
),
|
50 |
-
)
|
51 |
-
|
52 |
-
# Add initialization code
|
53 |
-
await rag.initialize_storages()
|
54 |
-
await initialize_pipeline_status()
|
55 |
-
|
56 |
-
return rag
|
57 |
-
|
58 |
-
|
59 |
-
@asynccontextmanager
|
60 |
-
async def lifespan(app: FastAPI):
|
61 |
-
global rag
|
62 |
-
rag = await init()
|
63 |
-
print("done!")
|
64 |
-
yield
|
65 |
-
|
66 |
-
|
67 |
-
app = FastAPI(
|
68 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
69 |
-
)
|
70 |
-
|
71 |
-
|
72 |
-
# Data models
|
73 |
-
class QueryRequest(BaseModel):
|
74 |
-
query: str
|
75 |
-
mode: str = "hybrid"
|
76 |
-
only_need_context: bool = False
|
77 |
-
|
78 |
-
|
79 |
-
class InsertRequest(BaseModel):
|
80 |
-
text: str
|
81 |
-
|
82 |
-
|
83 |
-
class Response(BaseModel):
|
84 |
-
status: str
|
85 |
-
data: Optional[str] = None
|
86 |
-
message: Optional[str] = None
|
87 |
-
|
88 |
-
|
89 |
-
# API routes
|
90 |
-
@app.post("/query", response_model=Response)
|
91 |
-
async def query_endpoint(request: QueryRequest):
|
92 |
-
try:
|
93 |
-
loop = asyncio.get_event_loop()
|
94 |
-
result = await loop.run_in_executor(
|
95 |
-
None,
|
96 |
-
lambda: rag.query(
|
97 |
-
request.query,
|
98 |
-
param=QueryParam(
|
99 |
-
mode=request.mode, only_need_context=request.only_need_context
|
100 |
-
),
|
101 |
-
),
|
102 |
-
)
|
103 |
-
return Response(status="success", data=result)
|
104 |
-
except Exception as e:
|
105 |
-
raise HTTPException(status_code=500, detail=str(e))
|
106 |
-
|
107 |
-
|
108 |
-
# insert by text
|
109 |
-
@app.post("/insert", response_model=Response)
|
110 |
-
async def insert_endpoint(request: InsertRequest):
|
111 |
-
try:
|
112 |
-
loop = asyncio.get_event_loop()
|
113 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
114 |
-
return Response(status="success", message="Text inserted successfully")
|
115 |
-
except Exception as e:
|
116 |
-
raise HTTPException(status_code=500, detail=str(e))
|
117 |
-
|
118 |
-
|
119 |
-
# insert by file in payload
|
120 |
-
@app.post("/insert_file", response_model=Response)
|
121 |
-
async def insert_file(file: UploadFile = File(...)):
|
122 |
-
try:
|
123 |
-
file_content = await file.read()
|
124 |
-
# Read file content
|
125 |
-
try:
|
126 |
-
content = file_content.decode("utf-8")
|
127 |
-
except UnicodeDecodeError:
|
128 |
-
# If UTF-8 decoding fails, try other encodings
|
129 |
-
content = file_content.decode("gbk")
|
130 |
-
# Insert file content
|
131 |
-
loop = asyncio.get_event_loop()
|
132 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
133 |
-
|
134 |
-
return Response(
|
135 |
-
status="success",
|
136 |
-
message=f"File content from {file.filename} inserted successfully",
|
137 |
-
)
|
138 |
-
except Exception as e:
|
139 |
-
raise HTTPException(status_code=500, detail=str(e))
|
140 |
-
|
141 |
-
|
142 |
-
# insert by local default file
|
143 |
-
@app.post("/insert_default_file", response_model=Response)
|
144 |
-
@app.get("/insert_default_file", response_model=Response)
|
145 |
-
async def insert_default_file():
|
146 |
-
try:
|
147 |
-
# Read file content from book.txt
|
148 |
-
async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
|
149 |
-
content = await file.read()
|
150 |
-
print(f"read input file {INPUT_FILE} successfully")
|
151 |
-
# Insert file content
|
152 |
-
loop = asyncio.get_event_loop()
|
153 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
154 |
-
|
155 |
-
return Response(
|
156 |
-
status="success",
|
157 |
-
message=f"File content from {INPUT_FILE} inserted successfully",
|
158 |
-
)
|
159 |
-
except Exception as e:
|
160 |
-
raise HTTPException(status_code=500, detail=str(e))
|
161 |
-
|
162 |
-
|
163 |
-
@app.get("/health")
|
164 |
-
async def health_check():
|
165 |
-
return {"status": "healthy"}
|
166 |
-
|
167 |
-
|
168 |
-
if __name__ == "__main__":
|
169 |
-
import uvicorn
|
170 |
-
|
171 |
-
uvicorn.run(app, host="0.0.0.0", port=8020)
|
172 |
-
|
173 |
-
# Usage example
|
174 |
-
# To run the server, use the following command in your terminal:
|
175 |
-
# python lightrag_api_openai_compatible_demo.py
|
176 |
-
|
177 |
-
# Example requests:
|
178 |
-
# 1. Query:
|
179 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
180 |
-
|
181 |
-
# 2. Insert text:
|
182 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
183 |
-
|
184 |
-
# 3. Insert file:
|
185 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
186 |
-
|
187 |
-
# 4. Health check:
|
188 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_api_openai_compatible_demo.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
2 |
-
from contextlib import asynccontextmanager
|
3 |
-
from pydantic import BaseModel
|
4 |
-
import os
|
5 |
-
from lightrag import LightRAG, QueryParam
|
6 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
7 |
-
from lightrag.utils import EmbeddingFunc
|
8 |
-
import numpy as np
|
9 |
-
from typing import Optional
|
10 |
-
import asyncio
|
11 |
-
import nest_asyncio
|
12 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
13 |
-
|
14 |
-
# Apply nest_asyncio to solve event loop issues
|
15 |
-
nest_asyncio.apply()
|
16 |
-
|
17 |
-
DEFAULT_RAG_DIR = "index_default"
|
18 |
-
app = FastAPI(title="LightRAG API", description="API for RAG operations")
|
19 |
-
|
20 |
-
# Configure working directory
|
21 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
22 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
23 |
-
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
24 |
-
print(f"LLM_MODEL: {LLM_MODEL}")
|
25 |
-
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
26 |
-
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
27 |
-
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
28 |
-
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
29 |
-
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
|
30 |
-
print(f"BASE_URL: {BASE_URL}")
|
31 |
-
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
|
32 |
-
print(f"API_KEY: {API_KEY}")
|
33 |
-
|
34 |
-
if not os.path.exists(WORKING_DIR):
|
35 |
-
os.mkdir(WORKING_DIR)
|
36 |
-
|
37 |
-
|
38 |
-
# LLM model function
|
39 |
-
|
40 |
-
|
41 |
-
async def llm_model_func(
|
42 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
43 |
-
) -> str:
|
44 |
-
return await openai_complete_if_cache(
|
45 |
-
model=LLM_MODEL,
|
46 |
-
prompt=prompt,
|
47 |
-
system_prompt=system_prompt,
|
48 |
-
history_messages=history_messages,
|
49 |
-
base_url=BASE_URL,
|
50 |
-
api_key=API_KEY,
|
51 |
-
**kwargs,
|
52 |
-
)
|
53 |
-
|
54 |
-
|
55 |
-
# Embedding function
|
56 |
-
|
57 |
-
|
58 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
59 |
-
return await openai_embed(
|
60 |
-
texts=texts,
|
61 |
-
model=EMBEDDING_MODEL,
|
62 |
-
base_url=BASE_URL,
|
63 |
-
api_key=API_KEY,
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
async def get_embedding_dim():
|
68 |
-
test_text = ["This is a test sentence."]
|
69 |
-
embedding = await embedding_func(test_text)
|
70 |
-
embedding_dim = embedding.shape[1]
|
71 |
-
print(f"{embedding_dim=}")
|
72 |
-
return embedding_dim
|
73 |
-
|
74 |
-
|
75 |
-
# Initialize RAG instance
|
76 |
-
async def init():
|
77 |
-
embedding_dimension = await get_embedding_dim()
|
78 |
-
|
79 |
-
rag = LightRAG(
|
80 |
-
working_dir=WORKING_DIR,
|
81 |
-
llm_model_func=llm_model_func,
|
82 |
-
embedding_func=EmbeddingFunc(
|
83 |
-
embedding_dim=embedding_dimension,
|
84 |
-
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
85 |
-
func=embedding_func,
|
86 |
-
),
|
87 |
-
)
|
88 |
-
|
89 |
-
await rag.initialize_storages()
|
90 |
-
await initialize_pipeline_status()
|
91 |
-
|
92 |
-
return rag
|
93 |
-
|
94 |
-
|
95 |
-
@asynccontextmanager
|
96 |
-
async def lifespan(app: FastAPI):
|
97 |
-
global rag
|
98 |
-
rag = await init()
|
99 |
-
print("done!")
|
100 |
-
yield
|
101 |
-
|
102 |
-
|
103 |
-
app = FastAPI(
|
104 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
105 |
-
)
|
106 |
-
|
107 |
-
# Data models
|
108 |
-
|
109 |
-
|
110 |
-
class QueryRequest(BaseModel):
|
111 |
-
query: str
|
112 |
-
mode: str = "hybrid"
|
113 |
-
only_need_context: bool = False
|
114 |
-
|
115 |
-
|
116 |
-
class InsertRequest(BaseModel):
|
117 |
-
text: str
|
118 |
-
|
119 |
-
|
120 |
-
class Response(BaseModel):
|
121 |
-
status: str
|
122 |
-
data: Optional[str] = None
|
123 |
-
message: Optional[str] = None
|
124 |
-
|
125 |
-
|
126 |
-
# API routes
|
127 |
-
|
128 |
-
|
129 |
-
@app.post("/query", response_model=Response)
|
130 |
-
async def query_endpoint(request: QueryRequest):
|
131 |
-
try:
|
132 |
-
loop = asyncio.get_event_loop()
|
133 |
-
result = await loop.run_in_executor(
|
134 |
-
None,
|
135 |
-
lambda: rag.query(
|
136 |
-
request.query,
|
137 |
-
param=QueryParam(
|
138 |
-
mode=request.mode, only_need_context=request.only_need_context
|
139 |
-
),
|
140 |
-
),
|
141 |
-
)
|
142 |
-
return Response(status="success", data=result)
|
143 |
-
except Exception as e:
|
144 |
-
raise HTTPException(status_code=500, detail=str(e))
|
145 |
-
|
146 |
-
|
147 |
-
@app.post("/insert", response_model=Response)
|
148 |
-
async def insert_endpoint(request: InsertRequest):
|
149 |
-
try:
|
150 |
-
loop = asyncio.get_event_loop()
|
151 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
152 |
-
return Response(status="success", message="Text inserted successfully")
|
153 |
-
except Exception as e:
|
154 |
-
raise HTTPException(status_code=500, detail=str(e))
|
155 |
-
|
156 |
-
|
157 |
-
@app.post("/insert_file", response_model=Response)
|
158 |
-
async def insert_file(file: UploadFile = File(...)):
|
159 |
-
try:
|
160 |
-
file_content = await file.read()
|
161 |
-
# Read file content
|
162 |
-
try:
|
163 |
-
content = file_content.decode("utf-8")
|
164 |
-
except UnicodeDecodeError:
|
165 |
-
# If UTF-8 decoding fails, try other encodings
|
166 |
-
content = file_content.decode("gbk")
|
167 |
-
# Insert file content
|
168 |
-
loop = asyncio.get_event_loop()
|
169 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
170 |
-
|
171 |
-
return Response(
|
172 |
-
status="success",
|
173 |
-
message=f"File content from {file.filename} inserted successfully",
|
174 |
-
)
|
175 |
-
except Exception as e:
|
176 |
-
raise HTTPException(status_code=500, detail=str(e))
|
177 |
-
|
178 |
-
|
179 |
-
@app.get("/health")
|
180 |
-
async def health_check():
|
181 |
-
return {"status": "healthy"}
|
182 |
-
|
183 |
-
|
184 |
-
if __name__ == "__main__":
|
185 |
-
import uvicorn
|
186 |
-
|
187 |
-
uvicorn.run(app, host="0.0.0.0", port=8020)
|
188 |
-
|
189 |
-
# Usage example
|
190 |
-
# To run the server, use the following command in your terminal:
|
191 |
-
# python lightrag_api_openai_compatible_demo.py
|
192 |
-
|
193 |
-
# Example requests:
|
194 |
-
# 1. Query:
|
195 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
196 |
-
|
197 |
-
# 2. Insert text:
|
198 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
199 |
-
|
200 |
-
# 3. Insert file:
|
201 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
202 |
-
|
203 |
-
# 4. Health check:
|
204 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_api_oracle_demo.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
2 |
-
from fastapi import Query
|
3 |
-
from contextlib import asynccontextmanager
|
4 |
-
from pydantic import BaseModel
|
5 |
-
from typing import Optional, Any
|
6 |
-
|
7 |
-
import sys
|
8 |
-
import os
|
9 |
-
|
10 |
-
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
import asyncio
|
14 |
-
import nest_asyncio
|
15 |
-
from lightrag import LightRAG, QueryParam
|
16 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
17 |
-
from lightrag.utils import EmbeddingFunc
|
18 |
-
import numpy as np
|
19 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
20 |
-
|
21 |
-
|
22 |
-
print(os.getcwd())
|
23 |
-
script_directory = Path(__file__).resolve().parent.parent
|
24 |
-
sys.path.append(os.path.abspath(script_directory))
|
25 |
-
|
26 |
-
|
27 |
-
# Apply nest_asyncio to solve event loop issues
|
28 |
-
nest_asyncio.apply()
|
29 |
-
|
30 |
-
DEFAULT_RAG_DIR = "index_default"
|
31 |
-
|
32 |
-
|
33 |
-
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
34 |
-
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
35 |
-
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
36 |
-
APIKEY = "ocigenerativeai"
|
37 |
-
|
38 |
-
# Configure working directory
|
39 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
40 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
41 |
-
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
|
42 |
-
print(f"LLM_MODEL: {LLM_MODEL}")
|
43 |
-
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
|
44 |
-
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
45 |
-
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
|
46 |
-
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
47 |
-
|
48 |
-
if not os.path.exists(WORKING_DIR):
|
49 |
-
os.mkdir(WORKING_DIR)
|
50 |
-
|
51 |
-
os.environ["ORACLE_USER"] = ""
|
52 |
-
os.environ["ORACLE_PASSWORD"] = ""
|
53 |
-
os.environ["ORACLE_DSN"] = ""
|
54 |
-
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
55 |
-
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
56 |
-
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
57 |
-
os.environ["ORACLE_WORKSPACE"] = "company"
|
58 |
-
|
59 |
-
|
60 |
-
async def llm_model_func(
|
61 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
62 |
-
) -> str:
|
63 |
-
return await openai_complete_if_cache(
|
64 |
-
LLM_MODEL,
|
65 |
-
prompt,
|
66 |
-
system_prompt=system_prompt,
|
67 |
-
history_messages=history_messages,
|
68 |
-
api_key=APIKEY,
|
69 |
-
base_url=BASE_URL,
|
70 |
-
**kwargs,
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
75 |
-
return await openai_embed(
|
76 |
-
texts,
|
77 |
-
model=EMBEDDING_MODEL,
|
78 |
-
api_key=APIKEY,
|
79 |
-
base_url=BASE_URL,
|
80 |
-
)
|
81 |
-
|
82 |
-
|
83 |
-
async def get_embedding_dim():
|
84 |
-
test_text = ["This is a test sentence."]
|
85 |
-
embedding = await embedding_func(test_text)
|
86 |
-
embedding_dim = embedding.shape[1]
|
87 |
-
return embedding_dim
|
88 |
-
|
89 |
-
|
90 |
-
async def init():
|
91 |
-
# Detect embedding dimension
|
92 |
-
embedding_dimension = await get_embedding_dim()
|
93 |
-
print(f"Detected embedding dimension: {embedding_dimension}")
|
94 |
-
# Create Oracle DB connection
|
95 |
-
# The `config` parameter is the connection configuration of Oracle DB
|
96 |
-
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
97 |
-
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
98 |
-
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
99 |
-
|
100 |
-
# Initialize LightRAG
|
101 |
-
# We use Oracle DB as the KV/vector/graph storage
|
102 |
-
rag = LightRAG(
|
103 |
-
enable_llm_cache=False,
|
104 |
-
working_dir=WORKING_DIR,
|
105 |
-
chunk_token_size=512,
|
106 |
-
llm_model_func=llm_model_func,
|
107 |
-
embedding_func=EmbeddingFunc(
|
108 |
-
embedding_dim=embedding_dimension,
|
109 |
-
max_token_size=512,
|
110 |
-
func=embedding_func,
|
111 |
-
),
|
112 |
-
graph_storage="OracleGraphStorage",
|
113 |
-
kv_storage="OracleKVStorage",
|
114 |
-
vector_storage="OracleVectorDBStorage",
|
115 |
-
)
|
116 |
-
|
117 |
-
await rag.initialize_storages()
|
118 |
-
await initialize_pipeline_status()
|
119 |
-
|
120 |
-
return rag
|
121 |
-
|
122 |
-
|
123 |
-
# Extract and Insert into LightRAG storage
|
124 |
-
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
|
125 |
-
# await rag.ainsert(f.read())
|
126 |
-
|
127 |
-
# # Perform search in different modes
|
128 |
-
# modes = ["naive", "local", "global", "hybrid"]
|
129 |
-
# for mode in modes:
|
130 |
-
# print("="*20, mode, "="*20)
|
131 |
-
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
|
132 |
-
# print("-"*100, "\n")
|
133 |
-
|
134 |
-
# Data models
|
135 |
-
|
136 |
-
|
137 |
-
class QueryRequest(BaseModel):
|
138 |
-
query: str
|
139 |
-
mode: str = "hybrid"
|
140 |
-
only_need_context: bool = False
|
141 |
-
only_need_prompt: bool = False
|
142 |
-
|
143 |
-
|
144 |
-
class DataRequest(BaseModel):
|
145 |
-
limit: int = 100
|
146 |
-
|
147 |
-
|
148 |
-
class InsertRequest(BaseModel):
|
149 |
-
text: str
|
150 |
-
|
151 |
-
|
152 |
-
class Response(BaseModel):
|
153 |
-
status: str
|
154 |
-
data: Optional[Any] = None
|
155 |
-
message: Optional[str] = None
|
156 |
-
|
157 |
-
|
158 |
-
# API routes
|
159 |
-
|
160 |
-
rag = None
|
161 |
-
|
162 |
-
|
163 |
-
@asynccontextmanager
|
164 |
-
async def lifespan(app: FastAPI):
|
165 |
-
global rag
|
166 |
-
rag = await init()
|
167 |
-
print("done!")
|
168 |
-
yield
|
169 |
-
|
170 |
-
|
171 |
-
app = FastAPI(
|
172 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
173 |
-
)
|
174 |
-
|
175 |
-
|
176 |
-
@app.post("/query", response_model=Response)
|
177 |
-
async def query_endpoint(request: QueryRequest):
|
178 |
-
# try:
|
179 |
-
# loop = asyncio.get_event_loop()
|
180 |
-
if request.mode == "naive":
|
181 |
-
top_k = 3
|
182 |
-
else:
|
183 |
-
top_k = 60
|
184 |
-
result = await rag.aquery(
|
185 |
-
request.query,
|
186 |
-
param=QueryParam(
|
187 |
-
mode=request.mode,
|
188 |
-
only_need_context=request.only_need_context,
|
189 |
-
only_need_prompt=request.only_need_prompt,
|
190 |
-
top_k=top_k,
|
191 |
-
),
|
192 |
-
)
|
193 |
-
return Response(status="success", data=result)
|
194 |
-
# except Exception as e:
|
195 |
-
# raise HTTPException(status_code=500, detail=str(e))
|
196 |
-
|
197 |
-
|
198 |
-
@app.get("/data", response_model=Response)
|
199 |
-
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
|
200 |
-
if type == "nodes":
|
201 |
-
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
|
202 |
-
elif type == "edges":
|
203 |
-
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
|
204 |
-
elif type == "statistics":
|
205 |
-
result = await rag.chunk_entity_relation_graph.get_statistics()
|
206 |
-
return Response(status="success", data=result)
|
207 |
-
|
208 |
-
|
209 |
-
@app.post("/insert", response_model=Response)
|
210 |
-
async def insert_endpoint(request: InsertRequest):
|
211 |
-
try:
|
212 |
-
loop = asyncio.get_event_loop()
|
213 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
214 |
-
return Response(status="success", message="Text inserted successfully")
|
215 |
-
except Exception as e:
|
216 |
-
raise HTTPException(status_code=500, detail=str(e))
|
217 |
-
|
218 |
-
|
219 |
-
@app.post("/insert_file", response_model=Response)
|
220 |
-
async def insert_file(file: UploadFile = File(...)):
|
221 |
-
try:
|
222 |
-
file_content = await file.read()
|
223 |
-
# Read file content
|
224 |
-
try:
|
225 |
-
content = file_content.decode("utf-8")
|
226 |
-
except UnicodeDecodeError:
|
227 |
-
# If UTF-8 decoding fails, try other encodings
|
228 |
-
content = file_content.decode("gbk")
|
229 |
-
# Insert file content
|
230 |
-
loop = asyncio.get_event_loop()
|
231 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
232 |
-
|
233 |
-
return Response(
|
234 |
-
status="success",
|
235 |
-
message=f"File content from {file.filename} inserted successfully",
|
236 |
-
)
|
237 |
-
except Exception as e:
|
238 |
-
raise HTTPException(status_code=500, detail=str(e))
|
239 |
-
|
240 |
-
|
241 |
-
@app.get("/health")
|
242 |
-
async def health_check():
|
243 |
-
return {"status": "healthy"}
|
244 |
-
|
245 |
-
|
246 |
-
if __name__ == "__main__":
|
247 |
-
import uvicorn
|
248 |
-
|
249 |
-
uvicorn.run(app, host="127.0.0.1", port=8020)
|
250 |
-
|
251 |
-
# Usage example
|
252 |
-
# To run the server, use the following command in your terminal:
|
253 |
-
# python lightrag_api_openai_compatible_demo.py
|
254 |
-
|
255 |
-
# Example requests:
|
256 |
-
# 1. Query:
|
257 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
258 |
-
|
259 |
-
# 2. Insert text:
|
260 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
261 |
-
|
262 |
-
# 3. Insert file:
|
263 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
264 |
-
|
265 |
-
|
266 |
-
# 4. Health check:
|
267 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_ollama_gremlin_demo.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import inspect
|
3 |
import os
|
|
|
1 |
+
##############################################
|
2 |
+
# Gremlin storage implementation is deprecated
|
3 |
+
##############################################
|
4 |
+
|
5 |
import asyncio
|
6 |
import inspect
|
7 |
import os
|
examples/lightrag_oracle_demo.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
import asyncio
|
5 |
-
from lightrag import LightRAG, QueryParam
|
6 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
7 |
-
from lightrag.utils import EmbeddingFunc
|
8 |
-
import numpy as np
|
9 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
10 |
-
|
11 |
-
print(os.getcwd())
|
12 |
-
script_directory = Path(__file__).resolve().parent.parent
|
13 |
-
sys.path.append(os.path.abspath(script_directory))
|
14 |
-
|
15 |
-
WORKING_DIR = "./dickens"
|
16 |
-
|
17 |
-
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
18 |
-
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
19 |
-
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
20 |
-
APIKEY = "ocigenerativeai"
|
21 |
-
CHATMODEL = "cohere.command-r-plus"
|
22 |
-
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
23 |
-
CHUNK_TOKEN_SIZE = 1024
|
24 |
-
MAX_TOKENS = 4000
|
25 |
-
|
26 |
-
if not os.path.exists(WORKING_DIR):
|
27 |
-
os.mkdir(WORKING_DIR)
|
28 |
-
|
29 |
-
os.environ["ORACLE_USER"] = "username"
|
30 |
-
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
|
31 |
-
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
|
32 |
-
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
33 |
-
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
34 |
-
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
35 |
-
os.environ["ORACLE_WORKSPACE"] = "company"
|
36 |
-
|
37 |
-
|
38 |
-
async def llm_model_func(
|
39 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
40 |
-
) -> str:
|
41 |
-
return await openai_complete_if_cache(
|
42 |
-
CHATMODEL,
|
43 |
-
prompt,
|
44 |
-
system_prompt=system_prompt,
|
45 |
-
history_messages=history_messages,
|
46 |
-
api_key=APIKEY,
|
47 |
-
base_url=BASE_URL,
|
48 |
-
**kwargs,
|
49 |
-
)
|
50 |
-
|
51 |
-
|
52 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
53 |
-
return await openai_embed(
|
54 |
-
texts,
|
55 |
-
model=EMBEDMODEL,
|
56 |
-
api_key=APIKEY,
|
57 |
-
base_url=BASE_URL,
|
58 |
-
)
|
59 |
-
|
60 |
-
|
61 |
-
async def get_embedding_dim():
|
62 |
-
test_text = ["This is a test sentence."]
|
63 |
-
embedding = await embedding_func(test_text)
|
64 |
-
embedding_dim = embedding.shape[1]
|
65 |
-
return embedding_dim
|
66 |
-
|
67 |
-
|
68 |
-
async def initialize_rag():
|
69 |
-
# Detect embedding dimension
|
70 |
-
embedding_dimension = await get_embedding_dim()
|
71 |
-
print(f"Detected embedding dimension: {embedding_dimension}")
|
72 |
-
|
73 |
-
# Initialize LightRAG
|
74 |
-
# We use Oracle DB as the KV/vector/graph storage
|
75 |
-
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
76 |
-
rag = LightRAG(
|
77 |
-
# log_level="DEBUG",
|
78 |
-
working_dir=WORKING_DIR,
|
79 |
-
entity_extract_max_gleaning=1,
|
80 |
-
enable_llm_cache=True,
|
81 |
-
enable_llm_cache_for_entity_extract=True,
|
82 |
-
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
|
83 |
-
chunk_token_size=CHUNK_TOKEN_SIZE,
|
84 |
-
llm_model_max_token_size=MAX_TOKENS,
|
85 |
-
llm_model_func=llm_model_func,
|
86 |
-
embedding_func=EmbeddingFunc(
|
87 |
-
embedding_dim=embedding_dimension,
|
88 |
-
max_token_size=500,
|
89 |
-
func=embedding_func,
|
90 |
-
),
|
91 |
-
graph_storage="OracleGraphStorage",
|
92 |
-
kv_storage="OracleKVStorage",
|
93 |
-
vector_storage="OracleVectorDBStorage",
|
94 |
-
addon_params={
|
95 |
-
"example_number": 1,
|
96 |
-
"language": "Simplfied Chinese",
|
97 |
-
"entity_types": ["organization", "person", "geo", "event"],
|
98 |
-
"insert_batch_size": 2,
|
99 |
-
},
|
100 |
-
)
|
101 |
-
await rag.initialize_storages()
|
102 |
-
await initialize_pipeline_status()
|
103 |
-
|
104 |
-
return rag
|
105 |
-
|
106 |
-
|
107 |
-
async def main():
|
108 |
-
try:
|
109 |
-
# Initialize RAG instance
|
110 |
-
rag = await initialize_rag()
|
111 |
-
|
112 |
-
# Extract and Insert into LightRAG storage
|
113 |
-
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
114 |
-
all_text = f.read()
|
115 |
-
texts = [x for x in all_text.split("\n") if x]
|
116 |
-
|
117 |
-
# New mode use pipeline
|
118 |
-
await rag.apipeline_enqueue_documents(texts)
|
119 |
-
await rag.apipeline_process_enqueue_documents()
|
120 |
-
|
121 |
-
# Old method use ainsert
|
122 |
-
# await rag.ainsert(texts)
|
123 |
-
|
124 |
-
# Perform search in different modes
|
125 |
-
modes = ["naive", "local", "global", "hybrid"]
|
126 |
-
for mode in modes:
|
127 |
-
print("=" * 20, mode, "=" * 20)
|
128 |
-
print(
|
129 |
-
await rag.aquery(
|
130 |
-
"What are the top themes in this story?",
|
131 |
-
param=QueryParam(mode=mode),
|
132 |
-
)
|
133 |
-
)
|
134 |
-
print("-" * 100, "\n")
|
135 |
-
|
136 |
-
except Exception as e:
|
137 |
-
print(f"An error occurred: {e}")
|
138 |
-
|
139 |
-
|
140 |
-
if __name__ == "__main__":
|
141 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_tidb_demo.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
|
|
|
1 |
+
###########################################
|
2 |
+
# TiDB storage implementation is deprecated
|
3 |
+
###########################################
|
4 |
+
|
5 |
import asyncio
|
6 |
import os
|
7 |
|
lightrag/api/README-zh.md
CHANGED
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
|
|
291 |
|
292 |
```
|
293 |
JsonKVStorage JsonFile(默认)
|
294 |
-
MongoKVStorage MogonDB
|
295 |
-
RedisKVStorage Redis
|
296 |
-
TiDBKVStorage TiDB
|
297 |
PGKVStorage Postgres
|
298 |
-
|
|
|
299 |
```
|
300 |
|
301 |
* GRAPH_STORAGE 支持的实现名称
|
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
|
|
303 |
```
|
304 |
NetworkXStorage NetworkX(默认)
|
305 |
Neo4JStorage Neo4J
|
306 |
-
MongoGraphStorage MongoDB
|
307 |
-
TiDBGraphStorage TiDB
|
308 |
-
AGEStorage AGE
|
309 |
-
GremlinStorage Gremlin
|
310 |
PGGraphStorage Postgres
|
311 |
-
|
312 |
```
|
313 |
|
314 |
* VECTOR_STORAGE 支持的实现名称
|
315 |
|
316 |
```
|
317 |
NanoVectorDBStorage NanoVector(默认)
|
|
|
318 |
MilvusVectorDBStorge Milvus
|
319 |
ChromaVectorDBStorage Chroma
|
320 |
-
TiDBVectorDBStorage TiDB
|
321 |
-
PGVectorStorage Postgres
|
322 |
FaissVectorDBStorage Faiss
|
323 |
QdrantVectorDBStorage Qdrant
|
324 |
-
OracleVectorDBStorage Oracle
|
325 |
MongoVectorDBStorage MongoDB
|
326 |
```
|
327 |
|
|
|
291 |
|
292 |
```
|
293 |
JsonKVStorage JsonFile(默认)
|
|
|
|
|
|
|
294 |
PGKVStorage Postgres
|
295 |
+
RedisKVStorage Redis
|
296 |
+
MongoKVStorage MogonDB
|
297 |
```
|
298 |
|
299 |
* GRAPH_STORAGE 支持的实现名称
|
|
|
301 |
```
|
302 |
NetworkXStorage NetworkX(默认)
|
303 |
Neo4JStorage Neo4J
|
|
|
|
|
|
|
|
|
304 |
PGGraphStorage Postgres
|
305 |
+
AGEStorage AGE
|
306 |
```
|
307 |
|
308 |
* VECTOR_STORAGE 支持的实现名称
|
309 |
|
310 |
```
|
311 |
NanoVectorDBStorage NanoVector(默认)
|
312 |
+
PGVectorStorage Postgres
|
313 |
MilvusVectorDBStorge Milvus
|
314 |
ChromaVectorDBStorage Chroma
|
|
|
|
|
315 |
FaissVectorDBStorage Faiss
|
316 |
QdrantVectorDBStorage Qdrant
|
|
|
317 |
MongoVectorDBStorage MongoDB
|
318 |
```
|
319 |
|
lightrag/api/README.md
CHANGED
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
|
|
302 |
|
303 |
```
|
304 |
JsonKVStorage JsonFile(default)
|
305 |
-
MongoKVStorage MogonDB
|
306 |
-
RedisKVStorage Redis
|
307 |
-
TiDBKVStorage TiDB
|
308 |
PGKVStorage Postgres
|
309 |
-
|
|
|
310 |
```
|
311 |
|
312 |
* GRAPH_STORAGE supported implement-name
|
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
|
|
314 |
```
|
315 |
NetworkXStorage NetworkX(defualt)
|
316 |
Neo4JStorage Neo4J
|
317 |
-
MongoGraphStorage MongoDB
|
318 |
-
TiDBGraphStorage TiDB
|
319 |
-
AGEStorage AGE
|
320 |
-
GremlinStorage Gremlin
|
321 |
PGGraphStorage Postgres
|
322 |
-
|
323 |
```
|
324 |
|
325 |
* VECTOR_STORAGE supported implement-name
|
326 |
|
327 |
```
|
328 |
NanoVectorDBStorage NanoVector(default)
|
329 |
-
MilvusVectorDBStorage Milvus
|
330 |
-
ChromaVectorDBStorage Chroma
|
331 |
-
TiDBVectorDBStorage TiDB
|
332 |
PGVectorStorage Postgres
|
|
|
|
|
333 |
FaissVectorDBStorage Faiss
|
334 |
QdrantVectorDBStorage Qdrant
|
335 |
-
OracleVectorDBStorage Oracle
|
336 |
MongoVectorDBStorage MongoDB
|
337 |
```
|
338 |
|
|
|
302 |
|
303 |
```
|
304 |
JsonKVStorage JsonFile(default)
|
|
|
|
|
|
|
305 |
PGKVStorage Postgres
|
306 |
+
RedisKVStorage Redis
|
307 |
+
MongoKVStorage MogonDB
|
308 |
```
|
309 |
|
310 |
* GRAPH_STORAGE supported implement-name
|
|
|
312 |
```
|
313 |
NetworkXStorage NetworkX(defualt)
|
314 |
Neo4JStorage Neo4J
|
|
|
|
|
|
|
|
|
315 |
PGGraphStorage Postgres
|
316 |
+
AGEStorage AGE
|
317 |
```
|
318 |
|
319 |
* VECTOR_STORAGE supported implement-name
|
320 |
|
321 |
```
|
322 |
NanoVectorDBStorage NanoVector(default)
|
|
|
|
|
|
|
323 |
PGVectorStorage Postgres
|
324 |
+
MilvusVectorDBStorge Milvus
|
325 |
+
ChromaVectorDBStorage Chroma
|
326 |
FaissVectorDBStorage Faiss
|
327 |
QdrantVectorDBStorage Qdrant
|
|
|
328 |
MongoVectorDBStorage MongoDB
|
329 |
```
|
330 |
|
lightrag/api/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
__api_version__ = "
|
|
|
1 |
+
__api_version__ = "0132"
|
lightrag/api/auth.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
import os
|
2 |
from datetime import datetime, timedelta
|
|
|
3 |
import jwt
|
|
|
4 |
from fastapi import HTTPException, status
|
5 |
from pydantic import BaseModel
|
6 |
-
|
|
|
7 |
|
8 |
# use the .env that is inside the current folder
|
9 |
# allows to use different .env file for each lightrag instance
|
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
|
|
20 |
|
21 |
class AuthHandler:
|
22 |
def __init__(self):
|
23 |
-
self.secret =
|
24 |
-
self.algorithm =
|
25 |
-
self.expire_hours =
|
26 |
-
self.guest_expire_hours =
|
27 |
-
|
28 |
self.accounts = {}
|
29 |
-
auth_accounts =
|
30 |
if auth_accounts:
|
31 |
for account in auth_accounts.split(","):
|
32 |
username, password = account.split(":", 1)
|
|
|
|
|
1 |
from datetime import datetime, timedelta
|
2 |
+
|
3 |
import jwt
|
4 |
+
from dotenv import load_dotenv
|
5 |
from fastapi import HTTPException, status
|
6 |
from pydantic import BaseModel
|
7 |
+
|
8 |
+
from .config import global_args
|
9 |
|
10 |
# use the .env that is inside the current folder
|
11 |
# allows to use different .env file for each lightrag instance
|
|
|
22 |
|
23 |
class AuthHandler:
|
24 |
def __init__(self):
|
25 |
+
self.secret = global_args.token_secret
|
26 |
+
self.algorithm = global_args.jwt_algorithm
|
27 |
+
self.expire_hours = global_args.token_expire_hours
|
28 |
+
self.guest_expire_hours = global_args.guest_token_expire_hours
|
|
|
29 |
self.accounts = {}
|
30 |
+
auth_accounts = global_args.auth_accounts
|
31 |
if auth_accounts:
|
32 |
for account in auth_accounts.split(","):
|
33 |
username, password = account.split(":", 1)
|
lightrag/api/config.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configs for the LightRAG API.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
# use the .env that is inside the current folder
|
11 |
+
# allows to use different .env file for each lightrag instance
|
12 |
+
# the OS environment variables take precedence over the .env file
|
13 |
+
load_dotenv(dotenv_path=".env", override=False)
|
14 |
+
|
15 |
+
|
16 |
+
class OllamaServerInfos:
|
17 |
+
# Constants for emulated Ollama model information
|
18 |
+
LIGHTRAG_NAME = "lightrag"
|
19 |
+
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
20 |
+
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
21 |
+
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
22 |
+
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
23 |
+
LIGHTRAG_DIGEST = "sha256:lightrag"
|
24 |
+
|
25 |
+
|
26 |
+
ollama_server_infos = OllamaServerInfos()
|
27 |
+
|
28 |
+
|
29 |
+
class DefaultRAGStorageConfig:
|
30 |
+
KV_STORAGE = "JsonKVStorage"
|
31 |
+
VECTOR_STORAGE = "NanoVectorDBStorage"
|
32 |
+
GRAPH_STORAGE = "NetworkXStorage"
|
33 |
+
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
34 |
+
|
35 |
+
|
36 |
+
def get_default_host(binding_type: str) -> str:
|
37 |
+
default_hosts = {
|
38 |
+
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
39 |
+
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
40 |
+
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
41 |
+
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
42 |
+
}
|
43 |
+
return default_hosts.get(
|
44 |
+
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
45 |
+
) # fallback to ollama if unknown
|
46 |
+
|
47 |
+
|
48 |
+
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
49 |
+
"""
|
50 |
+
Get value from environment variable with type conversion
|
51 |
+
|
52 |
+
Args:
|
53 |
+
env_key (str): Environment variable key
|
54 |
+
default (any): Default value if env variable is not set
|
55 |
+
value_type (type): Type to convert the value to
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
any: Converted value from environment or default
|
59 |
+
"""
|
60 |
+
value = os.getenv(env_key)
|
61 |
+
if value is None:
|
62 |
+
return default
|
63 |
+
|
64 |
+
if value_type is bool:
|
65 |
+
return value.lower() in ("true", "1", "yes", "t", "on")
|
66 |
+
try:
|
67 |
+
return value_type(value)
|
68 |
+
except ValueError:
|
69 |
+
return default
|
70 |
+
|
71 |
+
|
72 |
+
def parse_args() -> argparse.Namespace:
|
73 |
+
"""
|
74 |
+
Parse command line arguments with environment variable fallback
|
75 |
+
|
76 |
+
Args:
|
77 |
+
is_uvicorn_mode: Whether running under uvicorn mode
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
argparse.Namespace: Parsed arguments
|
81 |
+
"""
|
82 |
+
|
83 |
+
parser = argparse.ArgumentParser(
|
84 |
+
description="LightRAG FastAPI Server with separate working and input directories"
|
85 |
+
)
|
86 |
+
|
87 |
+
# Server configuration
|
88 |
+
parser.add_argument(
|
89 |
+
"--host",
|
90 |
+
default=get_env_value("HOST", "0.0.0.0"),
|
91 |
+
help="Server host (default: from env or 0.0.0.0)",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--port",
|
95 |
+
type=int,
|
96 |
+
default=get_env_value("PORT", 9621, int),
|
97 |
+
help="Server port (default: from env or 9621)",
|
98 |
+
)
|
99 |
+
|
100 |
+
# Directory configuration
|
101 |
+
parser.add_argument(
|
102 |
+
"--working-dir",
|
103 |
+
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
104 |
+
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--input-dir",
|
108 |
+
default=get_env_value("INPUT_DIR", "./inputs"),
|
109 |
+
help="Directory containing input documents (default: from env or ./inputs)",
|
110 |
+
)
|
111 |
+
|
112 |
+
def timeout_type(value):
|
113 |
+
if value is None:
|
114 |
+
return 150
|
115 |
+
if value is None or value == "None":
|
116 |
+
return None
|
117 |
+
return int(value)
|
118 |
+
|
119 |
+
parser.add_argument(
|
120 |
+
"--timeout",
|
121 |
+
default=get_env_value("TIMEOUT", None, timeout_type),
|
122 |
+
type=timeout_type,
|
123 |
+
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
124 |
+
)
|
125 |
+
|
126 |
+
# RAG configuration
|
127 |
+
parser.add_argument(
|
128 |
+
"--max-async",
|
129 |
+
type=int,
|
130 |
+
default=get_env_value("MAX_ASYNC", 4, int),
|
131 |
+
help="Maximum async operations (default: from env or 4)",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--max-tokens",
|
135 |
+
type=int,
|
136 |
+
default=get_env_value("MAX_TOKENS", 32768, int),
|
137 |
+
help="Maximum token size (default: from env or 32768)",
|
138 |
+
)
|
139 |
+
|
140 |
+
# Logging configuration
|
141 |
+
parser.add_argument(
|
142 |
+
"--log-level",
|
143 |
+
default=get_env_value("LOG_LEVEL", "INFO"),
|
144 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
145 |
+
help="Logging level (default: from env or INFO)",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--verbose",
|
149 |
+
action="store_true",
|
150 |
+
default=get_env_value("VERBOSE", False, bool),
|
151 |
+
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--key",
|
156 |
+
type=str,
|
157 |
+
default=get_env_value("LIGHTRAG_API_KEY", None),
|
158 |
+
help="API key for authentication. This protects lightrag server against unauthorized access",
|
159 |
+
)
|
160 |
+
|
161 |
+
# Optional https parameters
|
162 |
+
parser.add_argument(
|
163 |
+
"--ssl",
|
164 |
+
action="store_true",
|
165 |
+
default=get_env_value("SSL", False, bool),
|
166 |
+
help="Enable HTTPS (default: from env or False)",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--ssl-certfile",
|
170 |
+
default=get_env_value("SSL_CERTFILE", None),
|
171 |
+
help="Path to SSL certificate file (required if --ssl is enabled)",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--ssl-keyfile",
|
175 |
+
default=get_env_value("SSL_KEYFILE", None),
|
176 |
+
help="Path to SSL private key file (required if --ssl is enabled)",
|
177 |
+
)
|
178 |
+
|
179 |
+
parser.add_argument(
|
180 |
+
"--history-turns",
|
181 |
+
type=int,
|
182 |
+
default=get_env_value("HISTORY_TURNS", 3, int),
|
183 |
+
help="Number of conversation history turns to include (default: from env or 3)",
|
184 |
+
)
|
185 |
+
|
186 |
+
# Search parameters
|
187 |
+
parser.add_argument(
|
188 |
+
"--top-k",
|
189 |
+
type=int,
|
190 |
+
default=get_env_value("TOP_K", 60, int),
|
191 |
+
help="Number of most similar results to return (default: from env or 60)",
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--cosine-threshold",
|
195 |
+
type=float,
|
196 |
+
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
197 |
+
help="Cosine similarity threshold (default: from env or 0.4)",
|
198 |
+
)
|
199 |
+
|
200 |
+
# Ollama model name
|
201 |
+
parser.add_argument(
|
202 |
+
"--simulated-model-name",
|
203 |
+
type=str,
|
204 |
+
default=get_env_value(
|
205 |
+
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
206 |
+
),
|
207 |
+
help="Number of conversation history turns to include (default: from env or 3)",
|
208 |
+
)
|
209 |
+
|
210 |
+
# Namespace
|
211 |
+
parser.add_argument(
|
212 |
+
"--namespace-prefix",
|
213 |
+
type=str,
|
214 |
+
default=get_env_value("NAMESPACE_PREFIX", ""),
|
215 |
+
help="Prefix of the namespace",
|
216 |
+
)
|
217 |
+
|
218 |
+
parser.add_argument(
|
219 |
+
"--auto-scan-at-startup",
|
220 |
+
action="store_true",
|
221 |
+
default=False,
|
222 |
+
help="Enable automatic scanning when the program starts",
|
223 |
+
)
|
224 |
+
|
225 |
+
# Server workers configuration
|
226 |
+
parser.add_argument(
|
227 |
+
"--workers",
|
228 |
+
type=int,
|
229 |
+
default=get_env_value("WORKERS", 1, int),
|
230 |
+
help="Number of worker processes (default: from env or 1)",
|
231 |
+
)
|
232 |
+
|
233 |
+
# LLM and embedding bindings
|
234 |
+
parser.add_argument(
|
235 |
+
"--llm-binding",
|
236 |
+
type=str,
|
237 |
+
default=get_env_value("LLM_BINDING", "ollama"),
|
238 |
+
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
239 |
+
help="LLM binding type (default: from env or ollama)",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--embedding-binding",
|
243 |
+
type=str,
|
244 |
+
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
245 |
+
choices=["lollms", "ollama", "openai", "azure_openai"],
|
246 |
+
help="Embedding binding type (default: from env or ollama)",
|
247 |
+
)
|
248 |
+
|
249 |
+
args = parser.parse_args()
|
250 |
+
|
251 |
+
# convert relative path to absolute path
|
252 |
+
args.working_dir = os.path.abspath(args.working_dir)
|
253 |
+
args.input_dir = os.path.abspath(args.input_dir)
|
254 |
+
|
255 |
+
# Inject storage configuration from environment variables
|
256 |
+
args.kv_storage = get_env_value(
|
257 |
+
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
258 |
+
)
|
259 |
+
args.doc_status_storage = get_env_value(
|
260 |
+
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
261 |
+
)
|
262 |
+
args.graph_storage = get_env_value(
|
263 |
+
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
264 |
+
)
|
265 |
+
args.vector_storage = get_env_value(
|
266 |
+
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
267 |
+
)
|
268 |
+
|
269 |
+
# Get MAX_PARALLEL_INSERT from environment
|
270 |
+
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
271 |
+
|
272 |
+
# Handle openai-ollama special case
|
273 |
+
if args.llm_binding == "openai-ollama":
|
274 |
+
args.llm_binding = "openai"
|
275 |
+
args.embedding_binding = "ollama"
|
276 |
+
|
277 |
+
args.llm_binding_host = get_env_value(
|
278 |
+
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
279 |
+
)
|
280 |
+
args.embedding_binding_host = get_env_value(
|
281 |
+
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
282 |
+
)
|
283 |
+
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
284 |
+
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
285 |
+
|
286 |
+
# Inject model configuration
|
287 |
+
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
288 |
+
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
289 |
+
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
290 |
+
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
291 |
+
|
292 |
+
# Inject chunk configuration
|
293 |
+
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
294 |
+
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
295 |
+
|
296 |
+
# Inject LLM cache configuration
|
297 |
+
args.enable_llm_cache_for_extract = get_env_value(
|
298 |
+
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
299 |
+
)
|
300 |
+
|
301 |
+
# Inject LLM temperature configuration
|
302 |
+
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
303 |
+
|
304 |
+
# Select Document loading tool (DOCLING, DEFAULT)
|
305 |
+
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
306 |
+
|
307 |
+
# Add environment variables that were previously read directly
|
308 |
+
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
|
309 |
+
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
|
310 |
+
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
|
311 |
+
|
312 |
+
# For JWT Auth
|
313 |
+
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
|
314 |
+
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
|
315 |
+
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
|
316 |
+
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
317 |
+
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
318 |
+
|
319 |
+
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
320 |
+
|
321 |
+
return args
|
322 |
+
|
323 |
+
|
324 |
+
def update_uvicorn_mode_config():
|
325 |
+
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
326 |
+
if global_args.workers > 1:
|
327 |
+
original_workers = global_args.workers
|
328 |
+
global_args.workers = 1
|
329 |
+
# Log warning directly here
|
330 |
+
logging.warning(
|
331 |
+
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
global_args = parse_args()
|
lightrag/api/lightrag_server.py
CHANGED
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
|
|
19 |
from dotenv import load_dotenv
|
20 |
from lightrag.api.utils_api import (
|
21 |
get_combined_auth_dependency,
|
22 |
-
parse_args,
|
23 |
-
get_default_host,
|
24 |
display_splash_screen,
|
25 |
check_env_file,
|
26 |
)
|
|
|
|
|
|
|
|
|
|
|
27 |
import sys
|
28 |
from lightrag import LightRAG, __version__ as core_version
|
29 |
from lightrag.api import __api_version__
|
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
|
|
52 |
# the OS environment variables take precedence over the .env file
|
53 |
load_dotenv(dotenv_path=".env", override=False)
|
54 |
|
|
|
|
|
|
|
|
|
55 |
# Initialize config parser
|
56 |
config = configparser.ConfigParser()
|
57 |
config.read("config.ini")
|
@@ -164,10 +171,10 @@ def create_app(args):
|
|
164 |
app = FastAPI(**app_kwargs)
|
165 |
|
166 |
def get_cors_origins():
|
167 |
-
"""Get allowed origins from
|
168 |
Returns a list of allowed origins, defaults to ["*"] if not set
|
169 |
"""
|
170 |
-
origins_str =
|
171 |
if origins_str == "*":
|
172 |
return ["*"]
|
173 |
return [origin.strip() for origin in origins_str.split(",")]
|
@@ -315,9 +322,10 @@ def create_app(args):
|
|
315 |
"similarity_threshold": 0.95,
|
316 |
"use_llm_check": False,
|
317 |
},
|
318 |
-
namespace_prefix=args.namespace_prefix,
|
319 |
auto_manage_storages_states=False,
|
320 |
max_parallel_insert=args.max_parallel_insert,
|
|
|
321 |
)
|
322 |
else: # azure_openai
|
323 |
rag = LightRAG(
|
@@ -345,9 +353,10 @@ def create_app(args):
|
|
345 |
"similarity_threshold": 0.95,
|
346 |
"use_llm_check": False,
|
347 |
},
|
348 |
-
namespace_prefix=args.namespace_prefix,
|
349 |
auto_manage_storages_states=False,
|
350 |
max_parallel_insert=args.max_parallel_insert,
|
|
|
351 |
)
|
352 |
|
353 |
# Add routes
|
@@ -381,6 +390,8 @@ def create_app(args):
|
|
381 |
"message": "Authentication is disabled. Using guest access.",
|
382 |
"core_version": core_version,
|
383 |
"api_version": __api_version__,
|
|
|
|
|
384 |
}
|
385 |
|
386 |
return {
|
@@ -388,6 +399,8 @@ def create_app(args):
|
|
388 |
"auth_mode": "enabled",
|
389 |
"core_version": core_version,
|
390 |
"api_version": __api_version__,
|
|
|
|
|
391 |
}
|
392 |
|
393 |
@app.post("/login")
|
@@ -404,6 +417,8 @@ def create_app(args):
|
|
404 |
"message": "Authentication is disabled. Using guest access.",
|
405 |
"core_version": core_version,
|
406 |
"api_version": __api_version__,
|
|
|
|
|
407 |
}
|
408 |
username = form_data.username
|
409 |
if auth_handler.accounts.get(username) != form_data.password:
|
@@ -454,10 +469,12 @@ def create_app(args):
|
|
454 |
"vector_storage": args.vector_storage,
|
455 |
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
456 |
},
|
457 |
-
"core_version": core_version,
|
458 |
-
"api_version": __api_version__,
|
459 |
"auth_mode": auth_mode,
|
460 |
"pipeline_busy": pipeline_status.get("busy", False),
|
|
|
|
|
|
|
|
|
461 |
}
|
462 |
except Exception as e:
|
463 |
logger.error(f"Error getting health status: {str(e)}")
|
@@ -490,7 +507,7 @@ def create_app(args):
|
|
490 |
def get_application(args=None):
|
491 |
"""Factory function for creating the FastAPI application"""
|
492 |
if args is None:
|
493 |
-
args =
|
494 |
return create_app(args)
|
495 |
|
496 |
|
@@ -611,30 +628,31 @@ def main():
|
|
611 |
|
612 |
# Configure logging before parsing args
|
613 |
configure_logging()
|
614 |
-
|
615 |
-
|
616 |
-
display_splash_screen(args)
|
617 |
|
618 |
# Create application instance directly instead of using factory function
|
619 |
-
app = create_app(
|
620 |
|
621 |
# Start Uvicorn in single process mode
|
622 |
uvicorn_config = {
|
623 |
"app": app, # Pass application instance directly instead of string path
|
624 |
-
"host":
|
625 |
-
"port":
|
626 |
"log_config": None, # Disable default config
|
627 |
}
|
628 |
|
629 |
-
if
|
630 |
uvicorn_config.update(
|
631 |
{
|
632 |
-
"ssl_certfile":
|
633 |
-
"ssl_keyfile":
|
634 |
}
|
635 |
)
|
636 |
|
637 |
-
print(
|
|
|
|
|
638 |
uvicorn.run(**uvicorn_config)
|
639 |
|
640 |
|
|
|
19 |
from dotenv import load_dotenv
|
20 |
from lightrag.api.utils_api import (
|
21 |
get_combined_auth_dependency,
|
|
|
|
|
22 |
display_splash_screen,
|
23 |
check_env_file,
|
24 |
)
|
25 |
+
from .config import (
|
26 |
+
global_args,
|
27 |
+
update_uvicorn_mode_config,
|
28 |
+
get_default_host,
|
29 |
+
)
|
30 |
import sys
|
31 |
from lightrag import LightRAG, __version__ as core_version
|
32 |
from lightrag.api import __api_version__
|
|
|
55 |
# the OS environment variables take precedence over the .env file
|
56 |
load_dotenv(dotenv_path=".env", override=False)
|
57 |
|
58 |
+
|
59 |
+
webui_title = os.getenv("WEBUI_TITLE")
|
60 |
+
webui_description = os.getenv("WEBUI_DESCRIPTION")
|
61 |
+
|
62 |
# Initialize config parser
|
63 |
config = configparser.ConfigParser()
|
64 |
config.read("config.ini")
|
|
|
171 |
app = FastAPI(**app_kwargs)
|
172 |
|
173 |
def get_cors_origins():
|
174 |
+
"""Get allowed origins from global_args
|
175 |
Returns a list of allowed origins, defaults to ["*"] if not set
|
176 |
"""
|
177 |
+
origins_str = global_args.cors_origins
|
178 |
if origins_str == "*":
|
179 |
return ["*"]
|
180 |
return [origin.strip() for origin in origins_str.split(",")]
|
|
|
322 |
"similarity_threshold": 0.95,
|
323 |
"use_llm_check": False,
|
324 |
},
|
325 |
+
# namespace_prefix=args.namespace_prefix,
|
326 |
auto_manage_storages_states=False,
|
327 |
max_parallel_insert=args.max_parallel_insert,
|
328 |
+
addon_params={"language": args.summary_language},
|
329 |
)
|
330 |
else: # azure_openai
|
331 |
rag = LightRAG(
|
|
|
353 |
"similarity_threshold": 0.95,
|
354 |
"use_llm_check": False,
|
355 |
},
|
356 |
+
# namespace_prefix=args.namespace_prefix,
|
357 |
auto_manage_storages_states=False,
|
358 |
max_parallel_insert=args.max_parallel_insert,
|
359 |
+
addon_params={"language": args.summary_language},
|
360 |
)
|
361 |
|
362 |
# Add routes
|
|
|
390 |
"message": "Authentication is disabled. Using guest access.",
|
391 |
"core_version": core_version,
|
392 |
"api_version": __api_version__,
|
393 |
+
"webui_title": webui_title,
|
394 |
+
"webui_description": webui_description,
|
395 |
}
|
396 |
|
397 |
return {
|
|
|
399 |
"auth_mode": "enabled",
|
400 |
"core_version": core_version,
|
401 |
"api_version": __api_version__,
|
402 |
+
"webui_title": webui_title,
|
403 |
+
"webui_description": webui_description,
|
404 |
}
|
405 |
|
406 |
@app.post("/login")
|
|
|
417 |
"message": "Authentication is disabled. Using guest access.",
|
418 |
"core_version": core_version,
|
419 |
"api_version": __api_version__,
|
420 |
+
"webui_title": webui_title,
|
421 |
+
"webui_description": webui_description,
|
422 |
}
|
423 |
username = form_data.username
|
424 |
if auth_handler.accounts.get(username) != form_data.password:
|
|
|
469 |
"vector_storage": args.vector_storage,
|
470 |
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
471 |
},
|
|
|
|
|
472 |
"auth_mode": auth_mode,
|
473 |
"pipeline_busy": pipeline_status.get("busy", False),
|
474 |
+
"core_version": core_version,
|
475 |
+
"api_version": __api_version__,
|
476 |
+
"webui_title": webui_title,
|
477 |
+
"webui_description": webui_description,
|
478 |
}
|
479 |
except Exception as e:
|
480 |
logger.error(f"Error getting health status: {str(e)}")
|
|
|
507 |
def get_application(args=None):
|
508 |
"""Factory function for creating the FastAPI application"""
|
509 |
if args is None:
|
510 |
+
args = global_args
|
511 |
return create_app(args)
|
512 |
|
513 |
|
|
|
628 |
|
629 |
# Configure logging before parsing args
|
630 |
configure_logging()
|
631 |
+
update_uvicorn_mode_config()
|
632 |
+
display_splash_screen(global_args)
|
|
|
633 |
|
634 |
# Create application instance directly instead of using factory function
|
635 |
+
app = create_app(global_args)
|
636 |
|
637 |
# Start Uvicorn in single process mode
|
638 |
uvicorn_config = {
|
639 |
"app": app, # Pass application instance directly instead of string path
|
640 |
+
"host": global_args.host,
|
641 |
+
"port": global_args.port,
|
642 |
"log_config": None, # Disable default config
|
643 |
}
|
644 |
|
645 |
+
if global_args.ssl:
|
646 |
uvicorn_config.update(
|
647 |
{
|
648 |
+
"ssl_certfile": global_args.ssl_certfile,
|
649 |
+
"ssl_keyfile": global_args.ssl_keyfile,
|
650 |
}
|
651 |
)
|
652 |
|
653 |
+
print(
|
654 |
+
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
655 |
+
)
|
656 |
uvicorn.run(**uvicorn_config)
|
657 |
|
658 |
|
lightrag/api/routers/document_routes.py
CHANGED
@@ -10,16 +10,14 @@ import traceback
|
|
10 |
import pipmaster as pm
|
11 |
from datetime import datetime
|
12 |
from pathlib import Path
|
13 |
-
from typing import Dict, List, Optional, Any
|
14 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
15 |
from pydantic import BaseModel, Field, field_validator
|
16 |
|
17 |
from lightrag import LightRAG
|
18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
19 |
-
from lightrag.api.utils_api import
|
20 |
-
|
21 |
-
global_args,
|
22 |
-
)
|
23 |
|
24 |
router = APIRouter(
|
25 |
prefix="/documents",
|
@@ -30,7 +28,37 @@ router = APIRouter(
|
|
30 |
temp_prefix = "__tmp__"
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
class InsertTextRequest(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
text: str = Field(
|
35 |
min_length=1,
|
36 |
description="The text to insert",
|
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
|
|
41 |
def strip_after(cls, text: str) -> str:
|
42 |
return text.strip()
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
class InsertTextsRequest(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
texts: list[str] = Field(
|
47 |
min_length=1,
|
48 |
description="The texts to insert",
|
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
|
|
53 |
def strip_after(cls, texts: list[str]) -> list[str]:
|
54 |
return [text.strip() for text in texts]
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
class InsertResponse(BaseModel):
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
message: str = Field(description="Message describing the operation result")
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
class DocStatusResponse(BaseModel):
|
63 |
@staticmethod
|
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
|
|
68 |
return dt
|
69 |
return dt.isoformat()
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
Attributes:
|
74 |
-
|
75 |
-
content_summary: Summary of document content
|
76 |
-
content_length: Length of document content
|
77 |
-
status: Current processing status
|
78 |
-
created_at: Creation timestamp (ISO format string)
|
79 |
-
updated_at: Last update timestamp (ISO format string)
|
80 |
-
chunks_count: Number of chunks (optional)
|
81 |
-
error: Error message if any (optional)
|
82 |
-
metadata: Additional metadata (optional)
|
83 |
"""
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
created_at: str
|
90 |
-
updated_at: str
|
91 |
-
chunks_count: Optional[int] = None
|
92 |
-
error: Optional[str] = None
|
93 |
-
metadata: Optional[dict[str, Any]] = None
|
94 |
-
file_path: str
|
95 |
-
|
96 |
|
97 |
-
class
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
class PipelineStatusResponse(BaseModel):
|
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
276 |
)
|
277 |
return False
|
278 |
case ".pdf":
|
279 |
-
if global_args
|
280 |
if not pm.is_installed("docling"): # type: ignore
|
281 |
pm.install("docling")
|
282 |
from docling.document_converter import DocumentConverter # type: ignore
|
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
295 |
for page in reader.pages:
|
296 |
content += page.extract_text() + "\n"
|
297 |
case ".docx":
|
298 |
-
if global_args
|
299 |
if not pm.is_installed("docling"): # type: ignore
|
300 |
pm.install("docling")
|
301 |
from docling.document_converter import DocumentConverter # type: ignore
|
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
315 |
[paragraph.text for paragraph in doc.paragraphs]
|
316 |
)
|
317 |
case ".pptx":
|
318 |
-
if global_args
|
319 |
if not pm.is_installed("docling"): # type: ignore
|
320 |
pm.install("docling")
|
321 |
from docling.document_converter import DocumentConverter # type: ignore
|
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
336 |
if hasattr(shape, "text"):
|
337 |
content += shape.text + "\n"
|
338 |
case ".xlsx":
|
339 |
-
if global_args
|
340 |
if not pm.is_installed("docling"): # type: ignore
|
341 |
pm.install("docling")
|
342 |
from docling.document_converter import DocumentConverter # type: ignore
|
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
|
443 |
await rag.apipeline_process_enqueue_documents()
|
444 |
|
445 |
|
|
|
446 |
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
447 |
"""Save the uploaded file to a temporary location
|
448 |
|
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|
476 |
if not new_files:
|
477 |
return
|
478 |
|
479 |
-
# Get MAX_PARALLEL_INSERT from global_args
|
480 |
-
max_parallel = global_args
|
481 |
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
482 |
batch_size = 2 * max_parallel
|
483 |
|
@@ -509,7 +704,9 @@ def create_document_routes(
|
|
509 |
# Create combined auth dependency for document routes
|
510 |
combined_auth = get_combined_auth_dependency(api_key)
|
511 |
|
512 |
-
@router.post(
|
|
|
|
|
513 |
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
514 |
"""
|
515 |
Trigger the scanning process for new documents.
|
@@ -519,13 +716,18 @@ def create_document_routes(
|
|
519 |
that fact.
|
520 |
|
521 |
Returns:
|
522 |
-
|
523 |
"""
|
524 |
# Start the scanning process in the background
|
525 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
526 |
-
return
|
|
|
|
|
|
|
527 |
|
528 |
-
@router.post(
|
|
|
|
|
529 |
async def upload_to_input_dir(
|
530 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
531 |
):
|
@@ -645,6 +847,7 @@ def create_document_routes(
|
|
645 |
logger.error(traceback.format_exc())
|
646 |
raise HTTPException(status_code=500, detail=str(e))
|
647 |
|
|
|
648 |
@router.post(
|
649 |
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
650 |
)
|
@@ -688,6 +891,7 @@ def create_document_routes(
|
|
688 |
logger.error(traceback.format_exc())
|
689 |
raise HTTPException(status_code=500, detail=str(e))
|
690 |
|
|
|
691 |
@router.post(
|
692 |
"/file_batch",
|
693 |
response_model=InsertResponse,
|
@@ -752,32 +956,186 @@ def create_document_routes(
|
|
752 |
raise HTTPException(status_code=500, detail=str(e))
|
753 |
|
754 |
@router.delete(
|
755 |
-
"", response_model=
|
756 |
)
|
757 |
async def clear_documents():
|
758 |
"""
|
759 |
Clear all documents from the RAG system.
|
760 |
|
761 |
-
This endpoint deletes all
|
762 |
-
|
|
|
763 |
|
764 |
Returns:
|
765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
766 |
|
767 |
Raises:
|
768 |
-
HTTPException:
|
|
|
769 |
"""
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
777 |
except Exception as e:
|
778 |
-
|
|
|
779 |
logger.error(traceback.format_exc())
|
|
|
|
|
780 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
|
782 |
@router.get(
|
783 |
"/pipeline_status",
|
@@ -850,7 +1208,9 @@ def create_document_routes(
|
|
850 |
logger.error(traceback.format_exc())
|
851 |
raise HTTPException(status_code=500, detail=str(e))
|
852 |
|
853 |
-
@router.get(
|
|
|
|
|
854 |
async def documents() -> DocsStatusesResponse:
|
855 |
"""
|
856 |
Get the status of all documents in the system.
|
@@ -908,4 +1268,57 @@ def create_document_routes(
|
|
908 |
logger.error(traceback.format_exc())
|
909 |
raise HTTPException(status_code=500, detail=str(e))
|
910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
return router
|
|
|
10 |
import pipmaster as pm
|
11 |
from datetime import datetime
|
12 |
from pathlib import Path
|
13 |
+
from typing import Dict, List, Optional, Any, Literal
|
14 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
15 |
from pydantic import BaseModel, Field, field_validator
|
16 |
|
17 |
from lightrag import LightRAG
|
18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
19 |
+
from lightrag.api.utils_api import get_combined_auth_dependency
|
20 |
+
from ..config import global_args
|
|
|
|
|
21 |
|
22 |
router = APIRouter(
|
23 |
prefix="/documents",
|
|
|
28 |
temp_prefix = "__tmp__"
|
29 |
|
30 |
|
31 |
+
class ScanResponse(BaseModel):
|
32 |
+
"""Response model for document scanning operation
|
33 |
+
|
34 |
+
Attributes:
|
35 |
+
status: Status of the scanning operation
|
36 |
+
message: Optional message with additional details
|
37 |
+
"""
|
38 |
+
|
39 |
+
status: Literal["scanning_started"] = Field(
|
40 |
+
description="Status of the scanning operation"
|
41 |
+
)
|
42 |
+
message: Optional[str] = Field(
|
43 |
+
default=None, description="Additional details about the scanning operation"
|
44 |
+
)
|
45 |
+
|
46 |
+
class Config:
|
47 |
+
json_schema_extra = {
|
48 |
+
"example": {
|
49 |
+
"status": "scanning_started",
|
50 |
+
"message": "Scanning process has been initiated in the background",
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
class InsertTextRequest(BaseModel):
|
56 |
+
"""Request model for inserting a single text document
|
57 |
+
|
58 |
+
Attributes:
|
59 |
+
text: The text content to be inserted into the RAG system
|
60 |
+
"""
|
61 |
+
|
62 |
text: str = Field(
|
63 |
min_length=1,
|
64 |
description="The text to insert",
|
|
|
69 |
def strip_after(cls, text: str) -> str:
|
70 |
return text.strip()
|
71 |
|
72 |
+
class Config:
|
73 |
+
json_schema_extra = {
|
74 |
+
"example": {
|
75 |
+
"text": "This is a sample text to be inserted into the RAG system."
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
|
80 |
class InsertTextsRequest(BaseModel):
|
81 |
+
"""Request model for inserting multiple text documents
|
82 |
+
|
83 |
+
Attributes:
|
84 |
+
texts: List of text contents to be inserted into the RAG system
|
85 |
+
"""
|
86 |
+
|
87 |
texts: list[str] = Field(
|
88 |
min_length=1,
|
89 |
description="The texts to insert",
|
|
|
94 |
def strip_after(cls, texts: list[str]) -> list[str]:
|
95 |
return [text.strip() for text in texts]
|
96 |
|
97 |
+
class Config:
|
98 |
+
json_schema_extra = {
|
99 |
+
"example": {
|
100 |
+
"texts": [
|
101 |
+
"This is the first text to be inserted.",
|
102 |
+
"This is the second text to be inserted.",
|
103 |
+
]
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
|
108 |
class InsertResponse(BaseModel):
|
109 |
+
"""Response model for document insertion operations
|
110 |
+
|
111 |
+
Attributes:
|
112 |
+
status: Status of the operation (success, duplicated, partial_success, failure)
|
113 |
+
message: Detailed message describing the operation result
|
114 |
+
"""
|
115 |
+
|
116 |
+
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
117 |
+
description="Status of the operation"
|
118 |
+
)
|
119 |
message: str = Field(description="Message describing the operation result")
|
120 |
|
121 |
+
class Config:
|
122 |
+
json_schema_extra = {
|
123 |
+
"example": {
|
124 |
+
"status": "success",
|
125 |
+
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
class ClearDocumentsResponse(BaseModel):
|
131 |
+
"""Response model for document clearing operation
|
132 |
+
|
133 |
+
Attributes:
|
134 |
+
status: Status of the clear operation
|
135 |
+
message: Detailed message describing the operation result
|
136 |
+
"""
|
137 |
+
|
138 |
+
status: Literal["success", "partial_success", "busy", "fail"] = Field(
|
139 |
+
description="Status of the clear operation"
|
140 |
+
)
|
141 |
+
message: str = Field(description="Message describing the operation result")
|
142 |
+
|
143 |
+
class Config:
|
144 |
+
json_schema_extra = {
|
145 |
+
"example": {
|
146 |
+
"status": "success",
|
147 |
+
"message": "All documents cleared successfully. Deleted 15 files.",
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
|
152 |
+
class ClearCacheRequest(BaseModel):
|
153 |
+
"""Request model for clearing cache
|
154 |
+
|
155 |
+
Attributes:
|
156 |
+
modes: Optional list of cache modes to clear
|
157 |
+
"""
|
158 |
+
|
159 |
+
modes: Optional[
|
160 |
+
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
|
161 |
+
] = Field(
|
162 |
+
default=None,
|
163 |
+
description="Modes of cache to clear. If None, clears all cache.",
|
164 |
+
)
|
165 |
+
|
166 |
+
class Config:
|
167 |
+
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
|
168 |
+
|
169 |
+
|
170 |
+
class ClearCacheResponse(BaseModel):
|
171 |
+
"""Response model for cache clearing operation
|
172 |
+
|
173 |
+
Attributes:
|
174 |
+
status: Status of the clear operation
|
175 |
+
message: Detailed message describing the operation result
|
176 |
+
"""
|
177 |
+
|
178 |
+
status: Literal["success", "fail"] = Field(
|
179 |
+
description="Status of the clear operation"
|
180 |
+
)
|
181 |
+
message: str = Field(description="Message describing the operation result")
|
182 |
+
|
183 |
+
class Config:
|
184 |
+
json_schema_extra = {
|
185 |
+
"example": {
|
186 |
+
"status": "success",
|
187 |
+
"message": "Successfully cleared cache for modes: ['default', 'naive']",
|
188 |
+
}
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
"""Response model for document status
|
193 |
+
|
194 |
+
Attributes:
|
195 |
+
id: Document identifier
|
196 |
+
content_summary: Summary of document content
|
197 |
+
content_length: Length of document content
|
198 |
+
status: Current processing status
|
199 |
+
created_at: Creation timestamp (ISO format string)
|
200 |
+
updated_at: Last update timestamp (ISO format string)
|
201 |
+
chunks_count: Number of chunks (optional)
|
202 |
+
error: Error message if any (optional)
|
203 |
+
metadata: Additional metadata (optional)
|
204 |
+
file_path: Path to the document file
|
205 |
+
"""
|
206 |
+
|
207 |
|
208 |
class DocStatusResponse(BaseModel):
|
209 |
@staticmethod
|
|
|
214 |
return dt
|
215 |
return dt.isoformat()
|
216 |
|
217 |
+
id: str = Field(description="Document identifier")
|
218 |
+
content_summary: str = Field(description="Summary of document content")
|
219 |
+
content_length: int = Field(description="Length of document content in characters")
|
220 |
+
status: DocStatus = Field(description="Current processing status")
|
221 |
+
created_at: str = Field(description="Creation timestamp (ISO format string)")
|
222 |
+
updated_at: str = Field(description="Last update timestamp (ISO format string)")
|
223 |
+
chunks_count: Optional[int] = Field(
|
224 |
+
default=None, description="Number of chunks the document was split into"
|
225 |
+
)
|
226 |
+
error: Optional[str] = Field(
|
227 |
+
default=None, description="Error message if processing failed"
|
228 |
+
)
|
229 |
+
metadata: Optional[dict[str, Any]] = Field(
|
230 |
+
default=None, description="Additional metadata about the document"
|
231 |
+
)
|
232 |
+
file_path: str = Field(description="Path to the document file")
|
233 |
+
|
234 |
+
class Config:
|
235 |
+
json_schema_extra = {
|
236 |
+
"example": {
|
237 |
+
"id": "doc_123456",
|
238 |
+
"content_summary": "Research paper on machine learning",
|
239 |
+
"content_length": 15240,
|
240 |
+
"status": "PROCESSED",
|
241 |
+
"created_at": "2025-03-31T12:34:56",
|
242 |
+
"updated_at": "2025-03-31T12:35:30",
|
243 |
+
"chunks_count": 12,
|
244 |
+
"error": None,
|
245 |
+
"metadata": {"author": "John Doe", "year": 2025},
|
246 |
+
"file_path": "research_paper.pdf",
|
247 |
+
}
|
248 |
+
}
|
249 |
+
|
250 |
+
|
251 |
+
class DocsStatusesResponse(BaseModel):
|
252 |
+
"""Response model for document statuses
|
253 |
|
254 |
Attributes:
|
255 |
+
statuses: Dictionary mapping document status to lists of document status responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
"""
|
257 |
|
258 |
+
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
|
259 |
+
default_factory=dict,
|
260 |
+
description="Dictionary mapping document status to lists of document status responses",
|
261 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
+
class Config:
|
264 |
+
json_schema_extra = {
|
265 |
+
"example": {
|
266 |
+
"statuses": {
|
267 |
+
"PENDING": [
|
268 |
+
{
|
269 |
+
"id": "doc_123",
|
270 |
+
"content_summary": "Pending document",
|
271 |
+
"content_length": 5000,
|
272 |
+
"status": "PENDING",
|
273 |
+
"created_at": "2025-03-31T10:00:00",
|
274 |
+
"updated_at": "2025-03-31T10:00:00",
|
275 |
+
"file_path": "pending_doc.pdf",
|
276 |
+
}
|
277 |
+
],
|
278 |
+
"PROCESSED": [
|
279 |
+
{
|
280 |
+
"id": "doc_456",
|
281 |
+
"content_summary": "Processed document",
|
282 |
+
"content_length": 8000,
|
283 |
+
"status": "PROCESSED",
|
284 |
+
"created_at": "2025-03-31T09:00:00",
|
285 |
+
"updated_at": "2025-03-31T09:05:00",
|
286 |
+
"chunks_count": 8,
|
287 |
+
"file_path": "processed_doc.pdf",
|
288 |
+
}
|
289 |
+
],
|
290 |
+
}
|
291 |
+
}
|
292 |
+
}
|
293 |
|
294 |
|
295 |
class PipelineStatusResponse(BaseModel):
|
|
|
470 |
)
|
471 |
return False
|
472 |
case ".pdf":
|
473 |
+
if global_args.document_loading_engine == "DOCLING":
|
474 |
if not pm.is_installed("docling"): # type: ignore
|
475 |
pm.install("docling")
|
476 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
489 |
for page in reader.pages:
|
490 |
content += page.extract_text() + "\n"
|
491 |
case ".docx":
|
492 |
+
if global_args.document_loading_engine == "DOCLING":
|
493 |
if not pm.is_installed("docling"): # type: ignore
|
494 |
pm.install("docling")
|
495 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
509 |
[paragraph.text for paragraph in doc.paragraphs]
|
510 |
)
|
511 |
case ".pptx":
|
512 |
+
if global_args.document_loading_engine == "DOCLING":
|
513 |
if not pm.is_installed("docling"): # type: ignore
|
514 |
pm.install("docling")
|
515 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
530 |
if hasattr(shape, "text"):
|
531 |
content += shape.text + "\n"
|
532 |
case ".xlsx":
|
533 |
+
if global_args.document_loading_engine == "DOCLING":
|
534 |
if not pm.is_installed("docling"): # type: ignore
|
535 |
pm.install("docling")
|
536 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
637 |
await rag.apipeline_process_enqueue_documents()
|
638 |
|
639 |
|
640 |
+
# TODO: deprecate after /insert_file is removed
|
641 |
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
642 |
"""Save the uploaded file to a temporary location
|
643 |
|
|
|
671 |
if not new_files:
|
672 |
return
|
673 |
|
674 |
+
# Get MAX_PARALLEL_INSERT from global_args
|
675 |
+
max_parallel = global_args.max_parallel_insert
|
676 |
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
677 |
batch_size = 2 * max_parallel
|
678 |
|
|
|
704 |
# Create combined auth dependency for document routes
|
705 |
combined_auth = get_combined_auth_dependency(api_key)
|
706 |
|
707 |
+
@router.post(
|
708 |
+
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
|
709 |
+
)
|
710 |
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
711 |
"""
|
712 |
Trigger the scanning process for new documents.
|
|
|
716 |
that fact.
|
717 |
|
718 |
Returns:
|
719 |
+
ScanResponse: A response object containing the scanning status
|
720 |
"""
|
721 |
# Start the scanning process in the background
|
722 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
723 |
+
return ScanResponse(
|
724 |
+
status="scanning_started",
|
725 |
+
message="Scanning process has been initiated in the background",
|
726 |
+
)
|
727 |
|
728 |
+
@router.post(
|
729 |
+
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
730 |
+
)
|
731 |
async def upload_to_input_dir(
|
732 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
733 |
):
|
|
|
847 |
logger.error(traceback.format_exc())
|
848 |
raise HTTPException(status_code=500, detail=str(e))
|
849 |
|
850 |
+
# TODO: deprecated, use /upload instead
|
851 |
@router.post(
|
852 |
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
853 |
)
|
|
|
891 |
logger.error(traceback.format_exc())
|
892 |
raise HTTPException(status_code=500, detail=str(e))
|
893 |
|
894 |
+
# TODO: deprecated, use /upload instead
|
895 |
@router.post(
|
896 |
"/file_batch",
|
897 |
response_model=InsertResponse,
|
|
|
956 |
raise HTTPException(status_code=500, detail=str(e))
|
957 |
|
958 |
@router.delete(
|
959 |
+
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
|
960 |
)
|
961 |
async def clear_documents():
|
962 |
"""
|
963 |
Clear all documents from the RAG system.
|
964 |
|
965 |
+
This endpoint deletes all documents, entities, relationships, and files from the system.
|
966 |
+
It uses the storage drop methods to properly clean up all data and removes all files
|
967 |
+
from the input directory.
|
968 |
|
969 |
Returns:
|
970 |
+
ClearDocumentsResponse: A response object containing the status and message.
|
971 |
+
- status="success": All documents and files were successfully cleared.
|
972 |
+
- status="partial_success": Document clear job exit with some errors.
|
973 |
+
- status="busy": Operation could not be completed because the pipeline is busy.
|
974 |
+
- status="fail": All storage drop operations failed, with message
|
975 |
+
- message: Detailed information about the operation results, including counts
|
976 |
+
of deleted files and any errors encountered.
|
977 |
|
978 |
Raises:
|
979 |
+
HTTPException: Raised when a serious error occurs during the clearing process,
|
980 |
+
with status code 500 and error details in the detail field.
|
981 |
"""
|
982 |
+
from lightrag.kg.shared_storage import (
|
983 |
+
get_namespace_data,
|
984 |
+
get_pipeline_status_lock,
|
985 |
+
)
|
986 |
+
|
987 |
+
# Get pipeline status and lock
|
988 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
989 |
+
pipeline_status_lock = get_pipeline_status_lock()
|
990 |
+
|
991 |
+
# Check and set status with lock
|
992 |
+
async with pipeline_status_lock:
|
993 |
+
if pipeline_status.get("busy", False):
|
994 |
+
return ClearDocumentsResponse(
|
995 |
+
status="busy",
|
996 |
+
message="Cannot clear documents while pipeline is busy",
|
997 |
+
)
|
998 |
+
# Set busy to true
|
999 |
+
pipeline_status.update(
|
1000 |
+
{
|
1001 |
+
"busy": True,
|
1002 |
+
"job_name": "Clearing Documents",
|
1003 |
+
"job_start": datetime.now().isoformat(),
|
1004 |
+
"docs": 0,
|
1005 |
+
"batchs": 0,
|
1006 |
+
"cur_batch": 0,
|
1007 |
+
"request_pending": False, # Clear any previous request
|
1008 |
+
"latest_message": "Starting document clearing process",
|
1009 |
+
}
|
1010 |
)
|
1011 |
+
# Cleaning history_messages without breaking it as a shared list object
|
1012 |
+
del pipeline_status["history_messages"][:]
|
1013 |
+
pipeline_status["history_messages"].append(
|
1014 |
+
"Starting document clearing process"
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
try:
|
1018 |
+
# Use drop method to clear all data
|
1019 |
+
drop_tasks = []
|
1020 |
+
storages = [
|
1021 |
+
rag.text_chunks,
|
1022 |
+
rag.full_docs,
|
1023 |
+
rag.entities_vdb,
|
1024 |
+
rag.relationships_vdb,
|
1025 |
+
rag.chunks_vdb,
|
1026 |
+
rag.chunk_entity_relation_graph,
|
1027 |
+
rag.doc_status,
|
1028 |
+
]
|
1029 |
+
|
1030 |
+
# Log storage drop start
|
1031 |
+
if "history_messages" in pipeline_status:
|
1032 |
+
pipeline_status["history_messages"].append(
|
1033 |
+
"Starting to drop storage components"
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
for storage in storages:
|
1037 |
+
if storage is not None:
|
1038 |
+
drop_tasks.append(storage.drop())
|
1039 |
+
|
1040 |
+
# Wait for all drop tasks to complete
|
1041 |
+
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
|
1042 |
+
|
1043 |
+
# Check for errors and log results
|
1044 |
+
errors = []
|
1045 |
+
storage_success_count = 0
|
1046 |
+
storage_error_count = 0
|
1047 |
+
|
1048 |
+
for i, result in enumerate(drop_results):
|
1049 |
+
storage_name = storages[i].__class__.__name__
|
1050 |
+
if isinstance(result, Exception):
|
1051 |
+
error_msg = f"Error dropping {storage_name}: {str(result)}"
|
1052 |
+
errors.append(error_msg)
|
1053 |
+
logger.error(error_msg)
|
1054 |
+
storage_error_count += 1
|
1055 |
+
else:
|
1056 |
+
logger.info(f"Successfully dropped {storage_name}")
|
1057 |
+
storage_success_count += 1
|
1058 |
+
|
1059 |
+
# Log storage drop results
|
1060 |
+
if "history_messages" in pipeline_status:
|
1061 |
+
if storage_error_count > 0:
|
1062 |
+
pipeline_status["history_messages"].append(
|
1063 |
+
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
|
1064 |
+
)
|
1065 |
+
else:
|
1066 |
+
pipeline_status["history_messages"].append(
|
1067 |
+
f"Successfully dropped all {storage_success_count} storage components"
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
# If all storage operations failed, return error status and don't proceed with file deletion
|
1071 |
+
if storage_success_count == 0 and storage_error_count > 0:
|
1072 |
+
error_message = "All storage drop operations failed. Aborting document clearing process."
|
1073 |
+
logger.error(error_message)
|
1074 |
+
if "history_messages" in pipeline_status:
|
1075 |
+
pipeline_status["history_messages"].append(error_message)
|
1076 |
+
return ClearDocumentsResponse(status="fail", message=error_message)
|
1077 |
+
|
1078 |
+
# Log file deletion start
|
1079 |
+
if "history_messages" in pipeline_status:
|
1080 |
+
pipeline_status["history_messages"].append(
|
1081 |
+
"Starting to delete files in input directory"
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
# Delete all files in input_dir
|
1085 |
+
deleted_files_count = 0
|
1086 |
+
file_errors_count = 0
|
1087 |
+
|
1088 |
+
for file_path in doc_manager.input_dir.glob("**/*"):
|
1089 |
+
if file_path.is_file():
|
1090 |
+
try:
|
1091 |
+
file_path.unlink()
|
1092 |
+
deleted_files_count += 1
|
1093 |
+
except Exception as e:
|
1094 |
+
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
1095 |
+
file_errors_count += 1
|
1096 |
+
|
1097 |
+
# Log file deletion results
|
1098 |
+
if "history_messages" in pipeline_status:
|
1099 |
+
if file_errors_count > 0:
|
1100 |
+
pipeline_status["history_messages"].append(
|
1101 |
+
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
|
1102 |
+
)
|
1103 |
+
errors.append(f"Failed to delete {file_errors_count} files")
|
1104 |
+
else:
|
1105 |
+
pipeline_status["history_messages"].append(
|
1106 |
+
f"Successfully deleted {deleted_files_count} files"
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
# Prepare final result message
|
1110 |
+
final_message = ""
|
1111 |
+
if errors:
|
1112 |
+
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
|
1113 |
+
status = "partial_success"
|
1114 |
+
else:
|
1115 |
+
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
|
1116 |
+
status = "success"
|
1117 |
+
|
1118 |
+
# Log final result
|
1119 |
+
if "history_messages" in pipeline_status:
|
1120 |
+
pipeline_status["history_messages"].append(final_message)
|
1121 |
+
|
1122 |
+
# Return response based on results
|
1123 |
+
return ClearDocumentsResponse(status=status, message=final_message)
|
1124 |
except Exception as e:
|
1125 |
+
error_msg = f"Error clearing documents: {str(e)}"
|
1126 |
+
logger.error(error_msg)
|
1127 |
logger.error(traceback.format_exc())
|
1128 |
+
if "history_messages" in pipeline_status:
|
1129 |
+
pipeline_status["history_messages"].append(error_msg)
|
1130 |
raise HTTPException(status_code=500, detail=str(e))
|
1131 |
+
finally:
|
1132 |
+
# Reset busy status after completion
|
1133 |
+
async with pipeline_status_lock:
|
1134 |
+
pipeline_status["busy"] = False
|
1135 |
+
completion_msg = "Document clearing process completed"
|
1136 |
+
pipeline_status["latest_message"] = completion_msg
|
1137 |
+
if "history_messages" in pipeline_status:
|
1138 |
+
pipeline_status["history_messages"].append(completion_msg)
|
1139 |
|
1140 |
@router.get(
|
1141 |
"/pipeline_status",
|
|
|
1208 |
logger.error(traceback.format_exc())
|
1209 |
raise HTTPException(status_code=500, detail=str(e))
|
1210 |
|
1211 |
+
@router.get(
|
1212 |
+
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
|
1213 |
+
)
|
1214 |
async def documents() -> DocsStatusesResponse:
|
1215 |
"""
|
1216 |
Get the status of all documents in the system.
|
|
|
1268 |
logger.error(traceback.format_exc())
|
1269 |
raise HTTPException(status_code=500, detail=str(e))
|
1270 |
|
1271 |
+
@router.post(
|
1272 |
+
"/clear_cache",
|
1273 |
+
response_model=ClearCacheResponse,
|
1274 |
+
dependencies=[Depends(combined_auth)],
|
1275 |
+
)
|
1276 |
+
async def clear_cache(request: ClearCacheRequest):
|
1277 |
+
"""
|
1278 |
+
Clear cache data from the LLM response cache storage.
|
1279 |
+
|
1280 |
+
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
|
1281 |
+
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
|
1282 |
+
- "default" represents extraction cache.
|
1283 |
+
- Other modes correspond to different query modes.
|
1284 |
+
|
1285 |
+
Args:
|
1286 |
+
request (ClearCacheRequest): The request body containing optional modes to clear.
|
1287 |
+
|
1288 |
+
Returns:
|
1289 |
+
ClearCacheResponse: A response object containing the status and message.
|
1290 |
+
|
1291 |
+
Raises:
|
1292 |
+
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
|
1293 |
+
"""
|
1294 |
+
try:
|
1295 |
+
# Validate modes if provided
|
1296 |
+
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
|
1297 |
+
if request.modes and not all(mode in valid_modes for mode in request.modes):
|
1298 |
+
invalid_modes = [
|
1299 |
+
mode for mode in request.modes if mode not in valid_modes
|
1300 |
+
]
|
1301 |
+
raise HTTPException(
|
1302 |
+
status_code=400,
|
1303 |
+
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
|
1304 |
+
)
|
1305 |
+
|
1306 |
+
# Call the aclear_cache method
|
1307 |
+
await rag.aclear_cache(request.modes)
|
1308 |
+
|
1309 |
+
# Prepare success message
|
1310 |
+
if request.modes:
|
1311 |
+
message = f"Successfully cleared cache for modes: {request.modes}"
|
1312 |
+
else:
|
1313 |
+
message = "Successfully cleared all cache"
|
1314 |
+
|
1315 |
+
return ClearCacheResponse(status="success", message=message)
|
1316 |
+
except HTTPException:
|
1317 |
+
# Re-raise HTTP exceptions
|
1318 |
+
raise
|
1319 |
+
except Exception as e:
|
1320 |
+
logger.error(f"Error clearing cache: {str(e)}")
|
1321 |
+
logger.error(traceback.format_exc())
|
1322 |
+
raise HTTPException(status_code=500, detail=str(e))
|
1323 |
+
|
1324 |
return router
|
lightrag/api/routers/graph_routes.py
CHANGED
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
|
|
3 |
"""
|
4 |
|
5 |
from typing import Optional
|
6 |
-
from fastapi import APIRouter, Depends
|
7 |
|
8 |
from ..utils_api import get_combined_auth_dependency
|
9 |
|
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|
25 |
|
26 |
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
27 |
async def get_knowledge_graph(
|
28 |
-
label: str
|
|
|
|
|
29 |
):
|
30 |
"""
|
31 |
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
32 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
33 |
When reducing the number of nodes, the prioritization criteria are as follows:
|
34 |
-
1.
|
35 |
-
2.
|
36 |
-
3. Followed by nodes directly connected to the matching nodes
|
37 |
-
4. Finally, the degree of the nodes
|
38 |
-
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
39 |
|
40 |
Args:
|
41 |
-
label (str): Label
|
42 |
-
max_depth (int, optional): Maximum depth of
|
43 |
-
|
44 |
-
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
45 |
|
46 |
Returns:
|
47 |
Dict[str, List[str]]: Knowledge graph for label
|
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|
49 |
return await rag.get_knowledge_graph(
|
50 |
node_label=label,
|
51 |
max_depth=max_depth,
|
52 |
-
|
53 |
-
min_degree=min_degree,
|
54 |
)
|
55 |
|
56 |
return router
|
|
|
3 |
"""
|
4 |
|
5 |
from typing import Optional
|
6 |
+
from fastapi import APIRouter, Depends, Query
|
7 |
|
8 |
from ..utils_api import get_combined_auth_dependency
|
9 |
|
|
|
25 |
|
26 |
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
27 |
async def get_knowledge_graph(
|
28 |
+
label: str = Query(..., description="Label to get knowledge graph for"),
|
29 |
+
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
30 |
+
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
31 |
):
|
32 |
"""
|
33 |
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
|
|
34 |
When reducing the number of nodes, the prioritization criteria are as follows:
|
35 |
+
1. Hops(path) to the staring node take precedence
|
36 |
+
2. Followed by the degree of the nodes
|
|
|
|
|
|
|
37 |
|
38 |
Args:
|
39 |
+
label (str): Label of the starting node
|
40 |
+
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
41 |
+
max_nodes: Maxiumu nodes to return
|
|
|
42 |
|
43 |
Returns:
|
44 |
Dict[str, List[str]]: Knowledge graph for label
|
|
|
46 |
return await rag.get_knowledge_graph(
|
47 |
node_label=label,
|
48 |
max_depth=max_depth,
|
49 |
+
max_nodes=max_nodes,
|
|
|
50 |
)
|
51 |
|
52 |
return router
|
lightrag/api/run_with_gunicorn.py
CHANGED
@@ -7,14 +7,9 @@ import os
|
|
7 |
import sys
|
8 |
import signal
|
9 |
import pipmaster as pm
|
10 |
-
from lightrag.api.utils_api import
|
11 |
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
12 |
-
from
|
13 |
-
|
14 |
-
# use the .env that is inside the current folder
|
15 |
-
# allows to use different .env file for each lightrag instance
|
16 |
-
# the OS environment variables take precedence over the .env file
|
17 |
-
load_dotenv(dotenv_path=".env", override=False)
|
18 |
|
19 |
|
20 |
def check_and_install_dependencies():
|
@@ -59,20 +54,17 @@ def main():
|
|
59 |
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
60 |
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
61 |
|
62 |
-
# Parse all arguments using parse_args
|
63 |
-
args = parse_args(is_uvicorn_mode=False)
|
64 |
-
|
65 |
# Display startup information
|
66 |
-
display_splash_screen(
|
67 |
|
68 |
print("🚀 Starting LightRAG with Gunicorn")
|
69 |
-
print(f"🔄 Worker management: Gunicorn (workers={
|
70 |
print("🔍 Preloading app: Enabled")
|
71 |
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
72 |
print("\n\n" + "=" * 80)
|
73 |
print("MAIN PROCESS INITIALIZATION")
|
74 |
print(f"Process ID: {os.getpid()}")
|
75 |
-
print(f"Workers setting: {
|
76 |
print("=" * 80 + "\n")
|
77 |
|
78 |
# Import Gunicorn's StandaloneApplication
|
@@ -128,31 +120,43 @@ def main():
|
|
128 |
|
129 |
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
130 |
gunicorn_config.workers = (
|
131 |
-
|
|
|
|
|
132 |
)
|
133 |
|
134 |
# Bind configuration prioritizes command line arguments
|
135 |
-
host =
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
gunicorn_config.bind = f"{host}:{port}"
|
138 |
|
139 |
# Log level configuration prioritizes command line arguments
|
140 |
gunicorn_config.loglevel = (
|
141 |
-
|
142 |
-
if
|
143 |
else os.getenv("LOG_LEVEL", "info")
|
144 |
)
|
145 |
|
146 |
# Timeout configuration prioritizes command line arguments
|
147 |
gunicorn_config.timeout = (
|
148 |
-
|
|
|
|
|
149 |
)
|
150 |
|
151 |
# Keepalive configuration
|
152 |
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
153 |
|
154 |
# SSL configuration prioritizes command line arguments
|
155 |
-
if
|
156 |
"true",
|
157 |
"1",
|
158 |
"yes",
|
@@ -160,12 +164,14 @@ def main():
|
|
160 |
"on",
|
161 |
):
|
162 |
gunicorn_config.certfile = (
|
163 |
-
|
164 |
-
if
|
165 |
else os.getenv("SSL_CERTFILE")
|
166 |
)
|
167 |
gunicorn_config.keyfile = (
|
168 |
-
|
|
|
|
|
169 |
)
|
170 |
|
171 |
# Set configuration options from the module
|
@@ -190,13 +196,13 @@ def main():
|
|
190 |
# Import the application
|
191 |
from lightrag.api.lightrag_server import get_application
|
192 |
|
193 |
-
return get_application(
|
194 |
|
195 |
# Create the application
|
196 |
app = GunicornApp("")
|
197 |
|
198 |
# Force workers to be an integer and greater than 1 for multi-process mode
|
199 |
-
workers_count = int(
|
200 |
if workers_count > 1:
|
201 |
# Set a flag to indicate we're in the main process
|
202 |
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
|
|
7 |
import sys
|
8 |
import signal
|
9 |
import pipmaster as pm
|
10 |
+
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
11 |
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
12 |
+
from .config import global_args
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def check_and_install_dependencies():
|
|
|
54 |
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
55 |
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
56 |
|
|
|
|
|
|
|
57 |
# Display startup information
|
58 |
+
display_splash_screen(global_args)
|
59 |
|
60 |
print("🚀 Starting LightRAG with Gunicorn")
|
61 |
+
print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
|
62 |
print("🔍 Preloading app: Enabled")
|
63 |
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
64 |
print("\n\n" + "=" * 80)
|
65 |
print("MAIN PROCESS INITIALIZATION")
|
66 |
print(f"Process ID: {os.getpid()}")
|
67 |
+
print(f"Workers setting: {global_args.workers}")
|
68 |
print("=" * 80 + "\n")
|
69 |
|
70 |
# Import Gunicorn's StandaloneApplication
|
|
|
120 |
|
121 |
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
122 |
gunicorn_config.workers = (
|
123 |
+
global_args.workers
|
124 |
+
if global_args.workers
|
125 |
+
else int(os.getenv("WORKERS", 1))
|
126 |
)
|
127 |
|
128 |
# Bind configuration prioritizes command line arguments
|
129 |
+
host = (
|
130 |
+
global_args.host
|
131 |
+
if global_args.host != "0.0.0.0"
|
132 |
+
else os.getenv("HOST", "0.0.0.0")
|
133 |
+
)
|
134 |
+
port = (
|
135 |
+
global_args.port
|
136 |
+
if global_args.port != 9621
|
137 |
+
else int(os.getenv("PORT", 9621))
|
138 |
+
)
|
139 |
gunicorn_config.bind = f"{host}:{port}"
|
140 |
|
141 |
# Log level configuration prioritizes command line arguments
|
142 |
gunicorn_config.loglevel = (
|
143 |
+
global_args.log_level.lower()
|
144 |
+
if global_args.log_level
|
145 |
else os.getenv("LOG_LEVEL", "info")
|
146 |
)
|
147 |
|
148 |
# Timeout configuration prioritizes command line arguments
|
149 |
gunicorn_config.timeout = (
|
150 |
+
global_args.timeout
|
151 |
+
if global_args.timeout * 2
|
152 |
+
else int(os.getenv("TIMEOUT", 150 * 2))
|
153 |
)
|
154 |
|
155 |
# Keepalive configuration
|
156 |
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
157 |
|
158 |
# SSL configuration prioritizes command line arguments
|
159 |
+
if global_args.ssl or os.getenv("SSL", "").lower() in (
|
160 |
"true",
|
161 |
"1",
|
162 |
"yes",
|
|
|
164 |
"on",
|
165 |
):
|
166 |
gunicorn_config.certfile = (
|
167 |
+
global_args.ssl_certfile
|
168 |
+
if global_args.ssl_certfile
|
169 |
else os.getenv("SSL_CERTFILE")
|
170 |
)
|
171 |
gunicorn_config.keyfile = (
|
172 |
+
global_args.ssl_keyfile
|
173 |
+
if global_args.ssl_keyfile
|
174 |
+
else os.getenv("SSL_KEYFILE")
|
175 |
)
|
176 |
|
177 |
# Set configuration options from the module
|
|
|
196 |
# Import the application
|
197 |
from lightrag.api.lightrag_server import get_application
|
198 |
|
199 |
+
return get_application(global_args)
|
200 |
|
201 |
# Create the application
|
202 |
app = GunicornApp("")
|
203 |
|
204 |
# Force workers to be an integer and greater than 1 for multi-process mode
|
205 |
+
workers_count = int(global_args.workers)
|
206 |
if workers_count > 1:
|
207 |
# Set a flag to indicate we're in the main process
|
208 |
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
lightrag/api/utils_api.py
CHANGED
@@ -7,15 +7,13 @@ import argparse
|
|
7 |
from typing import Optional, List, Tuple
|
8 |
import sys
|
9 |
from ascii_colors import ASCIIColors
|
10 |
-
import logging
|
11 |
from lightrag.api import __api_version__ as api_version
|
12 |
from lightrag import __version__ as core_version
|
13 |
from fastapi import HTTPException, Security, Request, status
|
14 |
-
from dotenv import load_dotenv
|
15 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
16 |
from starlette.status import HTTP_403_FORBIDDEN
|
17 |
from .auth import auth_handler
|
18 |
-
from
|
19 |
|
20 |
|
21 |
def check_env_file():
|
@@ -36,16 +34,8 @@ def check_env_file():
|
|
36 |
return True
|
37 |
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
# the OS environment variables take precedence over the .env file
|
42 |
-
load_dotenv(dotenv_path=".env", override=False)
|
43 |
-
|
44 |
-
global_args = {"main_args": None}
|
45 |
-
|
46 |
-
# Get whitelist paths from environment variable, only once during initialization
|
47 |
-
default_whitelist = "/health,/api/*"
|
48 |
-
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
49 |
|
50 |
# Pre-compile path matching patterns
|
51 |
whitelist_patterns: List[Tuple[str, bool]] = []
|
@@ -63,19 +53,6 @@ for path in whitelist_paths:
|
|
63 |
auth_configured = bool(auth_handler.accounts)
|
64 |
|
65 |
|
66 |
-
class OllamaServerInfos:
|
67 |
-
# Constants for emulated Ollama model information
|
68 |
-
LIGHTRAG_NAME = "lightrag"
|
69 |
-
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
70 |
-
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
71 |
-
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
72 |
-
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
73 |
-
LIGHTRAG_DIGEST = "sha256:lightrag"
|
74 |
-
|
75 |
-
|
76 |
-
ollama_server_infos = OllamaServerInfos()
|
77 |
-
|
78 |
-
|
79 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
80 |
"""
|
81 |
Create a combined authentication dependency that implements authentication logic
|
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|
186 |
return combined_dependency
|
187 |
|
188 |
|
189 |
-
class DefaultRAGStorageConfig:
|
190 |
-
KV_STORAGE = "JsonKVStorage"
|
191 |
-
VECTOR_STORAGE = "NanoVectorDBStorage"
|
192 |
-
GRAPH_STORAGE = "NetworkXStorage"
|
193 |
-
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
194 |
-
|
195 |
-
|
196 |
-
def get_default_host(binding_type: str) -> str:
|
197 |
-
default_hosts = {
|
198 |
-
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
199 |
-
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
200 |
-
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
201 |
-
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
202 |
-
}
|
203 |
-
return default_hosts.get(
|
204 |
-
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
205 |
-
) # fallback to ollama if unknown
|
206 |
-
|
207 |
-
|
208 |
-
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
209 |
-
"""
|
210 |
-
Get value from environment variable with type conversion
|
211 |
-
|
212 |
-
Args:
|
213 |
-
env_key (str): Environment variable key
|
214 |
-
default (any): Default value if env variable is not set
|
215 |
-
value_type (type): Type to convert the value to
|
216 |
-
|
217 |
-
Returns:
|
218 |
-
any: Converted value from environment or default
|
219 |
-
"""
|
220 |
-
value = os.getenv(env_key)
|
221 |
-
if value is None:
|
222 |
-
return default
|
223 |
-
|
224 |
-
if value_type is bool:
|
225 |
-
return value.lower() in ("true", "1", "yes", "t", "on")
|
226 |
-
try:
|
227 |
-
return value_type(value)
|
228 |
-
except ValueError:
|
229 |
-
return default
|
230 |
-
|
231 |
-
|
232 |
-
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
233 |
-
"""
|
234 |
-
Parse command line arguments with environment variable fallback
|
235 |
-
|
236 |
-
Args:
|
237 |
-
is_uvicorn_mode: Whether running under uvicorn mode
|
238 |
-
|
239 |
-
Returns:
|
240 |
-
argparse.Namespace: Parsed arguments
|
241 |
-
"""
|
242 |
-
|
243 |
-
parser = argparse.ArgumentParser(
|
244 |
-
description="LightRAG FastAPI Server with separate working and input directories"
|
245 |
-
)
|
246 |
-
|
247 |
-
# Server configuration
|
248 |
-
parser.add_argument(
|
249 |
-
"--host",
|
250 |
-
default=get_env_value("HOST", "0.0.0.0"),
|
251 |
-
help="Server host (default: from env or 0.0.0.0)",
|
252 |
-
)
|
253 |
-
parser.add_argument(
|
254 |
-
"--port",
|
255 |
-
type=int,
|
256 |
-
default=get_env_value("PORT", 9621, int),
|
257 |
-
help="Server port (default: from env or 9621)",
|
258 |
-
)
|
259 |
-
|
260 |
-
# Directory configuration
|
261 |
-
parser.add_argument(
|
262 |
-
"--working-dir",
|
263 |
-
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
264 |
-
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
265 |
-
)
|
266 |
-
parser.add_argument(
|
267 |
-
"--input-dir",
|
268 |
-
default=get_env_value("INPUT_DIR", "./inputs"),
|
269 |
-
help="Directory containing input documents (default: from env or ./inputs)",
|
270 |
-
)
|
271 |
-
|
272 |
-
def timeout_type(value):
|
273 |
-
if value is None:
|
274 |
-
return 150
|
275 |
-
if value is None or value == "None":
|
276 |
-
return None
|
277 |
-
return int(value)
|
278 |
-
|
279 |
-
parser.add_argument(
|
280 |
-
"--timeout",
|
281 |
-
default=get_env_value("TIMEOUT", None, timeout_type),
|
282 |
-
type=timeout_type,
|
283 |
-
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
284 |
-
)
|
285 |
-
|
286 |
-
# RAG configuration
|
287 |
-
parser.add_argument(
|
288 |
-
"--max-async",
|
289 |
-
type=int,
|
290 |
-
default=get_env_value("MAX_ASYNC", 4, int),
|
291 |
-
help="Maximum async operations (default: from env or 4)",
|
292 |
-
)
|
293 |
-
parser.add_argument(
|
294 |
-
"--max-tokens",
|
295 |
-
type=int,
|
296 |
-
default=get_env_value("MAX_TOKENS", 32768, int),
|
297 |
-
help="Maximum token size (default: from env or 32768)",
|
298 |
-
)
|
299 |
-
|
300 |
-
# Logging configuration
|
301 |
-
parser.add_argument(
|
302 |
-
"--log-level",
|
303 |
-
default=get_env_value("LOG_LEVEL", "INFO"),
|
304 |
-
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
305 |
-
help="Logging level (default: from env or INFO)",
|
306 |
-
)
|
307 |
-
parser.add_argument(
|
308 |
-
"--verbose",
|
309 |
-
action="store_true",
|
310 |
-
default=get_env_value("VERBOSE", False, bool),
|
311 |
-
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
312 |
-
)
|
313 |
-
|
314 |
-
parser.add_argument(
|
315 |
-
"--key",
|
316 |
-
type=str,
|
317 |
-
default=get_env_value("LIGHTRAG_API_KEY", None),
|
318 |
-
help="API key for authentication. This protects lightrag server against unauthorized access",
|
319 |
-
)
|
320 |
-
|
321 |
-
# Optional https parameters
|
322 |
-
parser.add_argument(
|
323 |
-
"--ssl",
|
324 |
-
action="store_true",
|
325 |
-
default=get_env_value("SSL", False, bool),
|
326 |
-
help="Enable HTTPS (default: from env or False)",
|
327 |
-
)
|
328 |
-
parser.add_argument(
|
329 |
-
"--ssl-certfile",
|
330 |
-
default=get_env_value("SSL_CERTFILE", None),
|
331 |
-
help="Path to SSL certificate file (required if --ssl is enabled)",
|
332 |
-
)
|
333 |
-
parser.add_argument(
|
334 |
-
"--ssl-keyfile",
|
335 |
-
default=get_env_value("SSL_KEYFILE", None),
|
336 |
-
help="Path to SSL private key file (required if --ssl is enabled)",
|
337 |
-
)
|
338 |
-
|
339 |
-
parser.add_argument(
|
340 |
-
"--history-turns",
|
341 |
-
type=int,
|
342 |
-
default=get_env_value("HISTORY_TURNS", 3, int),
|
343 |
-
help="Number of conversation history turns to include (default: from env or 3)",
|
344 |
-
)
|
345 |
-
|
346 |
-
# Search parameters
|
347 |
-
parser.add_argument(
|
348 |
-
"--top-k",
|
349 |
-
type=int,
|
350 |
-
default=get_env_value("TOP_K", 60, int),
|
351 |
-
help="Number of most similar results to return (default: from env or 60)",
|
352 |
-
)
|
353 |
-
parser.add_argument(
|
354 |
-
"--cosine-threshold",
|
355 |
-
type=float,
|
356 |
-
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
357 |
-
help="Cosine similarity threshold (default: from env or 0.4)",
|
358 |
-
)
|
359 |
-
|
360 |
-
# Ollama model name
|
361 |
-
parser.add_argument(
|
362 |
-
"--simulated-model-name",
|
363 |
-
type=str,
|
364 |
-
default=get_env_value(
|
365 |
-
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
366 |
-
),
|
367 |
-
help="Number of conversation history turns to include (default: from env or 3)",
|
368 |
-
)
|
369 |
-
|
370 |
-
# Namespace
|
371 |
-
parser.add_argument(
|
372 |
-
"--namespace-prefix",
|
373 |
-
type=str,
|
374 |
-
default=get_env_value("NAMESPACE_PREFIX", ""),
|
375 |
-
help="Prefix of the namespace",
|
376 |
-
)
|
377 |
-
|
378 |
-
parser.add_argument(
|
379 |
-
"--auto-scan-at-startup",
|
380 |
-
action="store_true",
|
381 |
-
default=False,
|
382 |
-
help="Enable automatic scanning when the program starts",
|
383 |
-
)
|
384 |
-
|
385 |
-
# Server workers configuration
|
386 |
-
parser.add_argument(
|
387 |
-
"--workers",
|
388 |
-
type=int,
|
389 |
-
default=get_env_value("WORKERS", 1, int),
|
390 |
-
help="Number of worker processes (default: from env or 1)",
|
391 |
-
)
|
392 |
-
|
393 |
-
# LLM and embedding bindings
|
394 |
-
parser.add_argument(
|
395 |
-
"--llm-binding",
|
396 |
-
type=str,
|
397 |
-
default=get_env_value("LLM_BINDING", "ollama"),
|
398 |
-
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
399 |
-
help="LLM binding type (default: from env or ollama)",
|
400 |
-
)
|
401 |
-
parser.add_argument(
|
402 |
-
"--embedding-binding",
|
403 |
-
type=str,
|
404 |
-
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
405 |
-
choices=["lollms", "ollama", "openai", "azure_openai"],
|
406 |
-
help="Embedding binding type (default: from env or ollama)",
|
407 |
-
)
|
408 |
-
|
409 |
-
args = parser.parse_args()
|
410 |
-
|
411 |
-
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
412 |
-
if is_uvicorn_mode and args.workers > 1:
|
413 |
-
original_workers = args.workers
|
414 |
-
args.workers = 1
|
415 |
-
# Log warning directly here
|
416 |
-
logging.warning(
|
417 |
-
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
418 |
-
)
|
419 |
-
|
420 |
-
# convert relative path to absolute path
|
421 |
-
args.working_dir = os.path.abspath(args.working_dir)
|
422 |
-
args.input_dir = os.path.abspath(args.input_dir)
|
423 |
-
|
424 |
-
# Inject storage configuration from environment variables
|
425 |
-
args.kv_storage = get_env_value(
|
426 |
-
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
427 |
-
)
|
428 |
-
args.doc_status_storage = get_env_value(
|
429 |
-
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
430 |
-
)
|
431 |
-
args.graph_storage = get_env_value(
|
432 |
-
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
433 |
-
)
|
434 |
-
args.vector_storage = get_env_value(
|
435 |
-
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
436 |
-
)
|
437 |
-
|
438 |
-
# Get MAX_PARALLEL_INSERT from environment
|
439 |
-
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
440 |
-
|
441 |
-
# Handle openai-ollama special case
|
442 |
-
if args.llm_binding == "openai-ollama":
|
443 |
-
args.llm_binding = "openai"
|
444 |
-
args.embedding_binding = "ollama"
|
445 |
-
|
446 |
-
args.llm_binding_host = get_env_value(
|
447 |
-
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
448 |
-
)
|
449 |
-
args.embedding_binding_host = get_env_value(
|
450 |
-
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
451 |
-
)
|
452 |
-
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
453 |
-
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
454 |
-
|
455 |
-
# Inject model configuration
|
456 |
-
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
457 |
-
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
458 |
-
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
459 |
-
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
460 |
-
|
461 |
-
# Inject chunk configuration
|
462 |
-
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
463 |
-
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
464 |
-
|
465 |
-
# Inject LLM cache configuration
|
466 |
-
args.enable_llm_cache_for_extract = get_env_value(
|
467 |
-
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
468 |
-
)
|
469 |
-
|
470 |
-
# Inject LLM temperature configuration
|
471 |
-
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
472 |
-
|
473 |
-
# Select Document loading tool (DOCLING, DEFAULT)
|
474 |
-
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
475 |
-
|
476 |
-
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
477 |
-
|
478 |
-
global_args["main_args"] = args
|
479 |
-
return args
|
480 |
-
|
481 |
-
|
482 |
def display_splash_screen(args: argparse.Namespace) -> None:
|
483 |
"""
|
484 |
Display a colorful splash screen showing LightRAG server configuration
|
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
503 |
ASCIIColors.white(" ├─ Workers: ", end="")
|
504 |
ASCIIColors.yellow(f"{args.workers}")
|
505 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
506 |
-
ASCIIColors.yellow(f"{
|
507 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
508 |
ASCIIColors.yellow(f"{args.ssl}")
|
509 |
if args.ssl:
|
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
519 |
ASCIIColors.yellow(f"{args.verbose}")
|
520 |
ASCIIColors.white(" ├─ History Turns: ", end="")
|
521 |
ASCIIColors.yellow(f"{args.history_turns}")
|
522 |
-
ASCIIColors.white("
|
523 |
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
|
|
|
|
524 |
|
525 |
# Directory Configuration
|
526 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
558 |
ASCIIColors.yellow(f"{args.embedding_dim}")
|
559 |
|
560 |
# RAG Configuration
|
561 |
-
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
|
562 |
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
563 |
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
564 |
-
ASCIIColors.yellow(f"{summary_language}")
|
565 |
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
566 |
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
567 |
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
595 |
protocol = "https" if args.ssl else "http"
|
596 |
if args.host == "0.0.0.0":
|
597 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
598 |
-
ASCIIColors.white(" ├─
|
599 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
600 |
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
601 |
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
602 |
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
603 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
604 |
-
ASCIIColors.white("
|
605 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
606 |
-
ASCIIColors.white(" └─ WebUI (local): ", end="")
|
607 |
-
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
|
608 |
|
609 |
-
ASCIIColors.
|
610 |
-
ASCIIColors.
|
611 |
- Use 'localhost' or '127.0.0.1' for local access
|
612 |
- Use your machine's IP address for remote access
|
613 |
- To find your IP address:
|
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
617 |
else:
|
618 |
base_url = f"{protocol}://{args.host}:{args.port}"
|
619 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
620 |
-
ASCIIColors.white(" ├─
|
621 |
ASCIIColors.yellow(f"{base_url}")
|
622 |
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
623 |
ASCIIColors.yellow(f"{base_url}/docs")
|
624 |
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
625 |
ASCIIColors.yellow(f"{base_url}/redoc")
|
626 |
|
627 |
-
# Usage Examples
|
628 |
-
ASCIIColors.magenta("\n📚 Quick Start Guide:")
|
629 |
-
ASCIIColors.cyan("""
|
630 |
-
1. Access the Swagger UI:
|
631 |
-
Open your browser and navigate to the API documentation URL above
|
632 |
-
|
633 |
-
2. API Authentication:""")
|
634 |
-
if args.key:
|
635 |
-
ASCIIColors.cyan(""" Add the following header to your requests:
|
636 |
-
X-API-Key: <your-api-key>
|
637 |
-
""")
|
638 |
-
else:
|
639 |
-
ASCIIColors.cyan(" No authentication required\n")
|
640 |
-
|
641 |
-
ASCIIColors.cyan(""" 3. Basic Operations:
|
642 |
-
- POST /upload_document: Upload new documents to RAG
|
643 |
-
- POST /query: Query your document collection
|
644 |
-
|
645 |
-
4. Monitor the server:
|
646 |
-
- Check server logs for detailed operation information
|
647 |
-
- Use healthcheck endpoint: GET /health
|
648 |
-
""")
|
649 |
-
|
650 |
# Security Notice
|
651 |
if args.key:
|
652 |
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
653 |
ASCIIColors.white(""" API Key authentication is enabled.
|
654 |
Make sure to include the X-API-Key header in all your requests.
|
655 |
""")
|
|
|
|
|
|
|
|
|
|
|
656 |
|
657 |
# Ensure splash output flush to system log
|
658 |
sys.stdout.flush()
|
|
|
7 |
from typing import Optional, List, Tuple
|
8 |
import sys
|
9 |
from ascii_colors import ASCIIColors
|
|
|
10 |
from lightrag.api import __api_version__ as api_version
|
11 |
from lightrag import __version__ as core_version
|
12 |
from fastapi import HTTPException, Security, Request, status
|
|
|
13 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
14 |
from starlette.status import HTTP_403_FORBIDDEN
|
15 |
from .auth import auth_handler
|
16 |
+
from .config import ollama_server_infos, global_args
|
17 |
|
18 |
|
19 |
def check_env_file():
|
|
|
34 |
return True
|
35 |
|
36 |
|
37 |
+
# Get whitelist paths from global_args, only once during initialization
|
38 |
+
whitelist_paths = global_args.whitelist_paths.split(",")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Pre-compile path matching patterns
|
41 |
whitelist_patterns: List[Tuple[str, bool]] = []
|
|
|
53 |
auth_configured = bool(auth_handler.accounts)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
57 |
"""
|
58 |
Create a combined authentication dependency that implements authentication logic
|
|
|
163 |
return combined_dependency
|
164 |
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
def display_splash_screen(args: argparse.Namespace) -> None:
|
167 |
"""
|
168 |
Display a colorful splash screen showing LightRAG server configuration
|
|
|
187 |
ASCIIColors.white(" ├─ Workers: ", end="")
|
188 |
ASCIIColors.yellow(f"{args.workers}")
|
189 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
190 |
+
ASCIIColors.yellow(f"{args.cors_origins}")
|
191 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
192 |
ASCIIColors.yellow(f"{args.ssl}")
|
193 |
if args.ssl:
|
|
|
203 |
ASCIIColors.yellow(f"{args.verbose}")
|
204 |
ASCIIColors.white(" ├─ History Turns: ", end="")
|
205 |
ASCIIColors.yellow(f"{args.history_turns}")
|
206 |
+
ASCIIColors.white(" ├─ API Key: ", end="")
|
207 |
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
208 |
+
ASCIIColors.white(" └─ JWT Auth: ", end="")
|
209 |
+
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
|
210 |
|
211 |
# Directory Configuration
|
212 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
|
|
244 |
ASCIIColors.yellow(f"{args.embedding_dim}")
|
245 |
|
246 |
# RAG Configuration
|
|
|
247 |
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
248 |
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
249 |
+
ASCIIColors.yellow(f"{args.summary_language}")
|
250 |
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
251 |
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
252 |
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
|
|
280 |
protocol = "https" if args.ssl else "http"
|
281 |
if args.host == "0.0.0.0":
|
282 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
283 |
+
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
284 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
285 |
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
286 |
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
287 |
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
288 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
289 |
+
ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
|
290 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
|
|
|
|
291 |
|
292 |
+
ASCIIColors.magenta("\n📝 Note:")
|
293 |
+
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
294 |
- Use 'localhost' or '127.0.0.1' for local access
|
295 |
- Use your machine's IP address for remote access
|
296 |
- To find your IP address:
|
|
|
300 |
else:
|
301 |
base_url = f"{protocol}://{args.host}:{args.port}"
|
302 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
303 |
+
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
304 |
ASCIIColors.yellow(f"{base_url}")
|
305 |
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
306 |
ASCIIColors.yellow(f"{base_url}/docs")
|
307 |
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
308 |
ASCIIColors.yellow(f"{base_url}/redoc")
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
# Security Notice
|
311 |
if args.key:
|
312 |
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
313 |
ASCIIColors.white(""" API Key authentication is enabled.
|
314 |
Make sure to include the X-API-Key header in all your requests.
|
315 |
""")
|
316 |
+
if args.auth_accounts:
|
317 |
+
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
318 |
+
ASCIIColors.white(""" JWT authentication is enabled.
|
319 |
+
Make sure to login before making the request, and include the 'Authorization' in the header.
|
320 |
+
""")
|
321 |
|
322 |
# Ensure splash output flush to system log
|
323 |
sys.stdout.flush()
|
lightrag/api/webui/assets/{index-D8zGvNlV.js → index-BaHKTcxB.js}
RENAMED
Binary files a/lightrag/api/webui/assets/index-D8zGvNlV.js and b/lightrag/api/webui/assets/index-BaHKTcxB.js differ
|
|
lightrag/api/webui/assets/index-CD5HxTy1.css
DELETED
Binary file (55.1 kB)
|
|
lightrag/api/webui/assets/index-f0HMqdqP.css
ADDED
Binary file (57.1 kB). View file
|
|
lightrag/api/webui/index.html
CHANGED
Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ
|
|
lightrag/base.py
CHANGED
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
|
|
112 |
async def index_done_callback(self) -> None:
|
113 |
"""Commit the storage operations after indexing"""
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
@dataclass
|
117 |
class BaseVectorStorage(StorageNameSpace, ABC):
|
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|
127 |
|
128 |
@abstractmethod
|
129 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
130 |
-
"""Insert or update vectors in the storage.
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
@abstractmethod
|
133 |
async def delete_entity(self, entity_name: str) -> None:
|
134 |
-
"""Delete a single entity by its name.
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
@abstractmethod
|
137 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
138 |
-
"""Delete relations for a given entity.
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
@abstractmethod
|
141 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|
161 |
"""
|
162 |
pass
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
@dataclass
|
166 |
class BaseKVStorage(StorageNameSpace, ABC):
|
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|
180 |
|
181 |
@abstractmethod
|
182 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
183 |
-
"""Upsert data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
|
186 |
@dataclass
|
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
205 |
|
206 |
@abstractmethod
|
207 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
208 |
-
"""Get
|
209 |
|
210 |
@abstractmethod
|
211 |
async def get_edge(
|
212 |
self, source_node_id: str, target_node_id: str
|
213 |
) -> dict[str, str] | None:
|
214 |
-
"""Get
|
215 |
|
216 |
@abstractmethod
|
217 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
225 |
async def upsert_edge(
|
226 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
227 |
) -> None:
|
228 |
-
"""Delete a node from the graph.
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
@abstractmethod
|
231 |
async def delete_node(self, node_id: str) -> None:
|
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
243 |
|
244 |
@abstractmethod
|
245 |
async def get_knowledge_graph(
|
246 |
-
self, node_label: str, max_depth: int = 3
|
247 |
) -> KnowledgeGraph:
|
248 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
|
251 |
class DocStatus(str, Enum):
|
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
|
|
297 |
) -> dict[str, DocProcessingStatus]:
|
298 |
"""Get all documents with a specific status"""
|
299 |
|
|
|
|
|
|
|
|
|
300 |
|
301 |
class StoragesStatus(str, Enum):
|
302 |
"""Storages status"""
|
|
|
112 |
async def index_done_callback(self) -> None:
|
113 |
"""Commit the storage operations after indexing"""
|
114 |
|
115 |
+
@abstractmethod
|
116 |
+
async def drop(self) -> dict[str, str]:
|
117 |
+
"""Drop all data from storage and clean up resources
|
118 |
+
|
119 |
+
This abstract method defines the contract for dropping all data from a storage implementation.
|
120 |
+
Each storage type must implement this method to:
|
121 |
+
1. Clear all data from memory and/or external storage
|
122 |
+
2. Remove any associated storage files if applicable
|
123 |
+
3. Reset the storage to its initial state
|
124 |
+
4. Handle cleanup of any resources
|
125 |
+
5. Notify other processes if necessary
|
126 |
+
6. This action should persistent the data to disk immediately.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
dict[str, str]: Operation status and message with the following format:
|
130 |
+
{
|
131 |
+
"status": str, # "success" or "error"
|
132 |
+
"message": str # "data dropped" on success, error details on failure
|
133 |
+
}
|
134 |
+
|
135 |
+
Implementation specific:
|
136 |
+
- On success: return {"status": "success", "message": "data dropped"}
|
137 |
+
- On failure: return {"status": "error", "message": "<error details>"}
|
138 |
+
- If not supported: return {"status": "error", "message": "unsupported"}
|
139 |
+
"""
|
140 |
+
|
141 |
|
142 |
@dataclass
|
143 |
class BaseVectorStorage(StorageNameSpace, ABC):
|
|
|
153 |
|
154 |
@abstractmethod
|
155 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
156 |
+
"""Insert or update vectors in the storage.
|
157 |
+
|
158 |
+
Importance notes for in-memory storage:
|
159 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
160 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
161 |
+
KG-storage-log should be used to avoid data corruption
|
162 |
+
"""
|
163 |
|
164 |
@abstractmethod
|
165 |
async def delete_entity(self, entity_name: str) -> None:
|
166 |
+
"""Delete a single entity by its name.
|
167 |
+
|
168 |
+
Importance notes for in-memory storage:
|
169 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
170 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
171 |
+
KG-storage-log should be used to avoid data corruption
|
172 |
+
"""
|
173 |
|
174 |
@abstractmethod
|
175 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
176 |
+
"""Delete relations for a given entity.
|
177 |
+
|
178 |
+
Importance notes for in-memory storage:
|
179 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
180 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
181 |
+
KG-storage-log should be used to avoid data corruption
|
182 |
+
"""
|
183 |
|
184 |
@abstractmethod
|
185 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
|
205 |
"""
|
206 |
pass
|
207 |
|
208 |
+
@abstractmethod
|
209 |
+
async def delete(self, ids: list[str]):
|
210 |
+
"""Delete vectors with specified IDs
|
211 |
+
|
212 |
+
Importance notes for in-memory storage:
|
213 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
214 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
215 |
+
KG-storage-log should be used to avoid data corruption
|
216 |
+
|
217 |
+
Args:
|
218 |
+
ids: List of vector IDs to be deleted
|
219 |
+
"""
|
220 |
+
|
221 |
|
222 |
@dataclass
|
223 |
class BaseKVStorage(StorageNameSpace, ABC):
|
|
|
237 |
|
238 |
@abstractmethod
|
239 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
240 |
+
"""Upsert data
|
241 |
+
|
242 |
+
Importance notes for in-memory storage:
|
243 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
244 |
+
2. update flags to notify other processes that data persistence is needed
|
245 |
+
"""
|
246 |
+
|
247 |
+
@abstractmethod
|
248 |
+
async def delete(self, ids: list[str]) -> None:
|
249 |
+
"""Delete specific records from storage by their IDs
|
250 |
+
|
251 |
+
Importance notes for in-memory storage:
|
252 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
253 |
+
2. update flags to notify other processes that data persistence is needed
|
254 |
+
|
255 |
+
Args:
|
256 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
None
|
260 |
+
"""
|
261 |
+
|
262 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
263 |
+
"""Delete specific records from storage by cache mode
|
264 |
+
|
265 |
+
Importance notes for in-memory storage:
|
266 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
267 |
+
2. update flags to notify other processes that data persistence is needed
|
268 |
+
|
269 |
+
Args:
|
270 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
True: if the cache drop successfully
|
274 |
+
False: if the cache drop failed, or the cache mode is not supported
|
275 |
+
"""
|
276 |
|
277 |
|
278 |
@dataclass
|
|
|
297 |
|
298 |
@abstractmethod
|
299 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
300 |
+
"""Get node by its label identifier, return only node properties"""
|
301 |
|
302 |
@abstractmethod
|
303 |
async def get_edge(
|
304 |
self, source_node_id: str, target_node_id: str
|
305 |
) -> dict[str, str] | None:
|
306 |
+
"""Get edge properties between two nodes"""
|
307 |
|
308 |
@abstractmethod
|
309 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
|
|
317 |
async def upsert_edge(
|
318 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
319 |
) -> None:
|
320 |
+
"""Delete a node from the graph.
|
321 |
+
|
322 |
+
Importance notes for in-memory storage:
|
323 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
324 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
325 |
+
KG-storage-log should be used to avoid data corruption
|
326 |
+
"""
|
327 |
|
328 |
@abstractmethod
|
329 |
async def delete_node(self, node_id: str) -> None:
|
|
|
341 |
|
342 |
@abstractmethod
|
343 |
async def get_knowledge_graph(
|
344 |
+
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
|
345 |
) -> KnowledgeGraph:
|
346 |
+
"""
|
347 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
node_label: Label of the starting node,* means all nodes
|
351 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
352 |
+
max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
356 |
+
indicating whether the graph was truncated due to max_nodes limit
|
357 |
+
"""
|
358 |
|
359 |
|
360 |
class DocStatus(str, Enum):
|
|
|
406 |
) -> dict[str, DocProcessingStatus]:
|
407 |
"""Get all documents with a specific status"""
|
408 |
|
409 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
410 |
+
"""Drop cache is not supported for Doc Status storage"""
|
411 |
+
return False
|
412 |
+
|
413 |
|
414 |
class StoragesStatus(str, Enum):
|
415 |
"""Storages status"""
|
lightrag/kg/__init__.py
CHANGED
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
|
|
2 |
"KV_STORAGE": {
|
3 |
"implementations": [
|
4 |
"JsonKVStorage",
|
5 |
-
"MongoKVStorage",
|
6 |
"RedisKVStorage",
|
7 |
-
"TiDBKVStorage",
|
8 |
"PGKVStorage",
|
9 |
-
"
|
|
|
10 |
],
|
11 |
"required_methods": ["get_by_id", "upsert"],
|
12 |
},
|
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
|
|
14 |
"implementations": [
|
15 |
"NetworkXStorage",
|
16 |
"Neo4JStorage",
|
17 |
-
"MongoGraphStorage",
|
18 |
-
"TiDBGraphStorage",
|
19 |
-
"AGEStorage",
|
20 |
-
"GremlinStorage",
|
21 |
"PGGraphStorage",
|
22 |
-
"
|
|
|
|
|
|
|
23 |
],
|
24 |
"required_methods": ["upsert_node", "upsert_edge"],
|
25 |
},
|
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
|
28 |
"NanoVectorDBStorage",
|
29 |
"MilvusVectorDBStorage",
|
30 |
"ChromaVectorDBStorage",
|
31 |
-
"TiDBVectorDBStorage",
|
32 |
"PGVectorStorage",
|
33 |
"FaissVectorDBStorage",
|
34 |
"QdrantVectorDBStorage",
|
35 |
-
"OracleVectorDBStorage",
|
36 |
"MongoVectorDBStorage",
|
|
|
37 |
],
|
38 |
"required_methods": ["query", "upsert"],
|
39 |
},
|
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
|
|
41 |
"implementations": [
|
42 |
"JsonDocStatusStorage",
|
43 |
"PGDocStatusStorage",
|
44 |
-
"PGDocStatusStorage",
|
45 |
"MongoDocStatusStorage",
|
46 |
],
|
47 |
"required_methods": ["get_docs_by_status"],
|
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
54 |
"JsonKVStorage": [],
|
55 |
"MongoKVStorage": [],
|
56 |
"RedisKVStorage": ["REDIS_URI"],
|
57 |
-
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
58 |
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
59 |
-
"OracleKVStorage": [
|
60 |
-
"ORACLE_DSN",
|
61 |
-
"ORACLE_USER",
|
62 |
-
"ORACLE_PASSWORD",
|
63 |
-
"ORACLE_CONFIG_DIR",
|
64 |
-
],
|
65 |
# Graph Storage Implementations
|
66 |
"NetworkXStorage": [],
|
67 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
68 |
"MongoGraphStorage": [],
|
69 |
-
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
70 |
"AGEStorage": [
|
71 |
"AGE_POSTGRES_DB",
|
72 |
"AGE_POSTGRES_USER",
|
73 |
"AGE_POSTGRES_PASSWORD",
|
74 |
],
|
75 |
-
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
76 |
"PGGraphStorage": [
|
77 |
"POSTGRES_USER",
|
78 |
"POSTGRES_PASSWORD",
|
79 |
"POSTGRES_DATABASE",
|
80 |
],
|
81 |
-
"OracleGraphStorage": [
|
82 |
-
"ORACLE_DSN",
|
83 |
-
"ORACLE_USER",
|
84 |
-
"ORACLE_PASSWORD",
|
85 |
-
"ORACLE_CONFIG_DIR",
|
86 |
-
],
|
87 |
# Vector Storage Implementations
|
88 |
"NanoVectorDBStorage": [],
|
89 |
"MilvusVectorDBStorage": [],
|
90 |
"ChromaVectorDBStorage": [],
|
91 |
-
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
92 |
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
93 |
"FaissVectorDBStorage": [],
|
94 |
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
95 |
-
"OracleVectorDBStorage": [
|
96 |
-
"ORACLE_DSN",
|
97 |
-
"ORACLE_USER",
|
98 |
-
"ORACLE_PASSWORD",
|
99 |
-
"ORACLE_CONFIG_DIR",
|
100 |
-
],
|
101 |
"MongoVectorDBStorage": [],
|
102 |
# Document Status Storage Implementations
|
103 |
"JsonDocStatusStorage": [],
|
@@ -112,9 +90,6 @@ STORAGES = {
|
|
112 |
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
113 |
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
114 |
"Neo4JStorage": ".kg.neo4j_impl",
|
115 |
-
"OracleKVStorage": ".kg.oracle_impl",
|
116 |
-
"OracleGraphStorage": ".kg.oracle_impl",
|
117 |
-
"OracleVectorDBStorage": ".kg.oracle_impl",
|
118 |
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
119 |
"MongoKVStorage": ".kg.mongo_impl",
|
120 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
@@ -122,14 +97,14 @@ STORAGES = {
|
|
122 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
123 |
"RedisKVStorage": ".kg.redis_impl",
|
124 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
125 |
-
"TiDBKVStorage": ".kg.tidb_impl",
|
126 |
-
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
127 |
-
"TiDBGraphStorage": ".kg.tidb_impl",
|
128 |
"PGKVStorage": ".kg.postgres_impl",
|
129 |
"PGVectorStorage": ".kg.postgres_impl",
|
130 |
"AGEStorage": ".kg.age_impl",
|
131 |
"PGGraphStorage": ".kg.postgres_impl",
|
132 |
-
"GremlinStorage": ".kg.gremlin_impl",
|
133 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
134 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
135 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
|
2 |
"KV_STORAGE": {
|
3 |
"implementations": [
|
4 |
"JsonKVStorage",
|
|
|
5 |
"RedisKVStorage",
|
|
|
6 |
"PGKVStorage",
|
7 |
+
"MongoKVStorage",
|
8 |
+
# "TiDBKVStorage",
|
9 |
],
|
10 |
"required_methods": ["get_by_id", "upsert"],
|
11 |
},
|
|
|
13 |
"implementations": [
|
14 |
"NetworkXStorage",
|
15 |
"Neo4JStorage",
|
|
|
|
|
|
|
|
|
16 |
"PGGraphStorage",
|
17 |
+
# "AGEStorage",
|
18 |
+
# "MongoGraphStorage",
|
19 |
+
# "TiDBGraphStorage",
|
20 |
+
# "GremlinStorage",
|
21 |
],
|
22 |
"required_methods": ["upsert_node", "upsert_edge"],
|
23 |
},
|
|
|
26 |
"NanoVectorDBStorage",
|
27 |
"MilvusVectorDBStorage",
|
28 |
"ChromaVectorDBStorage",
|
|
|
29 |
"PGVectorStorage",
|
30 |
"FaissVectorDBStorage",
|
31 |
"QdrantVectorDBStorage",
|
|
|
32 |
"MongoVectorDBStorage",
|
33 |
+
# "TiDBVectorDBStorage",
|
34 |
],
|
35 |
"required_methods": ["query", "upsert"],
|
36 |
},
|
|
|
38 |
"implementations": [
|
39 |
"JsonDocStatusStorage",
|
40 |
"PGDocStatusStorage",
|
|
|
41 |
"MongoDocStatusStorage",
|
42 |
],
|
43 |
"required_methods": ["get_docs_by_status"],
|
|
|
50 |
"JsonKVStorage": [],
|
51 |
"MongoKVStorage": [],
|
52 |
"RedisKVStorage": ["REDIS_URI"],
|
53 |
+
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
54 |
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# Graph Storage Implementations
|
56 |
"NetworkXStorage": [],
|
57 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
58 |
"MongoGraphStorage": [],
|
59 |
+
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
60 |
"AGEStorage": [
|
61 |
"AGE_POSTGRES_DB",
|
62 |
"AGE_POSTGRES_USER",
|
63 |
"AGE_POSTGRES_PASSWORD",
|
64 |
],
|
65 |
+
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
66 |
"PGGraphStorage": [
|
67 |
"POSTGRES_USER",
|
68 |
"POSTGRES_PASSWORD",
|
69 |
"POSTGRES_DATABASE",
|
70 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Vector Storage Implementations
|
72 |
"NanoVectorDBStorage": [],
|
73 |
"MilvusVectorDBStorage": [],
|
74 |
"ChromaVectorDBStorage": [],
|
75 |
+
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
76 |
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
77 |
"FaissVectorDBStorage": [],
|
78 |
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
"MongoVectorDBStorage": [],
|
80 |
# Document Status Storage Implementations
|
81 |
"JsonDocStatusStorage": [],
|
|
|
90 |
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
91 |
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
92 |
"Neo4JStorage": ".kg.neo4j_impl",
|
|
|
|
|
|
|
93 |
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
94 |
"MongoKVStorage": ".kg.mongo_impl",
|
95 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
|
97 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
98 |
"RedisKVStorage": ".kg.redis_impl",
|
99 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
100 |
+
# "TiDBKVStorage": ".kg.tidb_impl",
|
101 |
+
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
102 |
+
# "TiDBGraphStorage": ".kg.tidb_impl",
|
103 |
"PGKVStorage": ".kg.postgres_impl",
|
104 |
"PGVectorStorage": ".kg.postgres_impl",
|
105 |
"AGEStorage": ".kg.age_impl",
|
106 |
"PGGraphStorage": ".kg.postgres_impl",
|
107 |
+
# "GremlinStorage": ".kg.gremlin_impl",
|
108 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
109 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
110 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
lightrag/kg/age_impl.py
CHANGED
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
|
34 |
if not pm.is_installed("asyncpg"):
|
35 |
pm.install("asyncpg")
|
36 |
|
37 |
-
import psycopg
|
38 |
-
from psycopg.rows import namedtuple_row
|
39 |
-
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
40 |
|
41 |
|
42 |
class AGEQueryException(Exception):
|
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
|
|
871 |
async def index_done_callback(self) -> None:
|
872 |
# AGES handles persistence automatically
|
873 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if not pm.is_installed("asyncpg"):
|
35 |
pm.install("asyncpg")
|
36 |
|
37 |
+
import psycopg # type: ignore
|
38 |
+
from psycopg.rows import namedtuple_row # type: ignore
|
39 |
+
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
|
40 |
|
41 |
|
42 |
class AGEQueryException(Exception):
|
|
|
871 |
async def index_done_callback(self) -> None:
|
872 |
# AGES handles persistence automatically
|
873 |
pass
|
874 |
+
|
875 |
+
async def drop(self) -> dict[str, str]:
|
876 |
+
"""Drop the storage by removing all nodes and relationships in the graph.
|
877 |
+
|
878 |
+
Returns:
|
879 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
880 |
+
"""
|
881 |
+
try:
|
882 |
+
query = """
|
883 |
+
MATCH (n)
|
884 |
+
DETACH DELETE n
|
885 |
+
"""
|
886 |
+
await self._query(query)
|
887 |
+
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
888 |
+
return {"status": "success", "message": "graph data dropped"}
|
889 |
+
except Exception as e:
|
890 |
+
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
891 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/chroma_impl.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import asyncio
|
|
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
4 |
import numpy as np
|
@@ -10,8 +11,8 @@ import pipmaster as pm
|
|
10 |
if not pm.is_installed("chromadb"):
|
11 |
pm.install("chromadb")
|
12 |
|
13 |
-
from chromadb import HttpClient, PersistentClient
|
14 |
-
from chromadb.config import Settings
|
15 |
|
16 |
|
17 |
@final
|
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
335 |
except Exception as e:
|
336 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
337 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
+
import os
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Any, final
|
5 |
import numpy as np
|
|
|
11 |
if not pm.is_installed("chromadb"):
|
12 |
pm.install("chromadb")
|
13 |
|
14 |
+
from chromadb import HttpClient, PersistentClient # type: ignore
|
15 |
+
from chromadb.config import Settings # type: ignore
|
16 |
|
17 |
|
18 |
@final
|
|
|
336 |
except Exception as e:
|
337 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
338 |
return []
|
339 |
+
|
340 |
+
async def drop(self) -> dict[str, str]:
|
341 |
+
"""Drop all vector data from storage and clean up resources
|
342 |
+
|
343 |
+
This method will delete all documents from the ChromaDB collection.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
dict[str, str]: Operation status and message
|
347 |
+
- On success: {"status": "success", "message": "data dropped"}
|
348 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
349 |
+
"""
|
350 |
+
try:
|
351 |
+
# Get all IDs in the collection
|
352 |
+
result = self._collection.get(include=[])
|
353 |
+
if result and result["ids"] and len(result["ids"]) > 0:
|
354 |
+
# Delete all documents
|
355 |
+
self._collection.delete(ids=result["ids"])
|
356 |
+
|
357 |
+
logger.info(
|
358 |
+
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
359 |
+
)
|
360 |
+
return {"status": "success", "message": "data dropped"}
|
361 |
+
except Exception as e:
|
362 |
+
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
363 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/faiss_impl.py
CHANGED
@@ -11,16 +11,20 @@ import pipmaster as pm
|
|
11 |
from lightrag.utils import logger, compute_mdhash_id
|
12 |
from lightrag.base import BaseVectorStorage
|
13 |
|
14 |
-
if not pm.is_installed("faiss"):
|
15 |
-
pm.install("faiss")
|
16 |
-
|
17 |
-
import faiss # type: ignore
|
18 |
from .shared_storage import (
|
19 |
get_storage_lock,
|
20 |
get_update_flag,
|
21 |
set_all_update_flags,
|
22 |
)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
@final
|
26 |
@dataclass
|
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
217 |
async def delete(self, ids: list[str]):
|
218 |
"""
|
219 |
Delete vectors for the provided custom IDs.
|
|
|
|
|
|
|
|
|
|
|
220 |
"""
|
221 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
222 |
to_remove = []
|
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
232 |
)
|
233 |
|
234 |
async def delete_entity(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
236 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
237 |
await self.delete([entity_id])
|
238 |
|
239 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
240 |
"""
|
241 |
-
|
|
|
|
|
|
|
242 |
"""
|
243 |
logger.debug(f"Searching relations for entity {entity_name}")
|
244 |
relations = []
|
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
429 |
results.append({**metadata, "id": metadata.get("__id__")})
|
430 |
|
431 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from lightrag.utils import logger, compute_mdhash_id
|
12 |
from lightrag.base import BaseVectorStorage
|
13 |
|
|
|
|
|
|
|
|
|
14 |
from .shared_storage import (
|
15 |
get_storage_lock,
|
16 |
get_update_flag,
|
17 |
set_all_update_flags,
|
18 |
)
|
19 |
|
20 |
+
import faiss # type: ignore
|
21 |
+
|
22 |
+
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
23 |
+
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
24 |
+
|
25 |
+
if not pm.is_installed(FAISS_PACKAGE):
|
26 |
+
pm.install(FAISS_PACKAGE)
|
27 |
+
|
28 |
|
29 |
@final
|
30 |
@dataclass
|
|
|
221 |
async def delete(self, ids: list[str]):
|
222 |
"""
|
223 |
Delete vectors for the provided custom IDs.
|
224 |
+
|
225 |
+
Importance notes:
|
226 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
227 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
228 |
+
KG-storage-log should be used to avoid data corruption
|
229 |
"""
|
230 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
231 |
to_remove = []
|
|
|
241 |
)
|
242 |
|
243 |
async def delete_entity(self, entity_name: str) -> None:
|
244 |
+
"""
|
245 |
+
Importance notes:
|
246 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
247 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
248 |
+
KG-storage-log should be used to avoid data corruption
|
249 |
+
"""
|
250 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
251 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
252 |
await self.delete([entity_id])
|
253 |
|
254 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
255 |
"""
|
256 |
+
Importance notes:
|
257 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
258 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
259 |
+
KG-storage-log should be used to avoid data corruption
|
260 |
"""
|
261 |
logger.debug(f"Searching relations for entity {entity_name}")
|
262 |
relations = []
|
|
|
447 |
results.append({**metadata, "id": metadata.get("__id__")})
|
448 |
|
449 |
return results
|
450 |
+
|
451 |
+
async def drop(self) -> dict[str, str]:
|
452 |
+
"""Drop all vector data from storage and clean up resources
|
453 |
+
|
454 |
+
This method will:
|
455 |
+
1. Remove the vector database storage file if it exists
|
456 |
+
2. Reinitialize the vector database client
|
457 |
+
3. Update flags to notify other processes
|
458 |
+
4. Changes is persisted to disk immediately
|
459 |
+
|
460 |
+
This method will remove all vectors from the Faiss index and delete the storage files.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
dict[str, str]: Operation status and message
|
464 |
+
- On success: {"status": "success", "message": "data dropped"}
|
465 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
466 |
+
"""
|
467 |
+
try:
|
468 |
+
async with self._storage_lock:
|
469 |
+
# Reset the index
|
470 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
471 |
+
self._id_to_meta = {}
|
472 |
+
|
473 |
+
# Remove storage files if they exist
|
474 |
+
if os.path.exists(self._faiss_index_file):
|
475 |
+
os.remove(self._faiss_index_file)
|
476 |
+
if os.path.exists(self._meta_file):
|
477 |
+
os.remove(self._meta_file)
|
478 |
+
|
479 |
+
self._id_to_meta = {}
|
480 |
+
self._load_faiss_index()
|
481 |
+
|
482 |
+
# Notify other processes
|
483 |
+
await set_all_update_flags(self.namespace)
|
484 |
+
self.storage_updated.value = False
|
485 |
+
|
486 |
+
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
|
487 |
+
return {"status": "success", "message": "data dropped"}
|
488 |
+
except Exception as e:
|
489 |
+
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
|
490 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/gremlin_impl.py
CHANGED
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
|
|
24 |
if not pm.is_installed("gremlinpython"):
|
25 |
pm.install("gremlinpython")
|
26 |
|
27 |
-
from gremlin_python.driver import client, serializer
|
28 |
-
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
29 |
-
from gremlin_python.driver.protocol import GremlinServerError
|
30 |
|
31 |
|
32 |
@final
|
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
|
|
695 |
except Exception as e:
|
696 |
logger.error(f"Error during edge deletion: {str(e)}")
|
697 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
if not pm.is_installed("gremlinpython"):
|
25 |
pm.install("gremlinpython")
|
26 |
|
27 |
+
from gremlin_python.driver import client, serializer # type: ignore
|
28 |
+
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
|
29 |
+
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
|
30 |
|
31 |
|
32 |
@final
|
|
|
695 |
except Exception as e:
|
696 |
logger.error(f"Error during edge deletion: {str(e)}")
|
697 |
raise
|
698 |
+
|
699 |
+
async def drop(self) -> dict[str, str]:
|
700 |
+
"""Drop the storage by removing all nodes and relationships in the graph.
|
701 |
+
|
702 |
+
This function deletes all nodes with the specified graph name property,
|
703 |
+
which automatically removes all associated edges.
|
704 |
+
|
705 |
+
Returns:
|
706 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
707 |
+
"""
|
708 |
+
try:
|
709 |
+
query = f"""g
|
710 |
+
.V().has('graph', {self.graph_name})
|
711 |
+
.drop()
|
712 |
+
"""
|
713 |
+
await self._query(query)
|
714 |
+
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
715 |
+
return {"status": "success", "message": "graph data dropped"}
|
716 |
+
except Exception as e:
|
717 |
+
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
718 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/json_doc_status_impl.py
CHANGED
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
109 |
await clear_all_update_flags(self.namespace)
|
110 |
|
111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
112 |
if not data:
|
113 |
return
|
114 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
122 |
async with self._storage_lock:
|
123 |
return self._data.get(id)
|
124 |
|
125 |
-
async def delete(self, doc_ids: list[str]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
async with self._storage_lock:
|
|
|
127 |
for doc_id in doc_ids:
|
128 |
-
self._data.pop(doc_id, None)
|
129 |
-
|
130 |
-
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
await clear_all_update_flags(self.namespace)
|
110 |
|
111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
112 |
+
"""
|
113 |
+
Importance notes for in-memory storage:
|
114 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
115 |
+
2. update flags to notify other processes that data persistence is needed
|
116 |
+
"""
|
117 |
if not data:
|
118 |
return
|
119 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
|
127 |
async with self._storage_lock:
|
128 |
return self._data.get(id)
|
129 |
|
130 |
+
async def delete(self, doc_ids: list[str]) -> None:
|
131 |
+
"""Delete specific records from storage by their IDs
|
132 |
+
|
133 |
+
Importance notes for in-memory storage:
|
134 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
135 |
+
2. update flags to notify other processes that data persistence is needed
|
136 |
+
|
137 |
+
Args:
|
138 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
None
|
142 |
+
"""
|
143 |
async with self._storage_lock:
|
144 |
+
any_deleted = False
|
145 |
for doc_id in doc_ids:
|
146 |
+
result = self._data.pop(doc_id, None)
|
147 |
+
if result is not None:
|
148 |
+
any_deleted = True
|
149 |
|
150 |
+
if any_deleted:
|
151 |
+
await set_all_update_flags(self.namespace)
|
152 |
+
|
153 |
+
async def drop(self) -> dict[str, str]:
|
154 |
+
"""Drop all document status data from storage and clean up resources
|
155 |
+
|
156 |
+
This method will:
|
157 |
+
1. Clear all document status data from memory
|
158 |
+
2. Update flags to notify other processes
|
159 |
+
3. Trigger index_done_callback to save the empty state
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
dict[str, str]: Operation status and message
|
163 |
+
- On success: {"status": "success", "message": "data dropped"}
|
164 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
165 |
+
"""
|
166 |
+
try:
|
167 |
+
async with self._storage_lock:
|
168 |
+
self._data.clear()
|
169 |
+
await set_all_update_flags(self.namespace)
|
170 |
+
|
171 |
+
await self.index_done_callback()
|
172 |
+
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
173 |
+
return {"status": "success", "message": "data dropped"}
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
176 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
|
|
114 |
return set(keys) - set(self._data.keys())
|
115 |
|
116 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
117 |
if not data:
|
118 |
return
|
119 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
|
|
122 |
await set_all_update_flags(self.namespace)
|
123 |
|
124 |
async def delete(self, ids: list[str]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
async with self._storage_lock:
|
|
|
126 |
for doc_id in ids:
|
127 |
-
self._data.pop(doc_id, None)
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
return set(keys) - set(self._data.keys())
|
115 |
|
116 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
117 |
+
"""
|
118 |
+
Importance notes for in-memory storage:
|
119 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
120 |
+
2. update flags to notify other processes that data persistence is needed
|
121 |
+
"""
|
122 |
if not data:
|
123 |
return
|
124 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
|
127 |
await set_all_update_flags(self.namespace)
|
128 |
|
129 |
async def delete(self, ids: list[str]) -> None:
|
130 |
+
"""Delete specific records from storage by their IDs
|
131 |
+
|
132 |
+
Importance notes for in-memory storage:
|
133 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
134 |
+
2. update flags to notify other processes that data persistence is needed
|
135 |
+
|
136 |
+
Args:
|
137 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
None
|
141 |
+
"""
|
142 |
async with self._storage_lock:
|
143 |
+
any_deleted = False
|
144 |
for doc_id in ids:
|
145 |
+
result = self._data.pop(doc_id, None)
|
146 |
+
if result is not None:
|
147 |
+
any_deleted = True
|
148 |
+
|
149 |
+
if any_deleted:
|
150 |
+
await set_all_update_flags(self.namespace)
|
151 |
+
|
152 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
153 |
+
"""Delete specific records from storage by by cache mode
|
154 |
+
|
155 |
+
Importance notes for in-memory storage:
|
156 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
157 |
+
2. update flags to notify other processes that data persistence is needed
|
158 |
+
|
159 |
+
Args:
|
160 |
+
ids (list[str]): List of cache mode to be drop from storage
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
True: if the cache drop successfully
|
164 |
+
False: if the cache drop failed
|
165 |
+
"""
|
166 |
+
if not modes:
|
167 |
+
return False
|
168 |
+
|
169 |
+
try:
|
170 |
+
await self.delete(modes)
|
171 |
+
return True
|
172 |
+
except Exception:
|
173 |
+
return False
|
174 |
+
|
175 |
+
async def drop(self) -> dict[str, str]:
|
176 |
+
"""Drop all data from storage and clean up resources
|
177 |
+
This action will persistent the data to disk immediately.
|
178 |
+
|
179 |
+
This method will:
|
180 |
+
1. Clear all data from memory
|
181 |
+
2. Update flags to notify other processes
|
182 |
+
3. Trigger index_done_callback to save the empty state
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
dict[str, str]: Operation status and message
|
186 |
+
- On success: {"status": "success", "message": "data dropped"}
|
187 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
188 |
+
"""
|
189 |
+
try:
|
190 |
+
async with self._storage_lock:
|
191 |
+
self._data.clear()
|
192 |
+
await set_all_update_flags(self.namespace)
|
193 |
+
|
194 |
+
await self.index_done_callback()
|
195 |
+
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
196 |
+
return {"status": "success", "message": "data dropped"}
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
199 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/milvus_impl.py
CHANGED
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
|
15 |
pm.install("pymilvus")
|
16 |
|
17 |
import configparser
|
18 |
-
from pymilvus import MilvusClient
|
19 |
|
20 |
config = configparser.ConfigParser()
|
21 |
config.read("config.ini", "utf-8")
|
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|
287 |
except Exception as e:
|
288 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
289 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
pm.install("pymilvus")
|
16 |
|
17 |
import configparser
|
18 |
+
from pymilvus import MilvusClient # type: ignore
|
19 |
|
20 |
config = configparser.ConfigParser()
|
21 |
config.read("config.ini", "utf-8")
|
|
|
287 |
except Exception as e:
|
288 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
289 |
return []
|
290 |
+
|
291 |
+
async def drop(self) -> dict[str, str]:
|
292 |
+
"""Drop all vector data from storage and clean up resources
|
293 |
+
|
294 |
+
This method will delete all data from the Milvus collection.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
dict[str, str]: Operation status and message
|
298 |
+
- On success: {"status": "success", "message": "data dropped"}
|
299 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
300 |
+
"""
|
301 |
+
try:
|
302 |
+
# Drop the collection and recreate it
|
303 |
+
if self._client.has_collection(self.namespace):
|
304 |
+
self._client.drop_collection(self.namespace)
|
305 |
+
|
306 |
+
# Recreate the collection
|
307 |
+
MilvusVectorDBStorage.create_collection_if_not_exist(
|
308 |
+
self._client,
|
309 |
+
self.namespace,
|
310 |
+
dimension=self.embedding_func.embedding_dim,
|
311 |
+
)
|
312 |
+
|
313 |
+
logger.info(
|
314 |
+
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
315 |
+
)
|
316 |
+
return {"status": "success", "message": "data dropped"}
|
317 |
+
except Exception as e:
|
318 |
+
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
|
319 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/mongo_impl.py
CHANGED
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
|
|
25 |
if not pm.is_installed("motor"):
|
26 |
pm.install("motor")
|
27 |
|
28 |
-
from motor.motor_asyncio import (
|
29 |
AsyncIOMotorClient,
|
30 |
AsyncIOMotorDatabase,
|
31 |
AsyncIOMotorCollection,
|
32 |
)
|
33 |
-
from pymongo.operations import SearchIndexModel
|
34 |
-
from pymongo.errors import PyMongoError
|
35 |
|
36 |
config = configparser.ConfigParser()
|
37 |
config.read("config.ini", "utf-8")
|
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
|
|
150 |
# Mongo handles persistence automatically
|
151 |
pass
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
@final
|
155 |
@dataclass
|
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
230 |
# Mongo handles persistence automatically
|
231 |
pass
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
@final
|
235 |
@dataclass
|
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
840 |
|
841 |
logger.debug(f"Successfully deleted edges: {edges}")
|
842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
843 |
|
844 |
@final
|
845 |
@dataclass
|
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
1127 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
1128 |
return []
|
1129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1130 |
|
1131 |
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
1132 |
collection_names = await db.list_collection_names()
|
|
|
25 |
if not pm.is_installed("motor"):
|
26 |
pm.install("motor")
|
27 |
|
28 |
+
from motor.motor_asyncio import ( # type: ignore
|
29 |
AsyncIOMotorClient,
|
30 |
AsyncIOMotorDatabase,
|
31 |
AsyncIOMotorCollection,
|
32 |
)
|
33 |
+
from pymongo.operations import SearchIndexModel # type: ignore
|
34 |
+
from pymongo.errors import PyMongoError # type: ignore
|
35 |
|
36 |
config = configparser.ConfigParser()
|
37 |
config.read("config.ini", "utf-8")
|
|
|
150 |
# Mongo handles persistence automatically
|
151 |
pass
|
152 |
|
153 |
+
async def delete(self, ids: list[str]) -> None:
|
154 |
+
"""Delete documents with specified IDs
|
155 |
+
|
156 |
+
Args:
|
157 |
+
ids: List of document IDs to be deleted
|
158 |
+
"""
|
159 |
+
if not ids:
|
160 |
+
return
|
161 |
+
|
162 |
+
try:
|
163 |
+
result = await self._data.delete_many({"_id": {"$in": ids}})
|
164 |
+
logger.info(
|
165 |
+
f"Deleted {result.deleted_count} documents from {self.namespace}"
|
166 |
+
)
|
167 |
+
except PyMongoError as e:
|
168 |
+
logger.error(f"Error deleting documents from {self.namespace}: {e}")
|
169 |
+
|
170 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
171 |
+
"""Delete specific records from storage by cache mode
|
172 |
+
|
173 |
+
Args:
|
174 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
bool: True if successful, False otherwise
|
178 |
+
"""
|
179 |
+
if not modes:
|
180 |
+
return False
|
181 |
+
|
182 |
+
try:
|
183 |
+
# Build regex pattern to match documents with the specified modes
|
184 |
+
pattern = f"^({'|'.join(modes)})_"
|
185 |
+
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
186 |
+
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
187 |
+
return True
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
190 |
+
return False
|
191 |
+
|
192 |
+
async def drop(self) -> dict[str, str]:
|
193 |
+
"""Drop the storage by removing all documents in the collection.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
197 |
+
"""
|
198 |
+
try:
|
199 |
+
result = await self._data.delete_many({})
|
200 |
+
deleted_count = result.deleted_count
|
201 |
+
|
202 |
+
logger.info(
|
203 |
+
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
204 |
+
)
|
205 |
+
return {
|
206 |
+
"status": "success",
|
207 |
+
"message": f"{deleted_count} documents dropped",
|
208 |
+
}
|
209 |
+
except PyMongoError as e:
|
210 |
+
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
211 |
+
return {"status": "error", "message": str(e)}
|
212 |
+
|
213 |
|
214 |
@final
|
215 |
@dataclass
|
|
|
290 |
# Mongo handles persistence automatically
|
291 |
pass
|
292 |
|
293 |
+
async def drop(self) -> dict[str, str]:
|
294 |
+
"""Drop the storage by removing all documents in the collection.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
298 |
+
"""
|
299 |
+
try:
|
300 |
+
result = await self._data.delete_many({})
|
301 |
+
deleted_count = result.deleted_count
|
302 |
+
|
303 |
+
logger.info(
|
304 |
+
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
305 |
+
)
|
306 |
+
return {
|
307 |
+
"status": "success",
|
308 |
+
"message": f"{deleted_count} documents dropped",
|
309 |
+
}
|
310 |
+
except PyMongoError as e:
|
311 |
+
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
312 |
+
return {"status": "error", "message": str(e)}
|
313 |
+
|
314 |
|
315 |
@final
|
316 |
@dataclass
|
|
|
921 |
|
922 |
logger.debug(f"Successfully deleted edges: {edges}")
|
923 |
|
924 |
+
async def drop(self) -> dict[str, str]:
|
925 |
+
"""Drop the storage by removing all documents in the collection.
|
926 |
+
|
927 |
+
Returns:
|
928 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
929 |
+
"""
|
930 |
+
try:
|
931 |
+
result = await self.collection.delete_many({})
|
932 |
+
deleted_count = result.deleted_count
|
933 |
+
|
934 |
+
logger.info(
|
935 |
+
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
936 |
+
)
|
937 |
+
return {
|
938 |
+
"status": "success",
|
939 |
+
"message": f"{deleted_count} documents dropped",
|
940 |
+
}
|
941 |
+
except PyMongoError as e:
|
942 |
+
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
943 |
+
return {"status": "error", "message": str(e)}
|
944 |
+
|
945 |
|
946 |
@final
|
947 |
@dataclass
|
|
|
1229 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
1230 |
return []
|
1231 |
|
1232 |
+
async def drop(self) -> dict[str, str]:
|
1233 |
+
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
1234 |
+
|
1235 |
+
Returns:
|
1236 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
1237 |
+
"""
|
1238 |
+
try:
|
1239 |
+
# Delete all documents
|
1240 |
+
result = await self._data.delete_many({})
|
1241 |
+
deleted_count = result.deleted_count
|
1242 |
+
|
1243 |
+
# Recreate vector index
|
1244 |
+
await self.create_vector_index_if_not_exists()
|
1245 |
+
|
1246 |
+
logger.info(
|
1247 |
+
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
1248 |
+
)
|
1249 |
+
return {
|
1250 |
+
"status": "success",
|
1251 |
+
"message": f"{deleted_count} documents dropped and vector index recreated",
|
1252 |
+
}
|
1253 |
+
except PyMongoError as e:
|
1254 |
+
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
|
1255 |
+
return {"status": "error", "message": str(e)}
|
1256 |
+
|
1257 |
|
1258 |
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
1259 |
collection_names = await db.list_collection_names()
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
78 |
return self._client
|
79 |
|
80 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
82 |
if not data:
|
83 |
return
|
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
146 |
async def delete(self, ids: list[str]):
|
147 |
"""Delete vectors with specified IDs
|
148 |
|
|
|
|
|
|
|
|
|
|
|
149 |
Args:
|
150 |
ids: List of vector IDs to be deleted
|
151 |
"""
|
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
159 |
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
160 |
|
161 |
async def delete_entity(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
try:
|
163 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
164 |
logger.debug(
|
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
176 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
177 |
|
178 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
try:
|
180 |
client = await self._get_client()
|
181 |
storage = getattr(client, "_NanoVectorDB__storage")
|
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
280 |
|
281 |
client = await self._get_client()
|
282 |
return client.get(ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return self._client
|
79 |
|
80 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
81 |
+
"""
|
82 |
+
Importance notes:
|
83 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
84 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
85 |
+
KG-storage-log should be used to avoid data corruption
|
86 |
+
"""
|
87 |
+
|
88 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
89 |
if not data:
|
90 |
return
|
|
|
153 |
async def delete(self, ids: list[str]):
|
154 |
"""Delete vectors with specified IDs
|
155 |
|
156 |
+
Importance notes:
|
157 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
158 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
159 |
+
KG-storage-log should be used to avoid data corruption
|
160 |
+
|
161 |
Args:
|
162 |
ids: List of vector IDs to be deleted
|
163 |
"""
|
|
|
171 |
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
172 |
|
173 |
async def delete_entity(self, entity_name: str) -> None:
|
174 |
+
"""
|
175 |
+
Importance notes:
|
176 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
177 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
178 |
+
KG-storage-log should be used to avoid data corruption
|
179 |
+
"""
|
180 |
+
|
181 |
try:
|
182 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
183 |
logger.debug(
|
|
|
195 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
196 |
|
197 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
198 |
+
"""
|
199 |
+
Importance notes:
|
200 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
201 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
202 |
+
KG-storage-log should be used to avoid data corruption
|
203 |
+
"""
|
204 |
+
|
205 |
try:
|
206 |
client = await self._get_client()
|
207 |
storage = getattr(client, "_NanoVectorDB__storage")
|
|
|
306 |
|
307 |
client = await self._get_client()
|
308 |
return client.get(ids)
|
309 |
+
|
310 |
+
async def drop(self) -> dict[str, str]:
|
311 |
+
"""Drop all vector data from storage and clean up resources
|
312 |
+
|
313 |
+
This method will:
|
314 |
+
1. Remove the vector database storage file if it exists
|
315 |
+
2. Reinitialize the vector database client
|
316 |
+
3. Update flags to notify other processes
|
317 |
+
4. Changes is persisted to disk immediately
|
318 |
+
|
319 |
+
This method is intended for use in scenarios where all data needs to be removed,
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
dict[str, str]: Operation status and message
|
323 |
+
- On success: {"status": "success", "message": "data dropped"}
|
324 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
325 |
+
"""
|
326 |
+
try:
|
327 |
+
async with self._storage_lock:
|
328 |
+
# delete _client_file_name
|
329 |
+
if os.path.exists(self._client_file_name):
|
330 |
+
os.remove(self._client_file_name)
|
331 |
+
|
332 |
+
self._client = NanoVectorDB(
|
333 |
+
self.embedding_func.embedding_dim,
|
334 |
+
storage_file=self._client_file_name,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Notify other processes that data has been updated
|
338 |
+
await set_all_update_flags(self.namespace)
|
339 |
+
# Reset own update flag to avoid self-reloading
|
340 |
+
self.storage_updated.value = False
|
341 |
+
|
342 |
+
logger.info(
|
343 |
+
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
344 |
+
)
|
345 |
+
return {"status": "success", "message": "data dropped"}
|
346 |
+
except Exception as e:
|
347 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
348 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/neo4j_impl.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
import asyncio
|
2 |
import inspect
|
3 |
import os
|
4 |
import re
|
5 |
from dataclasses import dataclass
|
6 |
-
from typing import Any, final
|
7 |
import numpy as np
|
8 |
import configparser
|
9 |
|
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
|
29 |
exceptions as neo4jExceptions,
|
30 |
AsyncDriver,
|
31 |
AsyncManagedTransaction,
|
32 |
-
GraphDatabase,
|
33 |
)
|
34 |
|
35 |
config = configparser.ConfigParser()
|
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
52 |
embedding_func=embedding_func,
|
53 |
)
|
54 |
self._driver = None
|
55 |
-
self._driver_lock = asyncio.Lock()
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
58 |
USERNAME = os.environ.get(
|
59 |
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
86 |
),
|
87 |
)
|
88 |
DATABASE = os.environ.get(
|
89 |
-
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
90 |
)
|
91 |
|
92 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
|
|
98 |
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
99 |
)
|
100 |
|
101 |
-
# Try to connect to the database
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
106 |
-
connection_timeout=CONNECTION_TIMEOUT,
|
107 |
-
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
108 |
-
) as _sync_driver:
|
109 |
-
for database in (DATABASE, None):
|
110 |
-
self._DATABASE = database
|
111 |
-
connected = False
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
raise e
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
|
|
|
|
|
|
132 |
try:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
)
|
137 |
-
logger.info(f"{database} at {URI} created".capitalize())
|
138 |
-
connected = True
|
139 |
-
except (
|
140 |
-
neo4jExceptions.ClientError,
|
141 |
-
neo4jExceptions.DatabaseError,
|
142 |
-
) as e:
|
143 |
-
if (
|
144 |
-
e.code
|
145 |
-
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
146 |
-
) or (
|
147 |
-
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
148 |
-
):
|
149 |
-
if database is not None:
|
150 |
-
logger.warning(
|
151 |
-
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
152 |
-
)
|
153 |
-
if database is None:
|
154 |
-
logger.error(f"Failed to create {database} at {URI}")
|
155 |
-
raise e
|
156 |
|
157 |
-
|
158 |
-
break
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
-
async def
|
166 |
"""Close the Neo4j driver and release all resources"""
|
167 |
if self._driver:
|
168 |
await self._driver.close()
|
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
170 |
|
171 |
async def __aexit__(self, exc_type, exc, tb):
|
172 |
"""Ensure driver is closed when context manager exits"""
|
173 |
-
await self.
|
174 |
|
175 |
async def index_done_callback(self) -> None:
|
176 |
# Noe4J handles persistence automatically
|
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
243 |
raise
|
244 |
|
245 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
246 |
-
"""Get node by its label identifier
|
247 |
|
248 |
Args:
|
249 |
node_id: The node label to look up
|
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|
428 |
logger.debug(
|
429 |
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
430 |
)
|
431 |
-
# Return
|
432 |
-
return
|
433 |
-
"weight": 0.0,
|
434 |
-
"source_id": None,
|
435 |
-
"description": None,
|
436 |
-
"keywords": None,
|
437 |
-
}
|
438 |
finally:
|
439 |
await result.consume() # Ensure result is fully consumed
|
440 |
|
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
526 |
"""
|
527 |
properties = node_data
|
528 |
entity_type = properties["entity_type"]
|
529 |
-
entity_id = properties["entity_id"]
|
530 |
if "entity_id" not in properties:
|
531 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
532 |
|
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
536 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
537 |
query = (
|
538 |
"""
|
539 |
-
MERGE (n:base {entity_id: $
|
540 |
SET n += $properties
|
541 |
SET n:`%s`
|
542 |
"""
|
543 |
% entity_type
|
544 |
)
|
545 |
-
result = await tx.run(
|
|
|
|
|
546 |
logger.debug(
|
547 |
-
f"Upserted node with entity_id '{
|
548 |
)
|
549 |
await result.consume() # Ensure result is fully consumed
|
550 |
|
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
|
|
622 |
self,
|
623 |
node_label: str,
|
624 |
max_depth: int = 3,
|
625 |
-
|
626 |
-
inclusive: bool = False,
|
627 |
) -> KnowledgeGraph:
|
628 |
"""
|
629 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
630 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
631 |
-
When reducing the number of nodes, the prioritization criteria are as follows:
|
632 |
-
1. min_degree does not affect nodes directly connected to the matching nodes
|
633 |
-
2. Label matching nodes take precedence
|
634 |
-
3. Followed by nodes directly connected to the matching nodes
|
635 |
-
4. Finally, the degree of the nodes
|
636 |
|
637 |
Args:
|
638 |
-
node_label: Label of the starting node
|
639 |
-
max_depth: Maximum depth of the subgraph
|
640 |
-
|
641 |
-
|
642 |
Returns:
|
643 |
-
KnowledgeGraph
|
|
|
644 |
"""
|
645 |
result = KnowledgeGraph()
|
646 |
seen_nodes = set()
|
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|
651 |
) as session:
|
652 |
try:
|
653 |
if node_label == "*":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
main_query = """
|
655 |
MATCH (n)
|
656 |
OPTIONAL MATCH (n)-[r]-()
|
657 |
WITH n, COALESCE(count(r), 0) AS degree
|
658 |
-
WHERE degree >= $min_degree
|
659 |
ORDER BY degree DESC
|
660 |
LIMIT $max_nodes
|
661 |
WITH collect({node: n}) AS filtered_nodes
|
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
666 |
RETURN filtered_nodes AS node_info,
|
667 |
collect(DISTINCT r) AS relationships
|
668 |
"""
|
669 |
-
result_set =
|
670 |
-
|
671 |
-
|
672 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
|
674 |
else:
|
675 |
-
#
|
676 |
-
|
|
|
677 |
MATCH (start)
|
678 |
-
WHERE
|
679 |
-
CASE
|
680 |
-
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
681 |
-
ELSE start.entity_id = $entity_id
|
682 |
-
END
|
683 |
WITH start
|
684 |
CALL apoc.path.subgraphAll(start, {
|
685 |
relationshipFilter: '',
|
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
|
|
688 |
bfs: true
|
689 |
})
|
690 |
YIELD nodes, relationships
|
691 |
-
WITH
|
692 |
UNWIND nodes AS node
|
693 |
-
|
694 |
-
|
695 |
-
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
696 |
-
ORDER BY
|
697 |
-
CASE
|
698 |
-
WHEN node = start THEN 3
|
699 |
-
WHEN EXISTS((start)--(node)) THEN 2
|
700 |
-
ELSE 1
|
701 |
-
END DESC,
|
702 |
-
degree DESC
|
703 |
-
LIMIT $max_nodes
|
704 |
-
WITH collect({node: node}) AS filtered_nodes
|
705 |
-
UNWIND filtered_nodes AS node_info
|
706 |
-
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
707 |
-
OPTIONAL MATCH (a)-[r]-(b)
|
708 |
-
WHERE a IN kept_nodes AND b IN kept_nodes
|
709 |
-
RETURN filtered_nodes AS node_info,
|
710 |
-
collect(DISTINCT r) AS relationships
|
711 |
"""
|
712 |
-
result_set = await session.run(
|
713 |
-
main_query,
|
714 |
-
{
|
715 |
-
"max_nodes": MAX_GRAPH_NODES,
|
716 |
-
"entity_id": node_label,
|
717 |
-
"inclusive": inclusive,
|
718 |
-
"max_depth": max_depth,
|
719 |
-
"min_degree": min_degree,
|
720 |
-
},
|
721 |
-
)
|
722 |
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
)
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
)
|
756 |
-
|
|
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
finally:
|
762 |
-
await result_set.consume() # Ensure result set is consumed
|
763 |
|
764 |
except neo4jExceptions.ClientError as e:
|
765 |
logger.warning(f"APOC plugin error: {str(e)}")
|
@@ -767,46 +837,89 @@ class Neo4JStorage(BaseGraphStorage):
|
|
767 |
logger.warning(
|
768 |
"Neo4j: falling back to basic Cypher recursive search..."
|
769 |
)
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
return await self._robust_fallback(
|
775 |
-
node_label, max_depth, min_degree
|
776 |
)
|
777 |
|
778 |
return result
|
779 |
|
780 |
async def _robust_fallback(
|
781 |
-
self, node_label: str, max_depth: int,
|
782 |
) -> KnowledgeGraph:
|
783 |
"""
|
784 |
Fallback implementation when APOC plugin is not available or incompatible.
|
785 |
This method implements the same functionality as get_knowledge_graph but uses
|
786 |
-
only basic Cypher queries and
|
787 |
"""
|
|
|
|
|
788 |
result = KnowledgeGraph()
|
789 |
visited_nodes = set()
|
790 |
visited_edges = set()
|
|
|
791 |
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
|
809 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
async with self._driver.session(
|
811 |
database=self._DATABASE, default_access_mode="READ"
|
812 |
) as session:
|
@@ -815,32 +928,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
815 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
816 |
RETURN r, b, edge_id, target_id
|
817 |
"""
|
818 |
-
results = await session.run(query, entity_id=
|
819 |
|
820 |
# Get all records and release database connection
|
821 |
-
records = await results.fetch(
|
822 |
-
1000
|
823 |
-
) # Max neighbour nodes we can handled
|
824 |
await results.consume() # Ensure results are consumed
|
825 |
|
826 |
-
#
|
827 |
-
if current_depth > 1 and len(records) < min_degree:
|
828 |
-
return
|
829 |
-
|
830 |
-
# Add current node to result
|
831 |
-
result.nodes.append(node)
|
832 |
-
visited_nodes.add(node.id)
|
833 |
-
|
834 |
-
# Add edge to result if it exists and not already added
|
835 |
-
if edge and edge.id not in visited_edges:
|
836 |
-
result.edges.append(edge)
|
837 |
-
visited_edges.add(edge.id)
|
838 |
-
|
839 |
-
# Prepare nodes and edges for recursive processing
|
840 |
-
nodes_to_process = []
|
841 |
for record in records:
|
842 |
rel = record["r"]
|
843 |
edge_id = str(record["edge_id"])
|
|
|
844 |
if edge_id not in visited_edges:
|
845 |
b_node = record["b"]
|
846 |
target_id = b_node.get("entity_id")
|
@@ -849,55 +947,59 @@ class Neo4JStorage(BaseGraphStorage):
|
|
849 |
# Create KnowledgeGraphNode for target
|
850 |
target_node = KnowledgeGraphNode(
|
851 |
id=f"{target_id}",
|
852 |
-
labels=
|
853 |
-
properties=dict(b_node.
|
854 |
)
|
855 |
|
856 |
# Create KnowledgeGraphEdge
|
857 |
target_edge = KnowledgeGraphEdge(
|
858 |
id=f"{edge_id}",
|
859 |
type=rel.type,
|
860 |
-
source=f"{
|
861 |
target=f"{target_id}",
|
862 |
properties=dict(rel),
|
863 |
)
|
864 |
|
865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
866 |
else:
|
867 |
logger.warning(
|
868 |
-
f"Skipping edge {edge_id} due to missing
|
869 |
)
|
870 |
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
# Get the starting node's data
|
876 |
-
async with self._driver.session(
|
877 |
-
database=self._DATABASE, default_access_mode="READ"
|
878 |
-
) as session:
|
879 |
-
query = """
|
880 |
-
MATCH (n:base {entity_id: $entity_id})
|
881 |
-
RETURN id(n) as node_id, n
|
882 |
-
"""
|
883 |
-
node_result = await session.run(query, entity_id=node_label)
|
884 |
-
try:
|
885 |
-
node_record = await node_result.single()
|
886 |
-
if not node_record:
|
887 |
-
return result
|
888 |
-
|
889 |
-
# Create initial KnowledgeGraphNode
|
890 |
-
start_node = KnowledgeGraphNode(
|
891 |
-
id=f"{node_record['n'].get('entity_id')}",
|
892 |
-
labels=list(f"{node_record['n'].get('entity_id')}"),
|
893 |
-
properties=dict(node_record["n"].properties),
|
894 |
-
)
|
895 |
-
finally:
|
896 |
-
await node_result.consume() # Ensure results are consumed
|
897 |
-
|
898 |
-
# Start traversal with the initial node
|
899 |
-
await traverse(start_node, None, 0)
|
900 |
-
|
901 |
return result
|
902 |
|
903 |
async def get_all_labels(self) -> list[str]:
|
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
914 |
|
915 |
# Method 2: Query compatible with older versions
|
916 |
query = """
|
917 |
-
MATCH (n)
|
918 |
WHERE n.entity_id IS NOT NULL
|
919 |
RETURN DISTINCT n.entity_id AS label
|
920 |
ORDER BY label
|
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
|
|
1028 |
self, algorithm: str
|
1029 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
1030 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import inspect
|
2 |
import os
|
3 |
import re
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Any, final
|
6 |
import numpy as np
|
7 |
import configparser
|
8 |
|
|
|
28 |
exceptions as neo4jExceptions,
|
29 |
AsyncDriver,
|
30 |
AsyncManagedTransaction,
|
|
|
31 |
)
|
32 |
|
33 |
config = configparser.ConfigParser()
|
|
|
50 |
embedding_func=embedding_func,
|
51 |
)
|
52 |
self._driver = None
|
|
|
53 |
|
54 |
+
def __post_init__(self):
|
55 |
+
self._node_embed_algorithms = {
|
56 |
+
"node2vec": self._node2vec_embed,
|
57 |
+
}
|
58 |
+
|
59 |
+
async def initialize(self):
|
60 |
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
61 |
USERNAME = os.environ.get(
|
62 |
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
|
|
89 |
),
|
90 |
)
|
91 |
DATABASE = os.environ.get(
|
92 |
+
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
93 |
)
|
94 |
|
95 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
|
|
101 |
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
102 |
)
|
103 |
|
104 |
+
# Try to connect to the database and create it if it doesn't exist
|
105 |
+
for database in (DATABASE, None):
|
106 |
+
self._DATABASE = database
|
107 |
+
connected = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
try:
|
110 |
+
async with self._driver.session(database=database) as session:
|
111 |
+
try:
|
112 |
+
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
113 |
+
await result.consume() # Ensure result is consumed
|
114 |
+
logger.info(f"Connected to {database} at {URI}")
|
115 |
+
connected = True
|
116 |
+
except neo4jExceptions.ServiceUnavailable as e:
|
117 |
+
logger.error(
|
118 |
+
f"{database} at {URI} is not available".capitalize()
|
119 |
+
)
|
120 |
+
raise e
|
121 |
+
except neo4jExceptions.AuthError as e:
|
122 |
+
logger.error(f"Authentication failed for {database} at {URI}")
|
123 |
+
raise e
|
124 |
+
except neo4jExceptions.ClientError as e:
|
125 |
+
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
126 |
+
logger.info(
|
127 |
+
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
128 |
+
)
|
129 |
+
try:
|
130 |
+
async with self._driver.session() as session:
|
131 |
+
result = await session.run(
|
132 |
+
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
133 |
)
|
134 |
+
await result.consume() # Ensure result is consumed
|
135 |
+
logger.info(f"{database} at {URI} created".capitalize())
|
136 |
+
connected = True
|
137 |
+
except (
|
138 |
+
neo4jExceptions.ClientError,
|
139 |
+
neo4jExceptions.DatabaseError,
|
140 |
+
) as e:
|
141 |
+
if (
|
142 |
+
e.code
|
143 |
+
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
144 |
+
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
145 |
+
if database is not None:
|
146 |
+
logger.warning(
|
147 |
+
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
148 |
+
)
|
149 |
+
if database is None:
|
150 |
+
logger.error(f"Failed to create {database} at {URI}")
|
151 |
raise e
|
152 |
+
|
153 |
+
if connected:
|
154 |
+
# Create index for base nodes on entity_id if it doesn't exist
|
155 |
+
try:
|
156 |
+
async with self._driver.session(database=database) as session:
|
157 |
+
# Check if index exists first
|
158 |
+
check_query = """
|
159 |
+
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
160 |
+
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
161 |
+
RETURN count(*) > 0 AS exists
|
162 |
+
"""
|
163 |
try:
|
164 |
+
check_result = await session.run(check_query)
|
165 |
+
record = await check_result.single()
|
166 |
+
await check_result.consume()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
index_exists = record and record.get("exists", False)
|
|
|
169 |
|
170 |
+
if not index_exists:
|
171 |
+
# Create index only if it doesn't exist
|
172 |
+
result = await session.run(
|
173 |
+
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
174 |
+
)
|
175 |
+
await result.consume()
|
176 |
+
logger.info(
|
177 |
+
f"Created index for base nodes on entity_id in {database}"
|
178 |
+
)
|
179 |
+
except Exception:
|
180 |
+
# Fallback if db.indexes() is not supported in this Neo4j version
|
181 |
+
result = await session.run(
|
182 |
+
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
183 |
+
)
|
184 |
+
await result.consume()
|
185 |
+
except Exception as e:
|
186 |
+
logger.warning(f"Failed to create index: {str(e)}")
|
187 |
+
break
|
188 |
|
189 |
+
async def finalize(self):
|
190 |
"""Close the Neo4j driver and release all resources"""
|
191 |
if self._driver:
|
192 |
await self._driver.close()
|
|
|
194 |
|
195 |
async def __aexit__(self, exc_type, exc, tb):
|
196 |
"""Ensure driver is closed when context manager exits"""
|
197 |
+
await self.finalize()
|
198 |
|
199 |
async def index_done_callback(self) -> None:
|
200 |
# Noe4J handles persistence automatically
|
|
|
267 |
raise
|
268 |
|
269 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
270 |
+
"""Get node by its label identifier, return only node properties
|
271 |
|
272 |
Args:
|
273 |
node_id: The node label to look up
|
|
|
452 |
logger.debug(
|
453 |
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
454 |
)
|
455 |
+
# Return None when no edge found
|
456 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
457 |
finally:
|
458 |
await result.consume() # Ensure result is fully consumed
|
459 |
|
|
|
545 |
"""
|
546 |
properties = node_data
|
547 |
entity_type = properties["entity_type"]
|
|
|
548 |
if "entity_id" not in properties:
|
549 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
550 |
|
|
|
554 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
555 |
query = (
|
556 |
"""
|
557 |
+
MERGE (n:base {entity_id: $entity_id})
|
558 |
SET n += $properties
|
559 |
SET n:`%s`
|
560 |
"""
|
561 |
% entity_type
|
562 |
)
|
563 |
+
result = await tx.run(
|
564 |
+
query, entity_id=node_id, properties=properties
|
565 |
+
)
|
566 |
logger.debug(
|
567 |
+
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
|
568 |
)
|
569 |
await result.consume() # Ensure result is fully consumed
|
570 |
|
|
|
642 |
self,
|
643 |
node_label: str,
|
644 |
max_depth: int = 3,
|
645 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
|
|
646 |
) -> KnowledgeGraph:
|
647 |
"""
|
648 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
Args:
|
651 |
+
node_label: Label of the starting node, * means all nodes
|
652 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
653 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
654 |
+
|
655 |
Returns:
|
656 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
657 |
+
indicating whether the graph was truncated due to max_nodes limit
|
658 |
"""
|
659 |
result = KnowledgeGraph()
|
660 |
seen_nodes = set()
|
|
|
665 |
) as session:
|
666 |
try:
|
667 |
if node_label == "*":
|
668 |
+
# First check total node count to determine if graph is truncated
|
669 |
+
count_query = "MATCH (n) RETURN count(n) as total"
|
670 |
+
count_result = None
|
671 |
+
try:
|
672 |
+
count_result = await session.run(count_query)
|
673 |
+
count_record = await count_result.single()
|
674 |
+
|
675 |
+
if count_record and count_record["total"] > max_nodes:
|
676 |
+
result.is_truncated = True
|
677 |
+
logger.info(
|
678 |
+
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
679 |
+
)
|
680 |
+
finally:
|
681 |
+
if count_result:
|
682 |
+
await count_result.consume()
|
683 |
+
|
684 |
+
# Run main query to get nodes with highest degree
|
685 |
main_query = """
|
686 |
MATCH (n)
|
687 |
OPTIONAL MATCH (n)-[r]-()
|
688 |
WITH n, COALESCE(count(r), 0) AS degree
|
|
|
689 |
ORDER BY degree DESC
|
690 |
LIMIT $max_nodes
|
691 |
WITH collect({node: n}) AS filtered_nodes
|
|
|
696 |
RETURN filtered_nodes AS node_info,
|
697 |
collect(DISTINCT r) AS relationships
|
698 |
"""
|
699 |
+
result_set = None
|
700 |
+
try:
|
701 |
+
result_set = await session.run(
|
702 |
+
main_query,
|
703 |
+
{"max_nodes": max_nodes},
|
704 |
+
)
|
705 |
+
record = await result_set.single()
|
706 |
+
finally:
|
707 |
+
if result_set:
|
708 |
+
await result_set.consume()
|
709 |
|
710 |
else:
|
711 |
+
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
712 |
+
# First try without limit to check if we need to truncate
|
713 |
+
full_query = """
|
714 |
MATCH (start)
|
715 |
+
WHERE start.entity_id = $entity_id
|
|
|
|
|
|
|
|
|
716 |
WITH start
|
717 |
CALL apoc.path.subgraphAll(start, {
|
718 |
relationshipFilter: '',
|
|
|
721 |
bfs: true
|
722 |
})
|
723 |
YIELD nodes, relationships
|
724 |
+
WITH nodes, relationships, size(nodes) AS total_nodes
|
725 |
UNWIND nodes AS node
|
726 |
+
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
727 |
+
RETURN node_info, relationships, total_nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
|
730 |
+
# Try to get full result
|
731 |
+
full_result = None
|
732 |
+
try:
|
733 |
+
full_result = await session.run(
|
734 |
+
full_query,
|
735 |
+
{
|
736 |
+
"entity_id": node_label,
|
737 |
+
"max_depth": max_depth,
|
738 |
+
},
|
739 |
+
)
|
740 |
+
full_record = await full_result.single()
|
741 |
+
|
742 |
+
# If no record found, return empty KnowledgeGraph
|
743 |
+
if not full_record:
|
744 |
+
logger.debug(f"No nodes found for entity_id: {node_label}")
|
745 |
+
return result
|
746 |
+
|
747 |
+
# If record found, check node count
|
748 |
+
total_nodes = full_record["total_nodes"]
|
749 |
+
|
750 |
+
if total_nodes <= max_nodes:
|
751 |
+
# If node count is within limit, use full result directly
|
752 |
+
logger.debug(
|
753 |
+
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
754 |
+
)
|
755 |
+
record = full_record
|
756 |
+
else:
|
757 |
+
# If node count exceeds limit, set truncated flag and run limited query
|
758 |
+
result.is_truncated = True
|
759 |
+
logger.info(
|
760 |
+
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
761 |
+
)
|
762 |
+
|
763 |
+
# Run limited query
|
764 |
+
limited_query = """
|
765 |
+
MATCH (start)
|
766 |
+
WHERE start.entity_id = $entity_id
|
767 |
+
WITH start
|
768 |
+
CALL apoc.path.subgraphAll(start, {
|
769 |
+
relationshipFilter: '',
|
770 |
+
minLevel: 0,
|
771 |
+
maxLevel: $max_depth,
|
772 |
+
limit: $max_nodes,
|
773 |
+
bfs: true
|
774 |
+
})
|
775 |
+
YIELD nodes, relationships
|
776 |
+
UNWIND nodes AS node
|
777 |
+
WITH collect({node: node}) AS node_info, relationships
|
778 |
+
RETURN node_info, relationships
|
779 |
+
"""
|
780 |
+
result_set = None
|
781 |
+
try:
|
782 |
+
result_set = await session.run(
|
783 |
+
limited_query,
|
784 |
+
{
|
785 |
+
"entity_id": node_label,
|
786 |
+
"max_depth": max_depth,
|
787 |
+
"max_nodes": max_nodes,
|
788 |
+
},
|
789 |
)
|
790 |
+
record = await result_set.single()
|
791 |
+
finally:
|
792 |
+
if result_set:
|
793 |
+
await result_set.consume()
|
794 |
+
finally:
|
795 |
+
if full_result:
|
796 |
+
await full_result.consume()
|
797 |
+
|
798 |
+
if record:
|
799 |
+
# Handle nodes (compatible with multi-label cases)
|
800 |
+
for node_info in record["node_info"]:
|
801 |
+
node = node_info["node"]
|
802 |
+
node_id = node.id
|
803 |
+
if node_id not in seen_nodes:
|
804 |
+
result.nodes.append(
|
805 |
+
KnowledgeGraphNode(
|
806 |
+
id=f"{node_id}",
|
807 |
+
labels=[node.get("entity_id")],
|
808 |
+
properties=dict(node),
|
809 |
+
)
|
810 |
+
)
|
811 |
+
seen_nodes.add(node_id)
|
812 |
+
|
813 |
+
# Handle relationships (including direction information)
|
814 |
+
for rel in record["relationships"]:
|
815 |
+
edge_id = rel.id
|
816 |
+
if edge_id not in seen_edges:
|
817 |
+
start = rel.start_node
|
818 |
+
end = rel.end_node
|
819 |
+
result.edges.append(
|
820 |
+
KnowledgeGraphEdge(
|
821 |
+
id=f"{edge_id}",
|
822 |
+
type=rel.type,
|
823 |
+
source=f"{start.id}",
|
824 |
+
target=f"{end.id}",
|
825 |
+
properties=dict(rel),
|
826 |
)
|
827 |
+
)
|
828 |
+
seen_edges.add(edge_id)
|
829 |
|
830 |
+
logger.info(
|
831 |
+
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
832 |
+
)
|
|
|
|
|
833 |
|
834 |
except neo4jExceptions.ClientError as e:
|
835 |
logger.warning(f"APOC plugin error: {str(e)}")
|
|
|
837 |
logger.warning(
|
838 |
"Neo4j: falling back to basic Cypher recursive search..."
|
839 |
)
|
840 |
+
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
841 |
+
else:
|
842 |
+
logger.warning(
|
843 |
+
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
|
|
|
|
844 |
)
|
845 |
|
846 |
return result
|
847 |
|
848 |
async def _robust_fallback(
|
849 |
+
self, node_label: str, max_depth: int, max_nodes: int
|
850 |
) -> KnowledgeGraph:
|
851 |
"""
|
852 |
Fallback implementation when APOC plugin is not available or incompatible.
|
853 |
This method implements the same functionality as get_knowledge_graph but uses
|
854 |
+
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
855 |
"""
|
856 |
+
from collections import deque
|
857 |
+
|
858 |
result = KnowledgeGraph()
|
859 |
visited_nodes = set()
|
860 |
visited_edges = set()
|
861 |
+
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
862 |
|
863 |
+
# Get the starting node's data
|
864 |
+
async with self._driver.session(
|
865 |
+
database=self._DATABASE, default_access_mode="READ"
|
866 |
+
) as session:
|
867 |
+
query = """
|
868 |
+
MATCH (n:base {entity_id: $entity_id})
|
869 |
+
RETURN id(n) as node_id, n
|
870 |
+
"""
|
871 |
+
node_result = await session.run(query, entity_id=node_label)
|
872 |
+
try:
|
873 |
+
node_record = await node_result.single()
|
874 |
+
if not node_record:
|
875 |
+
return result
|
876 |
+
|
877 |
+
# Create initial KnowledgeGraphNode
|
878 |
+
start_node = KnowledgeGraphNode(
|
879 |
+
id=f"{node_record['n'].get('entity_id')}",
|
880 |
+
labels=[node_record["n"].get("entity_id")],
|
881 |
+
properties=dict(node_record["n"]._properties),
|
882 |
+
)
|
883 |
+
finally:
|
884 |
+
await node_result.consume() # Ensure results are consumed
|
885 |
|
886 |
+
# Initialize queue for BFS with (node, edge, depth) tuples
|
887 |
+
# edge is None for the starting node
|
888 |
+
queue = deque([(start_node, None, 0)])
|
889 |
|
890 |
+
# True BFS implementation using a queue
|
891 |
+
while queue and len(visited_nodes) < max_nodes:
|
892 |
+
# Dequeue the next node to process
|
893 |
+
current_node, current_edge, current_depth = queue.popleft()
|
894 |
+
|
895 |
+
# Skip if already visited or exceeds max depth
|
896 |
+
if current_node.id in visited_nodes:
|
897 |
+
continue
|
898 |
+
|
899 |
+
if current_depth > max_depth:
|
900 |
+
logger.debug(
|
901 |
+
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
902 |
+
)
|
903 |
+
continue
|
904 |
+
|
905 |
+
# Add current node to result
|
906 |
+
result.nodes.append(current_node)
|
907 |
+
visited_nodes.add(current_node.id)
|
908 |
+
|
909 |
+
# Add edge to result if it exists and not already added
|
910 |
+
if current_edge and current_edge.id not in visited_edges:
|
911 |
+
result.edges.append(current_edge)
|
912 |
+
visited_edges.add(current_edge.id)
|
913 |
+
|
914 |
+
# Stop if we've reached the node limit
|
915 |
+
if len(visited_nodes) >= max_nodes:
|
916 |
+
result.is_truncated = True
|
917 |
+
logger.info(
|
918 |
+
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
919 |
+
)
|
920 |
+
break
|
921 |
+
|
922 |
+
# Get all edges and target nodes for the current node (even at max_depth)
|
923 |
async with self._driver.session(
|
924 |
database=self._DATABASE, default_access_mode="READ"
|
925 |
) as session:
|
|
|
928 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
929 |
RETURN r, b, edge_id, target_id
|
930 |
"""
|
931 |
+
results = await session.run(query, entity_id=current_node.id)
|
932 |
|
933 |
# Get all records and release database connection
|
934 |
+
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
|
|
|
|
935 |
await results.consume() # Ensure results are consumed
|
936 |
|
937 |
+
# Process all neighbors - capture all edges but only queue unvisited nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
938 |
for record in records:
|
939 |
rel = record["r"]
|
940 |
edge_id = str(record["edge_id"])
|
941 |
+
|
942 |
if edge_id not in visited_edges:
|
943 |
b_node = record["b"]
|
944 |
target_id = b_node.get("entity_id")
|
|
|
947 |
# Create KnowledgeGraphNode for target
|
948 |
target_node = KnowledgeGraphNode(
|
949 |
id=f"{target_id}",
|
950 |
+
labels=[target_id],
|
951 |
+
properties=dict(b_node._properties),
|
952 |
)
|
953 |
|
954 |
# Create KnowledgeGraphEdge
|
955 |
target_edge = KnowledgeGraphEdge(
|
956 |
id=f"{edge_id}",
|
957 |
type=rel.type,
|
958 |
+
source=f"{current_node.id}",
|
959 |
target=f"{target_id}",
|
960 |
properties=dict(rel),
|
961 |
)
|
962 |
|
963 |
+
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
964 |
+
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
965 |
+
|
966 |
+
# 检查是否已存在相同的边(考虑无向性)
|
967 |
+
if sorted_pair not in visited_edge_pairs:
|
968 |
+
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
969 |
+
if target_id in visited_nodes or (
|
970 |
+
target_id not in visited_nodes
|
971 |
+
and current_depth < max_depth
|
972 |
+
):
|
973 |
+
result.edges.append(target_edge)
|
974 |
+
visited_edges.add(edge_id)
|
975 |
+
visited_edge_pairs.add(sorted_pair)
|
976 |
+
|
977 |
+
# Only add unvisited nodes to the queue for further expansion
|
978 |
+
if target_id not in visited_nodes:
|
979 |
+
# Only add to queue if we're not at max depth yet
|
980 |
+
if current_depth < max_depth:
|
981 |
+
# Add node to queue with incremented depth
|
982 |
+
# Edge is already added to result, so we pass None as edge
|
983 |
+
queue.append((target_node, None, current_depth + 1))
|
984 |
+
else:
|
985 |
+
# At max depth, we've already added the edge but we don't add the node
|
986 |
+
# This prevents adding nodes beyond max_depth to the result
|
987 |
+
logger.debug(
|
988 |
+
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
989 |
+
)
|
990 |
+
else:
|
991 |
+
# If target node already exists in result, we don't need to add it again
|
992 |
+
logger.debug(
|
993 |
+
f"Node {target_id} already visited, edge added but node not queued"
|
994 |
+
)
|
995 |
else:
|
996 |
logger.warning(
|
997 |
+
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
998 |
)
|
999 |
|
1000 |
+
logger.info(
|
1001 |
+
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
1002 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1003 |
return result
|
1004 |
|
1005 |
async def get_all_labels(self) -> list[str]:
|
|
|
1016 |
|
1017 |
# Method 2: Query compatible with older versions
|
1018 |
query = """
|
1019 |
+
MATCH (n:base)
|
1020 |
WHERE n.entity_id IS NOT NULL
|
1021 |
RETURN DISTINCT n.entity_id AS label
|
1022 |
ORDER BY label
|
|
|
1130 |
self, algorithm: str
|
1131 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
1132 |
raise NotImplementedError
|
1133 |
+
|
1134 |
+
async def drop(self) -> dict[str, str]:
|
1135 |
+
"""Drop all data from storage and clean up resources
|
1136 |
+
|
1137 |
+
This method will delete all nodes and relationships in the Neo4j database.
|
1138 |
+
|
1139 |
+
Returns:
|
1140 |
+
dict[str, str]: Operation status and message
|
1141 |
+
- On success: {"status": "success", "message": "data dropped"}
|
1142 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
1143 |
+
"""
|
1144 |
+
try:
|
1145 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
1146 |
+
# Delete all nodes and relationships
|
1147 |
+
query = "MATCH (n) DETACH DELETE n"
|
1148 |
+
result = await session.run(query)
|
1149 |
+
await result.consume() # Ensure result is fully consumed
|
1150 |
+
|
1151 |
+
logger.info(
|
1152 |
+
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
1153 |
+
)
|
1154 |
+
return {"status": "success", "message": "data dropped"}
|
1155 |
+
except Exception as e:
|
1156 |
+
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
1157 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
42 |
)
|
43 |
nx.write_graphml(graph, file_name)
|
44 |
|
|
|
45 |
@staticmethod
|
46 |
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
47 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
|
|
155 |
return None
|
156 |
|
157 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
graph = await self._get_graph()
|
159 |
graph.add_node(node_id, **node_data)
|
160 |
|
161 |
async def upsert_edge(
|
162 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
163 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
graph = await self._get_graph()
|
165 |
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
166 |
|
167 |
async def delete_node(self, node_id: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
graph = await self._get_graph()
|
169 |
if graph.has_node(node_id):
|
170 |
graph.remove_node(node_id)
|
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
172 |
else:
|
173 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
174 |
|
|
|
175 |
async def embed_nodes(
|
176 |
self, algorithm: str
|
177 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
192 |
async def remove_nodes(self, nodes: list[str]):
|
193 |
"""Delete multiple nodes
|
194 |
|
|
|
|
|
|
|
|
|
|
|
195 |
Args:
|
196 |
nodes: List of node IDs to be deleted
|
197 |
"""
|
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
203 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
204 |
"""Delete multiple edges
|
205 |
|
|
|
|
|
|
|
|
|
|
|
206 |
Args:
|
207 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
208 |
"""
|
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
|
|
229 |
self,
|
230 |
node_label: str,
|
231 |
max_depth: int = 3,
|
232 |
-
|
233 |
-
inclusive: bool = False,
|
234 |
) -> KnowledgeGraph:
|
235 |
"""
|
236 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
237 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
238 |
-
When reducing the number of nodes, the prioritization criteria are as follows:
|
239 |
-
1. min_degree does not affect nodes directly connected to the matching nodes
|
240 |
-
2. Label matching nodes take precedence
|
241 |
-
3. Followed by nodes directly connected to the matching nodes
|
242 |
-
4. Finally, the degree of the nodes
|
243 |
|
244 |
Args:
|
245 |
-
node_label: Label of the starting node
|
246 |
-
max_depth: Maximum depth of the subgraph
|
247 |
-
|
248 |
-
inclusive: Do an inclusive search if true
|
249 |
|
250 |
Returns:
|
251 |
-
KnowledgeGraph object containing nodes and edges
|
|
|
252 |
"""
|
253 |
-
result = KnowledgeGraph()
|
254 |
-
seen_nodes = set()
|
255 |
-
seen_edges = set()
|
256 |
-
|
257 |
graph = await self._get_graph()
|
258 |
|
259 |
-
|
260 |
-
start_nodes = set()
|
261 |
-
direct_connected_nodes = set()
|
262 |
|
263 |
# Handle special case for "*" label
|
264 |
if node_label == "*":
|
265 |
-
#
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
else:
|
270 |
-
#
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
subgraph
|
304 |
-
|
305 |
-
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
306 |
-
if min_degree > 0:
|
307 |
-
nodes_to_keep = [
|
308 |
-
node
|
309 |
-
for node, degree in subgraph.degree()
|
310 |
-
if node in start_nodes
|
311 |
-
or node in direct_connected_nodes
|
312 |
-
or degree >= min_degree
|
313 |
-
]
|
314 |
-
subgraph = subgraph.subgraph(nodes_to_keep)
|
315 |
-
|
316 |
-
# Check if number of nodes exceeds max_graph_nodes
|
317 |
-
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
318 |
-
origin_nodes = len(subgraph.nodes())
|
319 |
-
node_degrees = dict(subgraph.degree())
|
320 |
-
|
321 |
-
def priority_key(node_item):
|
322 |
-
node, degree = node_item
|
323 |
-
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
324 |
-
if node in start_nodes:
|
325 |
-
priority = 2
|
326 |
-
elif node in direct_connected_nodes:
|
327 |
-
priority = 1
|
328 |
-
else:
|
329 |
-
priority = 0
|
330 |
-
return (priority, degree)
|
331 |
-
|
332 |
-
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
333 |
-
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
334 |
-
:MAX_GRAPH_NODES
|
335 |
-
]
|
336 |
-
top_node_ids = [node[0] for node in top_nodes]
|
337 |
-
# Create new subgraph and keep nodes only with most degree
|
338 |
-
subgraph = subgraph.subgraph(top_node_ids)
|
339 |
-
logger.info(
|
340 |
-
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
341 |
-
)
|
342 |
|
343 |
# Add nodes to result
|
|
|
|
|
344 |
for node in subgraph.nodes():
|
345 |
if str(node) in seen_nodes:
|
346 |
continue
|
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
368 |
for edge in subgraph.edges():
|
369 |
source, target = edge
|
370 |
# Esure unique edge_id for undirect graph
|
371 |
-
if source > target:
|
372 |
source, target = target, source
|
373 |
edge_id = f"{source}-{target}"
|
374 |
if edge_id in seen_edges:
|
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
|
|
424 |
return False # Return error
|
425 |
|
426 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
)
|
43 |
nx.write_graphml(graph, file_name)
|
44 |
|
45 |
+
# TODO:deprecated, remove later
|
46 |
@staticmethod
|
47 |
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
48 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
|
|
156 |
return None
|
157 |
|
158 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
159 |
+
"""
|
160 |
+
Importance notes:
|
161 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
162 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
163 |
+
KG-storage-log should be used to avoid data corruption
|
164 |
+
"""
|
165 |
graph = await self._get_graph()
|
166 |
graph.add_node(node_id, **node_data)
|
167 |
|
168 |
async def upsert_edge(
|
169 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
170 |
) -> None:
|
171 |
+
"""
|
172 |
+
Importance notes:
|
173 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
174 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
175 |
+
KG-storage-log should be used to avoid data corruption
|
176 |
+
"""
|
177 |
graph = await self._get_graph()
|
178 |
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
179 |
|
180 |
async def delete_node(self, node_id: str) -> None:
|
181 |
+
"""
|
182 |
+
Importance notes:
|
183 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
184 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
185 |
+
KG-storage-log should be used to avoid data corruption
|
186 |
+
"""
|
187 |
graph = await self._get_graph()
|
188 |
if graph.has_node(node_id):
|
189 |
graph.remove_node(node_id)
|
|
|
191 |
else:
|
192 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
193 |
|
194 |
+
# TODO: NOT USED
|
195 |
async def embed_nodes(
|
196 |
self, algorithm: str
|
197 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
|
|
212 |
async def remove_nodes(self, nodes: list[str]):
|
213 |
"""Delete multiple nodes
|
214 |
|
215 |
+
Importance notes:
|
216 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
217 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
218 |
+
KG-storage-log should be used to avoid data corruption
|
219 |
+
|
220 |
Args:
|
221 |
nodes: List of node IDs to be deleted
|
222 |
"""
|
|
|
228 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
229 |
"""Delete multiple edges
|
230 |
|
231 |
+
Importance notes:
|
232 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
233 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
234 |
+
KG-storage-log should be used to avoid data corruption
|
235 |
+
|
236 |
Args:
|
237 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
238 |
"""
|
|
|
259 |
self,
|
260 |
node_label: str,
|
261 |
max_depth: int = 3,
|
262 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
|
|
263 |
) -> KnowledgeGraph:
|
264 |
"""
|
265 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
Args:
|
268 |
+
node_label: Label of the starting node,* means all nodes
|
269 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
270 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
|
|
271 |
|
272 |
Returns:
|
273 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
274 |
+
indicating whether the graph was truncated due to max_nodes limit
|
275 |
"""
|
|
|
|
|
|
|
|
|
276 |
graph = await self._get_graph()
|
277 |
|
278 |
+
result = KnowledgeGraph()
|
|
|
|
|
279 |
|
280 |
# Handle special case for "*" label
|
281 |
if node_label == "*":
|
282 |
+
# Get degrees of all nodes
|
283 |
+
degrees = dict(graph.degree())
|
284 |
+
# Sort nodes by degree in descending order and take top max_nodes
|
285 |
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
286 |
+
|
287 |
+
# Check if graph is truncated
|
288 |
+
if len(sorted_nodes) > max_nodes:
|
289 |
+
result.is_truncated = True
|
290 |
+
logger.info(
|
291 |
+
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
292 |
+
)
|
293 |
+
|
294 |
+
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
295 |
+
# Create subgraph with the highest degree nodes
|
296 |
+
subgraph = graph.subgraph(limited_nodes)
|
297 |
else:
|
298 |
+
# Check if node exists
|
299 |
+
if node_label not in graph:
|
300 |
+
logger.warning(f"Node {node_label} not found in the graph")
|
301 |
+
return KnowledgeGraph() # Return empty graph
|
302 |
+
|
303 |
+
# Use BFS to get nodes
|
304 |
+
bfs_nodes = []
|
305 |
+
visited = set()
|
306 |
+
queue = [(node_label, 0)] # (node, depth) tuple
|
307 |
+
|
308 |
+
# Breadth-first search
|
309 |
+
while queue and len(bfs_nodes) < max_nodes:
|
310 |
+
current, depth = queue.pop(0)
|
311 |
+
if current not in visited:
|
312 |
+
visited.add(current)
|
313 |
+
bfs_nodes.append(current)
|
314 |
+
|
315 |
+
# Only explore neighbors if we haven't reached max_depth
|
316 |
+
if depth < max_depth:
|
317 |
+
# Add neighbor nodes to queue with incremented depth
|
318 |
+
neighbors = list(graph.neighbors(current))
|
319 |
+
queue.extend(
|
320 |
+
[(n, depth + 1) for n in neighbors if n not in visited]
|
321 |
+
)
|
322 |
+
|
323 |
+
# Check if graph is truncated - if we still have nodes in the queue
|
324 |
+
# and we've reached max_nodes, then the graph is truncated
|
325 |
+
if queue and len(bfs_nodes) >= max_nodes:
|
326 |
+
result.is_truncated = True
|
327 |
+
logger.info(
|
328 |
+
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
|
329 |
+
)
|
330 |
+
|
331 |
+
# Create subgraph with BFS discovered nodes
|
332 |
+
subgraph = graph.subgraph(bfs_nodes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
# Add nodes to result
|
335 |
+
seen_nodes = set()
|
336 |
+
seen_edges = set()
|
337 |
for node in subgraph.nodes():
|
338 |
if str(node) in seen_nodes:
|
339 |
continue
|
|
|
361 |
for edge in subgraph.edges():
|
362 |
source, target = edge
|
363 |
# Esure unique edge_id for undirect graph
|
364 |
+
if str(source) > str(target):
|
365 |
source, target = target, source
|
366 |
edge_id = f"{source}-{target}"
|
367 |
if edge_id in seen_edges:
|
|
|
417 |
return False # Return error
|
418 |
|
419 |
return True
|
420 |
+
|
421 |
+
async def drop(self) -> dict[str, str]:
|
422 |
+
"""Drop all graph data from storage and clean up resources
|
423 |
+
|
424 |
+
This method will:
|
425 |
+
1. Remove the graph storage file if it exists
|
426 |
+
2. Reset the graph to an empty state
|
427 |
+
3. Update flags to notify other processes
|
428 |
+
4. Changes is persisted to disk immediately
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
dict[str, str]: Operation status and message
|
432 |
+
- On success: {"status": "success", "message": "data dropped"}
|
433 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
434 |
+
"""
|
435 |
+
try:
|
436 |
+
async with self._storage_lock:
|
437 |
+
# delete _client_file_name
|
438 |
+
if os.path.exists(self._graphml_xml_file):
|
439 |
+
os.remove(self._graphml_xml_file)
|
440 |
+
self._graph = nx.Graph()
|
441 |
+
# Notify other processes that data has been updated
|
442 |
+
await set_all_update_flags(self.namespace)
|
443 |
+
# Reset own update flag to avoid self-reloading
|
444 |
+
self.storage_updated.value = False
|
445 |
+
logger.info(
|
446 |
+
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
|
447 |
+
)
|
448 |
+
return {"status": "success", "message": "data dropped"}
|
449 |
+
except Exception as e:
|
450 |
+
logger.error(f"Error dropping graph {self.namespace}: {e}")
|
451 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/oracle_impl.py
DELETED
@@ -1,1346 +0,0 @@
|
|
1 |
-
import array
|
2 |
-
import asyncio
|
3 |
-
|
4 |
-
# import html
|
5 |
-
import os
|
6 |
-
from dataclasses import dataclass, field
|
7 |
-
from typing import Any, Union, final
|
8 |
-
import numpy as np
|
9 |
-
import configparser
|
10 |
-
|
11 |
-
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
12 |
-
|
13 |
-
from ..base import (
|
14 |
-
BaseGraphStorage,
|
15 |
-
BaseKVStorage,
|
16 |
-
BaseVectorStorage,
|
17 |
-
)
|
18 |
-
from ..namespace import NameSpace, is_namespace
|
19 |
-
from ..utils import logger
|
20 |
-
|
21 |
-
import pipmaster as pm
|
22 |
-
|
23 |
-
if not pm.is_installed("graspologic"):
|
24 |
-
pm.install("graspologic")
|
25 |
-
|
26 |
-
if not pm.is_installed("oracledb"):
|
27 |
-
pm.install("oracledb")
|
28 |
-
|
29 |
-
from graspologic import embed
|
30 |
-
import oracledb
|
31 |
-
|
32 |
-
|
33 |
-
class OracleDB:
|
34 |
-
def __init__(self, config, **kwargs):
|
35 |
-
self.host = config.get("host", None)
|
36 |
-
self.port = config.get("port", None)
|
37 |
-
self.user = config.get("user", None)
|
38 |
-
self.password = config.get("password", None)
|
39 |
-
self.dsn = config.get("dsn", None)
|
40 |
-
self.config_dir = config.get("config_dir", None)
|
41 |
-
self.wallet_location = config.get("wallet_location", None)
|
42 |
-
self.wallet_password = config.get("wallet_password", None)
|
43 |
-
self.workspace = config.get("workspace", None)
|
44 |
-
self.max = 12
|
45 |
-
self.increment = 1
|
46 |
-
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
47 |
-
if self.user is None or self.password is None:
|
48 |
-
raise ValueError("Missing database user or password")
|
49 |
-
|
50 |
-
try:
|
51 |
-
oracledb.defaults.fetch_lobs = False
|
52 |
-
|
53 |
-
self.pool = oracledb.create_pool_async(
|
54 |
-
user=self.user,
|
55 |
-
password=self.password,
|
56 |
-
dsn=self.dsn,
|
57 |
-
config_dir=self.config_dir,
|
58 |
-
wallet_location=self.wallet_location,
|
59 |
-
wallet_password=self.wallet_password,
|
60 |
-
min=1,
|
61 |
-
max=self.max,
|
62 |
-
increment=self.increment,
|
63 |
-
)
|
64 |
-
logger.info(f"Connected to Oracle database at {self.dsn}")
|
65 |
-
except Exception as e:
|
66 |
-
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
67 |
-
logger.error(f"Oracle database error: {e}")
|
68 |
-
raise
|
69 |
-
|
70 |
-
def numpy_converter_in(self, value):
|
71 |
-
"""Convert numpy array to array.array"""
|
72 |
-
if value.dtype == np.float64:
|
73 |
-
dtype = "d"
|
74 |
-
elif value.dtype == np.float32:
|
75 |
-
dtype = "f"
|
76 |
-
else:
|
77 |
-
dtype = "b"
|
78 |
-
return array.array(dtype, value)
|
79 |
-
|
80 |
-
def input_type_handler(self, cursor, value, arraysize):
|
81 |
-
"""Set the type handler for the input data"""
|
82 |
-
if isinstance(value, np.ndarray):
|
83 |
-
return cursor.var(
|
84 |
-
oracledb.DB_TYPE_VECTOR,
|
85 |
-
arraysize=arraysize,
|
86 |
-
inconverter=self.numpy_converter_in,
|
87 |
-
)
|
88 |
-
|
89 |
-
def numpy_converter_out(self, value):
|
90 |
-
"""Convert array.array to numpy array"""
|
91 |
-
if value.typecode == "b":
|
92 |
-
dtype = np.int8
|
93 |
-
elif value.typecode == "f":
|
94 |
-
dtype = np.float32
|
95 |
-
else:
|
96 |
-
dtype = np.float64
|
97 |
-
return np.array(value, copy=False, dtype=dtype)
|
98 |
-
|
99 |
-
def output_type_handler(self, cursor, metadata):
|
100 |
-
"""Set the type handler for the output data"""
|
101 |
-
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
102 |
-
return cursor.var(
|
103 |
-
metadata.type_code,
|
104 |
-
arraysize=cursor.arraysize,
|
105 |
-
outconverter=self.numpy_converter_out,
|
106 |
-
)
|
107 |
-
|
108 |
-
async def check_tables(self):
|
109 |
-
for k, v in TABLES.items():
|
110 |
-
try:
|
111 |
-
if k.lower() == "lightrag_graph":
|
112 |
-
await self.query(
|
113 |
-
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
114 |
-
)
|
115 |
-
else:
|
116 |
-
await self.query(f"SELECT 1 FROM {k}")
|
117 |
-
except Exception as e:
|
118 |
-
logger.error(f"Failed to check table {k} in Oracle database")
|
119 |
-
logger.error(f"Oracle database error: {e}")
|
120 |
-
try:
|
121 |
-
# print(v["ddl"])
|
122 |
-
await self.execute(v["ddl"])
|
123 |
-
logger.info(f"Created table {k} in Oracle database")
|
124 |
-
except Exception as e:
|
125 |
-
logger.error(f"Failed to create table {k} in Oracle database")
|
126 |
-
logger.error(f"Oracle database error: {e}")
|
127 |
-
|
128 |
-
logger.info("Finished check all tables in Oracle database")
|
129 |
-
|
130 |
-
async def query(
|
131 |
-
self, sql: str, params: dict = None, multirows: bool = False
|
132 |
-
) -> Union[dict, None]:
|
133 |
-
async with self.pool.acquire() as connection:
|
134 |
-
connection.inputtypehandler = self.input_type_handler
|
135 |
-
connection.outputtypehandler = self.output_type_handler
|
136 |
-
with connection.cursor() as cursor:
|
137 |
-
try:
|
138 |
-
await cursor.execute(sql, params)
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Oracle database error: {e}")
|
141 |
-
raise
|
142 |
-
columns = [column[0].lower() for column in cursor.description]
|
143 |
-
if multirows:
|
144 |
-
rows = await cursor.fetchall()
|
145 |
-
if rows:
|
146 |
-
data = [dict(zip(columns, row)) for row in rows]
|
147 |
-
else:
|
148 |
-
data = []
|
149 |
-
else:
|
150 |
-
row = await cursor.fetchone()
|
151 |
-
if row:
|
152 |
-
data = dict(zip(columns, row))
|
153 |
-
else:
|
154 |
-
data = None
|
155 |
-
return data
|
156 |
-
|
157 |
-
async def execute(self, sql: str, data: Union[list, dict] = None):
|
158 |
-
# logger.info("go into OracleDB execute method")
|
159 |
-
try:
|
160 |
-
async with self.pool.acquire() as connection:
|
161 |
-
connection.inputtypehandler = self.input_type_handler
|
162 |
-
connection.outputtypehandler = self.output_type_handler
|
163 |
-
with connection.cursor() as cursor:
|
164 |
-
if data is None:
|
165 |
-
await cursor.execute(sql)
|
166 |
-
else:
|
167 |
-
await cursor.execute(sql, data)
|
168 |
-
await connection.commit()
|
169 |
-
except Exception as e:
|
170 |
-
logger.error(f"Oracle database error: {e}")
|
171 |
-
raise
|
172 |
-
|
173 |
-
|
174 |
-
class ClientManager:
|
175 |
-
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
176 |
-
_lock = asyncio.Lock()
|
177 |
-
|
178 |
-
@staticmethod
|
179 |
-
def get_config() -> dict[str, Any]:
|
180 |
-
config = configparser.ConfigParser()
|
181 |
-
config.read("config.ini", "utf-8")
|
182 |
-
|
183 |
-
return {
|
184 |
-
"user": os.environ.get(
|
185 |
-
"ORACLE_USER",
|
186 |
-
config.get("oracle", "user", fallback=None),
|
187 |
-
),
|
188 |
-
"password": os.environ.get(
|
189 |
-
"ORACLE_PASSWORD",
|
190 |
-
config.get("oracle", "password", fallback=None),
|
191 |
-
),
|
192 |
-
"dsn": os.environ.get(
|
193 |
-
"ORACLE_DSN",
|
194 |
-
config.get("oracle", "dsn", fallback=None),
|
195 |
-
),
|
196 |
-
"config_dir": os.environ.get(
|
197 |
-
"ORACLE_CONFIG_DIR",
|
198 |
-
config.get("oracle", "config_dir", fallback=None),
|
199 |
-
),
|
200 |
-
"wallet_location": os.environ.get(
|
201 |
-
"ORACLE_WALLET_LOCATION",
|
202 |
-
config.get("oracle", "wallet_location", fallback=None),
|
203 |
-
),
|
204 |
-
"wallet_password": os.environ.get(
|
205 |
-
"ORACLE_WALLET_PASSWORD",
|
206 |
-
config.get("oracle", "wallet_password", fallback=None),
|
207 |
-
),
|
208 |
-
"workspace": os.environ.get(
|
209 |
-
"ORACLE_WORKSPACE",
|
210 |
-
config.get("oracle", "workspace", fallback="default"),
|
211 |
-
),
|
212 |
-
}
|
213 |
-
|
214 |
-
@classmethod
|
215 |
-
async def get_client(cls) -> OracleDB:
|
216 |
-
async with cls._lock:
|
217 |
-
if cls._instances["db"] is None:
|
218 |
-
config = ClientManager.get_config()
|
219 |
-
db = OracleDB(config)
|
220 |
-
await db.check_tables()
|
221 |
-
cls._instances["db"] = db
|
222 |
-
cls._instances["ref_count"] = 0
|
223 |
-
cls._instances["ref_count"] += 1
|
224 |
-
return cls._instances["db"]
|
225 |
-
|
226 |
-
@classmethod
|
227 |
-
async def release_client(cls, db: OracleDB):
|
228 |
-
async with cls._lock:
|
229 |
-
if db is not None:
|
230 |
-
if db is cls._instances["db"]:
|
231 |
-
cls._instances["ref_count"] -= 1
|
232 |
-
if cls._instances["ref_count"] == 0:
|
233 |
-
await db.pool.close()
|
234 |
-
logger.info("Closed OracleDB database connection pool")
|
235 |
-
cls._instances["db"] = None
|
236 |
-
else:
|
237 |
-
await db.pool.close()
|
238 |
-
|
239 |
-
|
240 |
-
@final
|
241 |
-
@dataclass
|
242 |
-
class OracleKVStorage(BaseKVStorage):
|
243 |
-
db: OracleDB = field(default=None)
|
244 |
-
meta_fields = None
|
245 |
-
|
246 |
-
def __post_init__(self):
|
247 |
-
self._data = {}
|
248 |
-
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
249 |
-
|
250 |
-
async def initialize(self):
|
251 |
-
if self.db is None:
|
252 |
-
self.db = await ClientManager.get_client()
|
253 |
-
|
254 |
-
async def finalize(self):
|
255 |
-
if self.db is not None:
|
256 |
-
await ClientManager.release_client(self.db)
|
257 |
-
self.db = None
|
258 |
-
|
259 |
-
################ QUERY METHODS ################
|
260 |
-
|
261 |
-
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
262 |
-
"""Get doc_full data based on id."""
|
263 |
-
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
264 |
-
params = {"workspace": self.db.workspace, "id": id}
|
265 |
-
# print("get_by_id:"+SQL)
|
266 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
267 |
-
array_res = await self.db.query(SQL, params, multirows=True)
|
268 |
-
res = {}
|
269 |
-
for row in array_res:
|
270 |
-
res[row["id"]] = row
|
271 |
-
if res:
|
272 |
-
return res
|
273 |
-
else:
|
274 |
-
return None
|
275 |
-
else:
|
276 |
-
return await self.db.query(SQL, params)
|
277 |
-
|
278 |
-
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
279 |
-
"""Specifically for llm_response_cache."""
|
280 |
-
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
281 |
-
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
|
282 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
283 |
-
array_res = await self.db.query(SQL, params, multirows=True)
|
284 |
-
res = {}
|
285 |
-
for row in array_res:
|
286 |
-
res[row["id"]] = row
|
287 |
-
return res
|
288 |
-
else:
|
289 |
-
return None
|
290 |
-
|
291 |
-
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
292 |
-
"""Get doc_chunks data based on id"""
|
293 |
-
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
294 |
-
ids=",".join([f"'{id}'" for id in ids])
|
295 |
-
)
|
296 |
-
params = {"workspace": self.db.workspace}
|
297 |
-
# print("get_by_ids:"+SQL)
|
298 |
-
res = await self.db.query(SQL, params, multirows=True)
|
299 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
300 |
-
modes = set()
|
301 |
-
dict_res: dict[str, dict] = {}
|
302 |
-
for row in res:
|
303 |
-
modes.add(row["mode"])
|
304 |
-
for mode in modes:
|
305 |
-
if mode not in dict_res:
|
306 |
-
dict_res[mode] = {}
|
307 |
-
for row in res:
|
308 |
-
dict_res[row["mode"]][row["id"]] = row
|
309 |
-
res = [{k: v} for k, v in dict_res.items()]
|
310 |
-
return res
|
311 |
-
|
312 |
-
async def filter_keys(self, keys: set[str]) -> set[str]:
|
313 |
-
"""Return keys that don't exist in storage"""
|
314 |
-
SQL = SQL_TEMPLATES["filter_keys"].format(
|
315 |
-
table_name=namespace_to_table_name(self.namespace),
|
316 |
-
ids=",".join([f"'{id}'" for id in keys]),
|
317 |
-
)
|
318 |
-
params = {"workspace": self.db.workspace}
|
319 |
-
res = await self.db.query(SQL, params, multirows=True)
|
320 |
-
if res:
|
321 |
-
exist_keys = [key["id"] for key in res]
|
322 |
-
data = set([s for s in keys if s not in exist_keys])
|
323 |
-
return data
|
324 |
-
else:
|
325 |
-
return set(keys)
|
326 |
-
|
327 |
-
################ INSERT METHODS ################
|
328 |
-
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
329 |
-
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
330 |
-
if not data:
|
331 |
-
return
|
332 |
-
|
333 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
334 |
-
list_data = [
|
335 |
-
{
|
336 |
-
"id": k,
|
337 |
-
**{k1: v1 for k1, v1 in v.items()},
|
338 |
-
}
|
339 |
-
for k, v in data.items()
|
340 |
-
]
|
341 |
-
contents = [v["content"] for v in data.values()]
|
342 |
-
batches = [
|
343 |
-
contents[i : i + self._max_batch_size]
|
344 |
-
for i in range(0, len(contents), self._max_batch_size)
|
345 |
-
]
|
346 |
-
embeddings_list = await asyncio.gather(
|
347 |
-
*[self.embedding_func(batch) for batch in batches]
|
348 |
-
)
|
349 |
-
embeddings = np.concatenate(embeddings_list)
|
350 |
-
for i, d in enumerate(list_data):
|
351 |
-
d["__vector__"] = embeddings[i]
|
352 |
-
|
353 |
-
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
354 |
-
for item in list_data:
|
355 |
-
_data = {
|
356 |
-
"id": item["id"],
|
357 |
-
"content": item["content"],
|
358 |
-
"workspace": self.db.workspace,
|
359 |
-
"tokens": item["tokens"],
|
360 |
-
"chunk_order_index": item["chunk_order_index"],
|
361 |
-
"full_doc_id": item["full_doc_id"],
|
362 |
-
"content_vector": item["__vector__"],
|
363 |
-
"status": item["status"],
|
364 |
-
}
|
365 |
-
await self.db.execute(merge_sql, _data)
|
366 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
367 |
-
for k, v in data.items():
|
368 |
-
# values.clear()
|
369 |
-
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
370 |
-
_data = {
|
371 |
-
"id": k,
|
372 |
-
"content": v["content"],
|
373 |
-
"workspace": self.db.workspace,
|
374 |
-
}
|
375 |
-
await self.db.execute(merge_sql, _data)
|
376 |
-
|
377 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
378 |
-
for mode, items in data.items():
|
379 |
-
for k, v in items.items():
|
380 |
-
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
381 |
-
_data = {
|
382 |
-
"workspace": self.db.workspace,
|
383 |
-
"id": k,
|
384 |
-
"original_prompt": v["original_prompt"],
|
385 |
-
"return_value": v["return"],
|
386 |
-
"cache_mode": mode,
|
387 |
-
}
|
388 |
-
|
389 |
-
await self.db.execute(upsert_sql, _data)
|
390 |
-
|
391 |
-
async def index_done_callback(self) -> None:
|
392 |
-
# Oracle handles persistence automatically
|
393 |
-
pass
|
394 |
-
|
395 |
-
|
396 |
-
@final
|
397 |
-
@dataclass
|
398 |
-
class OracleVectorDBStorage(BaseVectorStorage):
|
399 |
-
db: OracleDB | None = field(default=None)
|
400 |
-
|
401 |
-
def __post_init__(self):
|
402 |
-
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
403 |
-
cosine_threshold = config.get("cosine_better_than_threshold")
|
404 |
-
if cosine_threshold is None:
|
405 |
-
raise ValueError(
|
406 |
-
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
407 |
-
)
|
408 |
-
self.cosine_better_than_threshold = cosine_threshold
|
409 |
-
|
410 |
-
async def initialize(self):
|
411 |
-
if self.db is None:
|
412 |
-
self.db = await ClientManager.get_client()
|
413 |
-
|
414 |
-
async def finalize(self):
|
415 |
-
if self.db is not None:
|
416 |
-
await ClientManager.release_client(self.db)
|
417 |
-
self.db = None
|
418 |
-
|
419 |
-
#################### query method ###############
|
420 |
-
async def query(
|
421 |
-
self, query: str, top_k: int, ids: list[str] | None = None
|
422 |
-
) -> list[dict[str, Any]]:
|
423 |
-
embeddings = await self.embedding_func([query])
|
424 |
-
embedding = embeddings[0]
|
425 |
-
# 转换精度
|
426 |
-
dtype = str(embedding.dtype).upper()
|
427 |
-
dimension = embedding.shape[0]
|
428 |
-
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
429 |
-
|
430 |
-
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
431 |
-
params = {
|
432 |
-
"embedding_string": embedding_string,
|
433 |
-
"workspace": self.db.workspace,
|
434 |
-
"top_k": top_k,
|
435 |
-
"better_than_threshold": self.cosine_better_than_threshold,
|
436 |
-
}
|
437 |
-
results = await self.db.query(SQL, params=params, multirows=True)
|
438 |
-
return results
|
439 |
-
|
440 |
-
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
441 |
-
raise NotImplementedError
|
442 |
-
|
443 |
-
async def index_done_callback(self) -> None:
|
444 |
-
# Oracles handles persistence automatically
|
445 |
-
pass
|
446 |
-
|
447 |
-
async def delete(self, ids: list[str]) -> None:
|
448 |
-
"""Delete vectors with specified IDs
|
449 |
-
|
450 |
-
Args:
|
451 |
-
ids: List of vector IDs to be deleted
|
452 |
-
"""
|
453 |
-
if not ids:
|
454 |
-
return
|
455 |
-
|
456 |
-
try:
|
457 |
-
SQL = SQL_TEMPLATES["delete_vectors"].format(
|
458 |
-
ids=",".join([f"'{id}'" for id in ids])
|
459 |
-
)
|
460 |
-
params = {"workspace": self.db.workspace}
|
461 |
-
await self.db.execute(SQL, params)
|
462 |
-
logger.info(
|
463 |
-
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
464 |
-
)
|
465 |
-
except Exception as e:
|
466 |
-
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
467 |
-
raise
|
468 |
-
|
469 |
-
async def delete_entity(self, entity_name: str) -> None:
|
470 |
-
"""Delete entity by name
|
471 |
-
|
472 |
-
Args:
|
473 |
-
entity_name: Name of the entity to delete
|
474 |
-
"""
|
475 |
-
try:
|
476 |
-
SQL = SQL_TEMPLATES["delete_entity"]
|
477 |
-
params = {"workspace": self.db.workspace, "entity_name": entity_name}
|
478 |
-
await self.db.execute(SQL, params)
|
479 |
-
logger.info(f"Successfully deleted entity {entity_name}")
|
480 |
-
except Exception as e:
|
481 |
-
logger.error(f"Error deleting entity {entity_name}: {e}")
|
482 |
-
raise
|
483 |
-
|
484 |
-
async def delete_entity_relation(self, entity_name: str) -> None:
|
485 |
-
"""Delete all relations connected to an entity
|
486 |
-
|
487 |
-
Args:
|
488 |
-
entity_name: Name of the entity whose relations should be deleted
|
489 |
-
"""
|
490 |
-
try:
|
491 |
-
SQL = SQL_TEMPLATES["delete_entity_relations"]
|
492 |
-
params = {"workspace": self.db.workspace, "entity_name": entity_name}
|
493 |
-
await self.db.execute(SQL, params)
|
494 |
-
logger.info(f"Successfully deleted relations for entity {entity_name}")
|
495 |
-
except Exception as e:
|
496 |
-
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
497 |
-
raise
|
498 |
-
|
499 |
-
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
500 |
-
"""Search for records with IDs starting with a specific prefix.
|
501 |
-
|
502 |
-
Args:
|
503 |
-
prefix: The prefix to search for in record IDs
|
504 |
-
|
505 |
-
Returns:
|
506 |
-
List of records with matching ID prefixes
|
507 |
-
"""
|
508 |
-
try:
|
509 |
-
# Determine the appropriate table based on namespace
|
510 |
-
table_name = namespace_to_table_name(self.namespace)
|
511 |
-
|
512 |
-
# Create SQL query to find records with IDs starting with prefix
|
513 |
-
search_sql = f"""
|
514 |
-
SELECT * FROM {table_name}
|
515 |
-
WHERE workspace = :workspace
|
516 |
-
AND id LIKE :prefix_pattern
|
517 |
-
ORDER BY id
|
518 |
-
"""
|
519 |
-
|
520 |
-
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
|
521 |
-
|
522 |
-
# Execute query and get results
|
523 |
-
results = await self.db.query(search_sql, params, multirows=True)
|
524 |
-
|
525 |
-
logger.debug(
|
526 |
-
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
|
527 |
-
)
|
528 |
-
return results or []
|
529 |
-
|
530 |
-
except Exception as e:
|
531 |
-
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
532 |
-
return []
|
533 |
-
|
534 |
-
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
535 |
-
"""Get vector data by its ID
|
536 |
-
|
537 |
-
Args:
|
538 |
-
id: The unique identifier of the vector
|
539 |
-
|
540 |
-
Returns:
|
541 |
-
The vector data if found, or None if not found
|
542 |
-
"""
|
543 |
-
try:
|
544 |
-
# Determine the table name based on namespace
|
545 |
-
table_name = namespace_to_table_name(self.namespace)
|
546 |
-
if not table_name:
|
547 |
-
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
548 |
-
return None
|
549 |
-
|
550 |
-
# Create the appropriate ID field name based on namespace
|
551 |
-
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
552 |
-
if "CHUNKS" in table_name:
|
553 |
-
id_field = "chunk_id"
|
554 |
-
|
555 |
-
# Prepare and execute the query
|
556 |
-
query = f"""
|
557 |
-
SELECT * FROM {table_name}
|
558 |
-
WHERE {id_field} = :id AND workspace = :workspace
|
559 |
-
"""
|
560 |
-
params = {"id": id, "workspace": self.db.workspace}
|
561 |
-
|
562 |
-
result = await self.db.query(query, params)
|
563 |
-
return result
|
564 |
-
except Exception as e:
|
565 |
-
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
566 |
-
return None
|
567 |
-
|
568 |
-
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
569 |
-
"""Get multiple vector data by their IDs
|
570 |
-
|
571 |
-
Args:
|
572 |
-
ids: List of unique identifiers
|
573 |
-
|
574 |
-
Returns:
|
575 |
-
List of vector data objects that were found
|
576 |
-
"""
|
577 |
-
if not ids:
|
578 |
-
return []
|
579 |
-
|
580 |
-
try:
|
581 |
-
# Determine the table name based on namespace
|
582 |
-
table_name = namespace_to_table_name(self.namespace)
|
583 |
-
if not table_name:
|
584 |
-
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
585 |
-
return []
|
586 |
-
|
587 |
-
# Create the appropriate ID field name based on namespace
|
588 |
-
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
589 |
-
if "CHUNKS" in table_name:
|
590 |
-
id_field = "chunk_id"
|
591 |
-
|
592 |
-
# Format the list of IDs for SQL IN clause
|
593 |
-
ids_list = ", ".join([f"'{id}'" for id in ids])
|
594 |
-
|
595 |
-
# Prepare and execute the query
|
596 |
-
query = f"""
|
597 |
-
SELECT * FROM {table_name}
|
598 |
-
WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
|
599 |
-
"""
|
600 |
-
params = {"workspace": self.db.workspace}
|
601 |
-
|
602 |
-
results = await self.db.query(query, params, multirows=True)
|
603 |
-
return results or []
|
604 |
-
except Exception as e:
|
605 |
-
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
606 |
-
return []
|
607 |
-
|
608 |
-
|
609 |
-
@final
|
610 |
-
@dataclass
|
611 |
-
class OracleGraphStorage(BaseGraphStorage):
|
612 |
-
db: OracleDB = field(default=None)
|
613 |
-
|
614 |
-
def __post_init__(self):
|
615 |
-
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
616 |
-
|
617 |
-
async def initialize(self):
|
618 |
-
if self.db is None:
|
619 |
-
self.db = await ClientManager.get_client()
|
620 |
-
|
621 |
-
async def finalize(self):
|
622 |
-
if self.db is not None:
|
623 |
-
await ClientManager.release_client(self.db)
|
624 |
-
self.db = None
|
625 |
-
|
626 |
-
#################### insert method ################
|
627 |
-
|
628 |
-
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
629 |
-
entity_name = node_id
|
630 |
-
entity_type = node_data["entity_type"]
|
631 |
-
description = node_data["description"]
|
632 |
-
source_id = node_data["source_id"]
|
633 |
-
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
|
634 |
-
|
635 |
-
content = entity_name + description
|
636 |
-
contents = [content]
|
637 |
-
batches = [
|
638 |
-
contents[i : i + self._max_batch_size]
|
639 |
-
for i in range(0, len(contents), self._max_batch_size)
|
640 |
-
]
|
641 |
-
embeddings_list = await asyncio.gather(
|
642 |
-
*[self.embedding_func(batch) for batch in batches]
|
643 |
-
)
|
644 |
-
embeddings = np.concatenate(embeddings_list)
|
645 |
-
content_vector = embeddings[0]
|
646 |
-
merge_sql = SQL_TEMPLATES["merge_node"]
|
647 |
-
data = {
|
648 |
-
"workspace": self.db.workspace,
|
649 |
-
"name": entity_name,
|
650 |
-
"entity_type": entity_type,
|
651 |
-
"description": description,
|
652 |
-
"source_chunk_id": source_id,
|
653 |
-
"content": content,
|
654 |
-
"content_vector": content_vector,
|
655 |
-
}
|
656 |
-
await self.db.execute(merge_sql, data)
|
657 |
-
# self._graph.add_node(node_id, **node_data)
|
658 |
-
|
659 |
-
async def upsert_edge(
|
660 |
-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
661 |
-
) -> None:
|
662 |
-
"""插入或更新边"""
|
663 |
-
# print("go into upsert edge method")
|
664 |
-
source_name = source_node_id
|
665 |
-
target_name = target_node_id
|
666 |
-
weight = edge_data["weight"]
|
667 |
-
keywords = edge_data["keywords"]
|
668 |
-
description = edge_data["description"]
|
669 |
-
source_chunk_id = edge_data["source_id"]
|
670 |
-
logger.debug(
|
671 |
-
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
672 |
-
)
|
673 |
-
|
674 |
-
content = keywords + source_name + target_name + description
|
675 |
-
contents = [content]
|
676 |
-
batches = [
|
677 |
-
contents[i : i + self._max_batch_size]
|
678 |
-
for i in range(0, len(contents), self._max_batch_size)
|
679 |
-
]
|
680 |
-
embeddings_list = await asyncio.gather(
|
681 |
-
*[self.embedding_func(batch) for batch in batches]
|
682 |
-
)
|
683 |
-
embeddings = np.concatenate(embeddings_list)
|
684 |
-
content_vector = embeddings[0]
|
685 |
-
merge_sql = SQL_TEMPLATES["merge_edge"]
|
686 |
-
data = {
|
687 |
-
"workspace": self.db.workspace,
|
688 |
-
"source_name": source_name,
|
689 |
-
"target_name": target_name,
|
690 |
-
"weight": weight,
|
691 |
-
"keywords": keywords,
|
692 |
-
"description": description,
|
693 |
-
"source_chunk_id": source_chunk_id,
|
694 |
-
"content": content,
|
695 |
-
"content_vector": content_vector,
|
696 |
-
}
|
697 |
-
# print(merge_sql)
|
698 |
-
await self.db.execute(merge_sql, data)
|
699 |
-
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
700 |
-
|
701 |
-
async def embed_nodes(
|
702 |
-
self, algorithm: str
|
703 |
-
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
704 |
-
if algorithm not in self._node_embed_algorithms:
|
705 |
-
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
706 |
-
return await self._node_embed_algorithms[algorithm]()
|
707 |
-
|
708 |
-
async def _node2vec_embed(self):
|
709 |
-
"""为节点生成向量"""
|
710 |
-
embeddings, nodes = embed.node2vec_embed(
|
711 |
-
self._graph,
|
712 |
-
**self.config["node2vec_params"],
|
713 |
-
)
|
714 |
-
|
715 |
-
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
716 |
-
return embeddings, nodes_ids
|
717 |
-
|
718 |
-
async def index_done_callback(self) -> None:
|
719 |
-
# Oracles handles persistence automatically
|
720 |
-
pass
|
721 |
-
|
722 |
-
#################### query method #################
|
723 |
-
async def has_node(self, node_id: str) -> bool:
|
724 |
-
"""根据节点id检查节点是否存在"""
|
725 |
-
SQL = SQL_TEMPLATES["has_node"]
|
726 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
727 |
-
res = await self.db.query(SQL, params)
|
728 |
-
if res:
|
729 |
-
# print("Node exist!",res)
|
730 |
-
return True
|
731 |
-
else:
|
732 |
-
# print("Node not exist!")
|
733 |
-
return False
|
734 |
-
|
735 |
-
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
736 |
-
SQL = SQL_TEMPLATES["has_edge"]
|
737 |
-
params = {
|
738 |
-
"workspace": self.db.workspace,
|
739 |
-
"source_node_id": source_node_id,
|
740 |
-
"target_node_id": target_node_id,
|
741 |
-
}
|
742 |
-
res = await self.db.query(SQL, params)
|
743 |
-
if res:
|
744 |
-
# print("Edge exist!",res)
|
745 |
-
return True
|
746 |
-
else:
|
747 |
-
# print("Edge not exist!")
|
748 |
-
return False
|
749 |
-
|
750 |
-
async def node_degree(self, node_id: str) -> int:
|
751 |
-
SQL = SQL_TEMPLATES["node_degree"]
|
752 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
753 |
-
res = await self.db.query(SQL, params)
|
754 |
-
if res:
|
755 |
-
return res["degree"]
|
756 |
-
else:
|
757 |
-
return 0
|
758 |
-
|
759 |
-
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
760 |
-
"""根据源和目标节点id获取边的度"""
|
761 |
-
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
762 |
-
return degree
|
763 |
-
|
764 |
-
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
765 |
-
"""根据节点id获取节点数据"""
|
766 |
-
SQL = SQL_TEMPLATES["get_node"]
|
767 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
768 |
-
res = await self.db.query(SQL, params)
|
769 |
-
if res:
|
770 |
-
return res
|
771 |
-
else:
|
772 |
-
return None
|
773 |
-
|
774 |
-
async def get_edge(
|
775 |
-
self, source_node_id: str, target_node_id: str
|
776 |
-
) -> dict[str, str] | None:
|
777 |
-
SQL = SQL_TEMPLATES["get_edge"]
|
778 |
-
params = {
|
779 |
-
"workspace": self.db.workspace,
|
780 |
-
"source_node_id": source_node_id,
|
781 |
-
"target_node_id": target_node_id,
|
782 |
-
}
|
783 |
-
res = await self.db.query(SQL, params)
|
784 |
-
if res:
|
785 |
-
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
786 |
-
return res
|
787 |
-
else:
|
788 |
-
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
789 |
-
return None
|
790 |
-
|
791 |
-
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
792 |
-
if await self.has_node(source_node_id):
|
793 |
-
SQL = SQL_TEMPLATES["get_node_edges"]
|
794 |
-
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
795 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
796 |
-
if res:
|
797 |
-
data = [(i["source_name"], i["target_name"]) for i in res]
|
798 |
-
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
799 |
-
return data
|
800 |
-
else:
|
801 |
-
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
802 |
-
return []
|
803 |
-
|
804 |
-
async def get_all_nodes(self, limit: int):
|
805 |
-
"""查询所有节点"""
|
806 |
-
SQL = SQL_TEMPLATES["get_all_nodes"]
|
807 |
-
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
808 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
809 |
-
if res:
|
810 |
-
return res
|
811 |
-
|
812 |
-
async def get_all_edges(self, limit: int):
|
813 |
-
"""查询所有边"""
|
814 |
-
SQL = SQL_TEMPLATES["get_all_edges"]
|
815 |
-
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
816 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
817 |
-
if res:
|
818 |
-
return res
|
819 |
-
|
820 |
-
async def get_statistics(self):
|
821 |
-
SQL = SQL_TEMPLATES["get_statistics"]
|
822 |
-
params = {"workspace": self.db.workspace}
|
823 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
824 |
-
if res:
|
825 |
-
return res
|
826 |
-
|
827 |
-
async def delete_node(self, node_id: str) -> None:
|
828 |
-
"""Delete a node from the graph
|
829 |
-
|
830 |
-
Args:
|
831 |
-
node_id: ID of the node to delete
|
832 |
-
"""
|
833 |
-
try:
|
834 |
-
# First delete all relations connected to this node
|
835 |
-
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
|
836 |
-
params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
|
837 |
-
await self.db.execute(delete_relations_sql, params_relations)
|
838 |
-
|
839 |
-
# Then delete the node itself
|
840 |
-
delete_node_sql = SQL_TEMPLATES["delete_entity"]
|
841 |
-
params_node = {"workspace": self.db.workspace, "entity_name": node_id}
|
842 |
-
await self.db.execute(delete_node_sql, params_node)
|
843 |
-
|
844 |
-
logger.info(
|
845 |
-
f"Successfully deleted node {node_id} and all its relationships"
|
846 |
-
)
|
847 |
-
except Exception as e:
|
848 |
-
logger.error(f"Error deleting node {node_id}: {e}")
|
849 |
-
raise
|
850 |
-
|
851 |
-
async def remove_nodes(self, nodes: list[str]) -> None:
|
852 |
-
"""Delete multiple nodes from the graph
|
853 |
-
|
854 |
-
Args:
|
855 |
-
nodes: List of node IDs to be deleted
|
856 |
-
"""
|
857 |
-
if not nodes:
|
858 |
-
return
|
859 |
-
|
860 |
-
try:
|
861 |
-
for node in nodes:
|
862 |
-
# For each node, first delete all its relationships
|
863 |
-
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
|
864 |
-
params_relations = {"workspace": self.db.workspace, "entity_name": node}
|
865 |
-
await self.db.execute(delete_relations_sql, params_relations)
|
866 |
-
|
867 |
-
# Then delete the node itself
|
868 |
-
delete_node_sql = SQL_TEMPLATES["delete_entity"]
|
869 |
-
params_node = {"workspace": self.db.workspace, "entity_name": node}
|
870 |
-
await self.db.execute(delete_node_sql, params_node)
|
871 |
-
|
872 |
-
logger.info(
|
873 |
-
f"Successfully deleted {len(nodes)} nodes and their relationships"
|
874 |
-
)
|
875 |
-
except Exception as e:
|
876 |
-
logger.error(f"Error during batch node deletion: {e}")
|
877 |
-
raise
|
878 |
-
|
879 |
-
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
|
880 |
-
"""Delete multiple edges from the graph
|
881 |
-
|
882 |
-
Args:
|
883 |
-
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
884 |
-
"""
|
885 |
-
if not edges:
|
886 |
-
return
|
887 |
-
|
888 |
-
try:
|
889 |
-
for source, target in edges:
|
890 |
-
# Check if the edge exists before attempting to delete
|
891 |
-
if await self.has_edge(source, target):
|
892 |
-
# Delete the edge using a SQL query that matches both source and target
|
893 |
-
delete_edge_sql = """
|
894 |
-
DELETE FROM LIGHTRAG_GRAPH_EDGES
|
895 |
-
WHERE workspace = :workspace
|
896 |
-
AND source_name = :source_name
|
897 |
-
AND target_name = :target_name
|
898 |
-
"""
|
899 |
-
params = {
|
900 |
-
"workspace": self.db.workspace,
|
901 |
-
"source_name": source,
|
902 |
-
"target_name": target,
|
903 |
-
}
|
904 |
-
await self.db.execute(delete_edge_sql, params)
|
905 |
-
|
906 |
-
logger.info(f"Successfully deleted {len(edges)} edges from the graph")
|
907 |
-
except Exception as e:
|
908 |
-
logger.error(f"Error during batch edge deletion: {e}")
|
909 |
-
raise
|
910 |
-
|
911 |
-
async def get_all_labels(self) -> list[str]:
|
912 |
-
"""Get all unique entity types (labels) in the graph
|
913 |
-
|
914 |
-
Returns:
|
915 |
-
List of unique entity types/labels
|
916 |
-
"""
|
917 |
-
try:
|
918 |
-
SQL = """
|
919 |
-
SELECT DISTINCT entity_type
|
920 |
-
FROM LIGHTRAG_GRAPH_NODES
|
921 |
-
WHERE workspace = :workspace
|
922 |
-
ORDER BY entity_type
|
923 |
-
"""
|
924 |
-
params = {"workspace": self.db.workspace}
|
925 |
-
results = await self.db.query(SQL, params, multirows=True)
|
926 |
-
|
927 |
-
if results:
|
928 |
-
labels = [row["entity_type"] for row in results]
|
929 |
-
return labels
|
930 |
-
else:
|
931 |
-
return []
|
932 |
-
except Exception as e:
|
933 |
-
logger.error(f"Error retrieving entity types: {e}")
|
934 |
-
return []
|
935 |
-
|
936 |
-
async def get_knowledge_graph(
|
937 |
-
self, node_label: str, max_depth: int = 5
|
938 |
-
) -> KnowledgeGraph:
|
939 |
-
"""Retrieve a connected subgraph starting from nodes matching the given label
|
940 |
-
|
941 |
-
Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
|
942 |
-
Prioritizes nodes by:
|
943 |
-
1. Nodes matching the specified label
|
944 |
-
2. Nodes directly connected to matching nodes
|
945 |
-
3. Node degree (number of connections)
|
946 |
-
|
947 |
-
Args:
|
948 |
-
node_label: Label to match for starting nodes (use "*" for all nodes)
|
949 |
-
max_depth: Maximum depth of traversal from starting nodes
|
950 |
-
|
951 |
-
Returns:
|
952 |
-
KnowledgeGraph object containing nodes and edges
|
953 |
-
"""
|
954 |
-
result = KnowledgeGraph()
|
955 |
-
|
956 |
-
try:
|
957 |
-
# Define maximum number of nodes to return
|
958 |
-
max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
|
959 |
-
|
960 |
-
if node_label == "*":
|
961 |
-
# For "*" label, get all nodes up to the limit
|
962 |
-
nodes_sql = """
|
963 |
-
SELECT name, entity_type, description, source_chunk_id
|
964 |
-
FROM LIGHTRAG_GRAPH_NODES
|
965 |
-
WHERE workspace = :workspace
|
966 |
-
ORDER BY id
|
967 |
-
FETCH FIRST :limit ROWS ONLY
|
968 |
-
"""
|
969 |
-
nodes_params = {
|
970 |
-
"workspace": self.db.workspace,
|
971 |
-
"limit": max_graph_nodes,
|
972 |
-
}
|
973 |
-
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
|
974 |
-
else:
|
975 |
-
# For specific label, find matching nodes and related nodes
|
976 |
-
nodes_sql = """
|
977 |
-
WITH matching_nodes AS (
|
978 |
-
SELECT name
|
979 |
-
FROM LIGHTRAG_GRAPH_NODES
|
980 |
-
WHERE workspace = :workspace
|
981 |
-
AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
|
982 |
-
)
|
983 |
-
SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
|
984 |
-
CASE
|
985 |
-
WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
|
986 |
-
WHEN EXISTS (
|
987 |
-
SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
|
988 |
-
WHERE workspace = :workspace
|
989 |
-
AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
|
990 |
-
OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
|
991 |
-
) THEN 1
|
992 |
-
ELSE 0
|
993 |
-
END AS priority,
|
994 |
-
(SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
|
995 |
-
WHERE workspace = :workspace
|
996 |
-
AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
|
997 |
-
FROM LIGHTRAG_GRAPH_NODES n
|
998 |
-
WHERE workspace = :workspace
|
999 |
-
ORDER BY priority DESC, degree DESC
|
1000 |
-
FETCH FIRST :limit ROWS ONLY
|
1001 |
-
"""
|
1002 |
-
nodes_params = {
|
1003 |
-
"workspace": self.db.workspace,
|
1004 |
-
"node_label": node_label,
|
1005 |
-
"limit": max_graph_nodes,
|
1006 |
-
}
|
1007 |
-
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
|
1008 |
-
|
1009 |
-
if not nodes:
|
1010 |
-
logger.warning(f"No nodes found matching '{node_label}'")
|
1011 |
-
return result
|
1012 |
-
|
1013 |
-
# Create mapping of node IDs to be used to filter edges
|
1014 |
-
node_names = [node["name"] for node in nodes]
|
1015 |
-
|
1016 |
-
# Add nodes to result
|
1017 |
-
seen_nodes = set()
|
1018 |
-
for node in nodes:
|
1019 |
-
node_id = node["name"]
|
1020 |
-
if node_id in seen_nodes:
|
1021 |
-
continue
|
1022 |
-
|
1023 |
-
# Create node properties dictionary
|
1024 |
-
properties = {
|
1025 |
-
"entity_type": node["entity_type"],
|
1026 |
-
"description": node["description"] or "",
|
1027 |
-
"source_id": node["source_chunk_id"] or "",
|
1028 |
-
}
|
1029 |
-
|
1030 |
-
# Add node to result
|
1031 |
-
result.nodes.append(
|
1032 |
-
KnowledgeGraphNode(
|
1033 |
-
id=node_id, labels=[node["entity_type"]], properties=properties
|
1034 |
-
)
|
1035 |
-
)
|
1036 |
-
seen_nodes.add(node_id)
|
1037 |
-
|
1038 |
-
# Get edges between these nodes
|
1039 |
-
edges_sql = """
|
1040 |
-
SELECT source_name, target_name, weight, keywords, description, source_chunk_id
|
1041 |
-
FROM LIGHTRAG_GRAPH_EDGES
|
1042 |
-
WHERE workspace = :workspace
|
1043 |
-
AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
|
1044 |
-
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
|
1045 |
-
ORDER BY id
|
1046 |
-
"""
|
1047 |
-
edges_params = {"workspace": self.db.workspace, "node_names": node_names}
|
1048 |
-
edges = await self.db.query(edges_sql, edges_params, multirows=True)
|
1049 |
-
|
1050 |
-
# Add edges to result
|
1051 |
-
seen_edges = set()
|
1052 |
-
for edge in edges:
|
1053 |
-
source = edge["source_name"]
|
1054 |
-
target = edge["target_name"]
|
1055 |
-
edge_id = f"{source}-{target}"
|
1056 |
-
|
1057 |
-
if edge_id in seen_edges:
|
1058 |
-
continue
|
1059 |
-
|
1060 |
-
# Create edge properties dictionary
|
1061 |
-
properties = {
|
1062 |
-
"weight": edge["weight"] or 0.0,
|
1063 |
-
"keywords": edge["keywords"] or "",
|
1064 |
-
"description": edge["description"] or "",
|
1065 |
-
"source_id": edge["source_chunk_id"] or "",
|
1066 |
-
}
|
1067 |
-
|
1068 |
-
# Add edge to result
|
1069 |
-
result.edges.append(
|
1070 |
-
KnowledgeGraphEdge(
|
1071 |
-
id=edge_id,
|
1072 |
-
type="RELATED",
|
1073 |
-
source=source,
|
1074 |
-
target=target,
|
1075 |
-
properties=properties,
|
1076 |
-
)
|
1077 |
-
)
|
1078 |
-
seen_edges.add(edge_id)
|
1079 |
-
|
1080 |
-
logger.info(
|
1081 |
-
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
1082 |
-
)
|
1083 |
-
|
1084 |
-
except Exception as e:
|
1085 |
-
logger.error(f"Error retrieving knowledge graph: {e}")
|
1086 |
-
|
1087 |
-
return result
|
1088 |
-
|
1089 |
-
|
1090 |
-
N_T = {
|
1091 |
-
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
1092 |
-
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
1093 |
-
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
1094 |
-
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
|
1095 |
-
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
|
1096 |
-
}
|
1097 |
-
|
1098 |
-
|
1099 |
-
def namespace_to_table_name(namespace: str) -> str:
|
1100 |
-
for k, v in N_T.items():
|
1101 |
-
if is_namespace(namespace, k):
|
1102 |
-
return v
|
1103 |
-
|
1104 |
-
|
1105 |
-
TABLES = {
|
1106 |
-
"LIGHTRAG_DOC_FULL": {
|
1107 |
-
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
1108 |
-
id varchar(256),
|
1109 |
-
workspace varchar(1024),
|
1110 |
-
doc_name varchar(1024),
|
1111 |
-
content CLOB,
|
1112 |
-
meta JSON,
|
1113 |
-
content_summary varchar(1024),
|
1114 |
-
content_length NUMBER,
|
1115 |
-
status varchar(256),
|
1116 |
-
chunks_count NUMBER,
|
1117 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1118 |
-
updatetime TIMESTAMP DEFAULT NULL,
|
1119 |
-
error varchar(4096)
|
1120 |
-
)"""
|
1121 |
-
},
|
1122 |
-
"LIGHTRAG_DOC_CHUNKS": {
|
1123 |
-
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
1124 |
-
id varchar(256),
|
1125 |
-
workspace varchar(1024),
|
1126 |
-
full_doc_id varchar(256),
|
1127 |
-
status varchar(256),
|
1128 |
-
chunk_order_index NUMBER,
|
1129 |
-
tokens NUMBER,
|
1130 |
-
content CLOB,
|
1131 |
-
content_vector VECTOR,
|
1132 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1133 |
-
updatetime TIMESTAMP DEFAULT NULL
|
1134 |
-
)"""
|
1135 |
-
},
|
1136 |
-
"LIGHTRAG_GRAPH_NODES": {
|
1137 |
-
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
1138 |
-
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
1139 |
-
workspace varchar(1024),
|
1140 |
-
name varchar(2048),
|
1141 |
-
entity_type varchar(1024),
|
1142 |
-
description CLOB,
|
1143 |
-
source_chunk_id varchar(256),
|
1144 |
-
content CLOB,
|
1145 |
-
content_vector VECTOR,
|
1146 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1147 |
-
updatetime TIMESTAMP DEFAULT NULL
|
1148 |
-
)"""
|
1149 |
-
},
|
1150 |
-
"LIGHTRAG_GRAPH_EDGES": {
|
1151 |
-
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
1152 |
-
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
1153 |
-
workspace varchar(1024),
|
1154 |
-
source_name varchar(2048),
|
1155 |
-
target_name varchar(2048),
|
1156 |
-
weight NUMBER,
|
1157 |
-
keywords CLOB,
|
1158 |
-
description CLOB,
|
1159 |
-
source_chunk_id varchar(256),
|
1160 |
-
content CLOB,
|
1161 |
-
content_vector VECTOR,
|
1162 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1163 |
-
updatetime TIMESTAMP DEFAULT NULL
|
1164 |
-
)"""
|
1165 |
-
},
|
1166 |
-
"LIGHTRAG_LLM_CACHE": {
|
1167 |
-
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
|
1168 |
-
id varchar(256) PRIMARY KEY,
|
1169 |
-
workspace varchar(1024),
|
1170 |
-
cache_mode varchar(256),
|
1171 |
-
model_name varchar(256),
|
1172 |
-
original_prompt clob,
|
1173 |
-
return_value clob,
|
1174 |
-
embedding CLOB,
|
1175 |
-
embedding_shape NUMBER,
|
1176 |
-
embedding_min NUMBER,
|
1177 |
-
embedding_max NUMBER,
|
1178 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1179 |
-
updatetime TIMESTAMP DEFAULT NULL
|
1180 |
-
)"""
|
1181 |
-
},
|
1182 |
-
"LIGHTRAG_GRAPH": {
|
1183 |
-
"ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
1184 |
-
VERTEX TABLES (
|
1185 |
-
lightrag_graph_nodes KEY (id)
|
1186 |
-
LABEL entity
|
1187 |
-
PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
|
1188 |
-
)
|
1189 |
-
EDGE TABLES (
|
1190 |
-
lightrag_graph_edges KEY (id)
|
1191 |
-
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
1192 |
-
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
1193 |
-
LABEL has_relation
|
1194 |
-
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
1195 |
-
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
|
1196 |
-
},
|
1197 |
-
}
|
1198 |
-
|
1199 |
-
|
1200 |
-
SQL_TEMPLATES = {
|
1201 |
-
# SQL for KVStorage
|
1202 |
-
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
|
1203 |
-
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
|
1204 |
-
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
1205 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
|
1206 |
-
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
1207 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
|
1208 |
-
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
1209 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
|
1210 |
-
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
|
1211 |
-
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
|
1212 |
-
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
|
1213 |
-
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
|
1214 |
-
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
1215 |
-
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
1216 |
-
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
1217 |
-
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
1218 |
-
USING DUAL
|
1219 |
-
ON (a.id = :id and a.workspace = :workspace)
|
1220 |
-
WHEN NOT MATCHED THEN
|
1221 |
-
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
|
1222 |
-
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
|
1223 |
-
USING DUAL
|
1224 |
-
ON (id = :id and workspace = :workspace)
|
1225 |
-
WHEN NOT MATCHED THEN INSERT
|
1226 |
-
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
|
1227 |
-
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
|
1228 |
-
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
|
1229 |
-
USING DUAL
|
1230 |
-
ON (a.id = :id)
|
1231 |
-
WHEN NOT MATCHED THEN
|
1232 |
-
INSERT (workspace,id,original_prompt,return_value,cache_mode)
|
1233 |
-
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
|
1234 |
-
WHEN MATCHED THEN UPDATE
|
1235 |
-
SET original_prompt = :original_prompt,
|
1236 |
-
return_value = :return_value,
|
1237 |
-
cache_mode = :cache_mode,
|
1238 |
-
updatetime = SYSDATE""",
|
1239 |
-
# SQL for VectorStorage
|
1240 |
-
"entities": """SELECT name as entity_name FROM
|
1241 |
-
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
1242 |
-
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
|
1243 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
1244 |
-
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
1245 |
-
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
1246 |
-
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
|
1247 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
1248 |
-
"chunks": """SELECT id FROM
|
1249 |
-
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
1250 |
-
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
|
1251 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
1252 |
-
# SQL for GraphStorage
|
1253 |
-
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
1254 |
-
MATCH (a)
|
1255 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
1256 |
-
COLUMNS (a.name))""",
|
1257 |
-
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
1258 |
-
MATCH (a) -[e]-> (b)
|
1259 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
1260 |
-
AND a.name=:source_node_id AND b.name=:target_node_id
|
1261 |
-
COLUMNS (e.source_name,e.target_name) )""",
|
1262 |
-
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
1263 |
-
MATCH (a)-[e]->(b)
|
1264 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
1265 |
-
AND a.name=:node_id or b.name = :node_id
|
1266 |
-
COLUMNS (a.name))""",
|
1267 |
-
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
1268 |
-
FROM GRAPH_TABLE (lightrag_graph
|
1269 |
-
MATCH (a)
|
1270 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
1271 |
-
COLUMNS (a.name)
|
1272 |
-
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
1273 |
-
WHERE t2.workspace=:workspace""",
|
1274 |
-
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
1275 |
-
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
1276 |
-
FROM GRAPH_TABLE (lightrag_graph
|
1277 |
-
MATCH (a)-[e]->(b)
|
1278 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
1279 |
-
AND a.name=:source_node_id and b.name = :target_node_id
|
1280 |
-
COLUMNS (e.id,a.name as source_id)
|
1281 |
-
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
1282 |
-
"get_node_edges": """SELECT source_name,target_name
|
1283 |
-
FROM GRAPH_TABLE (lightrag_graph
|
1284 |
-
MATCH (a)-[e]->(b)
|
1285 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
1286 |
-
AND a.name=:source_node_id
|
1287 |
-
COLUMNS (a.name as source_name,b.name as target_name))""",
|
1288 |
-
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
1289 |
-
USING DUAL
|
1290 |
-
ON (a.workspace=:workspace and a.name=:name)
|
1291 |
-
WHEN NOT MATCHED THEN
|
1292 |
-
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
1293 |
-
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
|
1294 |
-
WHEN MATCHED THEN
|
1295 |
-
UPDATE SET
|
1296 |
-
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
1297 |
-
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
1298 |
-
USING DUAL
|
1299 |
-
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
|
1300 |
-
WHEN NOT MATCHED THEN
|
1301 |
-
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
1302 |
-
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
|
1303 |
-
WHEN MATCHED THEN
|
1304 |
-
UPDATE SET
|
1305 |
-
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
1306 |
-
"get_all_nodes": """WITH t0 AS (
|
1307 |
-
SELECT name AS id, entity_type AS label, entity_type, description,
|
1308 |
-
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
|
1309 |
-
FROM lightrag_graph_nodes
|
1310 |
-
WHERE workspace = :workspace
|
1311 |
-
ORDER BY createtime DESC fetch first :limit rows only
|
1312 |
-
), t1 AS (
|
1313 |
-
SELECT t0.id, source_chunk_id
|
1314 |
-
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
|
1315 |
-
), t2 AS (
|
1316 |
-
SELECT t1.id, LISTAGG(t2.content, '\n') content
|
1317 |
-
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
|
1318 |
-
GROUP BY t1.id
|
1319 |
-
)
|
1320 |
-
SELECT t0.id, label, entity_type, description, t2.content
|
1321 |
-
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
|
1322 |
-
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
1323 |
-
t1.weight,t1.DESCRIPTION,t2.content
|
1324 |
-
FROM LIGHTRAG_GRAPH_EDGES t1
|
1325 |
-
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
1326 |
-
WHERE t1.workspace=:workspace
|
1327 |
-
order by t1.CREATETIME DESC
|
1328 |
-
fetch first :limit rows only""",
|
1329 |
-
"get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
|
1330 |
-
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
|
1331 |
-
FROM (
|
1332 |
-
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
|
1333 |
-
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
|
1334 |
-
UNION
|
1335 |
-
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
|
1336 |
-
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
|
1337 |
-
)""",
|
1338 |
-
# SQL for deletion
|
1339 |
-
"delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
|
1340 |
-
"delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
|
1341 |
-
"delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
|
1342 |
-
"delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
|
1343 |
-
MATCH (a)
|
1344 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
1345 |
-
ACTION DELETE a)""",
|
1346 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -9,7 +9,6 @@ import configparser
|
|
9 |
|
10 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
11 |
|
12 |
-
import sys
|
13 |
from tenacity import (
|
14 |
retry,
|
15 |
retry_if_exception_type,
|
@@ -28,11 +27,6 @@ from ..base import (
|
|
28 |
from ..namespace import NameSpace, is_namespace
|
29 |
from ..utils import logger
|
30 |
|
31 |
-
if sys.platform.startswith("win"):
|
32 |
-
import asyncio.windows_events
|
33 |
-
|
34 |
-
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
35 |
-
|
36 |
import pipmaster as pm
|
37 |
|
38 |
if not pm.is_installed("asyncpg"):
|
@@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"):
|
|
41 |
import asyncpg # type: ignore
|
42 |
from asyncpg import Pool # type: ignore
|
43 |
|
|
|
|
|
|
|
44 |
|
45 |
class PostgreSQLDB:
|
46 |
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
@@ -118,6 +115,25 @@ class PostgreSQLDB:
|
|
118 |
)
|
119 |
raise e
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
async def query(
|
122 |
self,
|
123 |
sql: str,
|
@@ -254,8 +270,6 @@ class PGKVStorage(BaseKVStorage):
|
|
254 |
db: PostgreSQLDB = field(default=None)
|
255 |
|
256 |
def __post_init__(self):
|
257 |
-
namespace_prefix = self.global_config.get("namespace_prefix")
|
258 |
-
self.base_namespace = self.namespace.replace(namespace_prefix, "")
|
259 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
260 |
|
261 |
async def initialize(self):
|
@@ -271,7 +285,7 @@ class PGKVStorage(BaseKVStorage):
|
|
271 |
|
272 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
273 |
"""Get doc_full data by id."""
|
274 |
-
sql = SQL_TEMPLATES["get_by_id_" + self.
|
275 |
params = {"workspace": self.db.workspace, "id": id}
|
276 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
277 |
array_res = await self.db.query(sql, params, multirows=True)
|
@@ -285,7 +299,7 @@ class PGKVStorage(BaseKVStorage):
|
|
285 |
|
286 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
287 |
"""Specifically for llm_response_cache."""
|
288 |
-
sql = SQL_TEMPLATES["get_by_mode_id_" + self.
|
289 |
params = {"workspace": self.db.workspace, mode: mode, "id": id}
|
290 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
291 |
array_res = await self.db.query(sql, params, multirows=True)
|
@@ -299,7 +313,7 @@ class PGKVStorage(BaseKVStorage):
|
|
299 |
# Query by id
|
300 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
301 |
"""Get doc_chunks data by id"""
|
302 |
-
sql = SQL_TEMPLATES["get_by_ids_" + self.
|
303 |
ids=",".join([f"'{id}'" for id in ids])
|
304 |
)
|
305 |
params = {"workspace": self.db.workspace}
|
@@ -320,7 +334,7 @@ class PGKVStorage(BaseKVStorage):
|
|
320 |
|
321 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
322 |
"""Specifically for llm_response_cache."""
|
323 |
-
SQL = SQL_TEMPLATES["get_by_status_" + self.
|
324 |
params = {"workspace": self.db.workspace, "status": status}
|
325 |
return await self.db.query(SQL, params, multirows=True)
|
326 |
|
@@ -380,10 +394,85 @@ class PGKVStorage(BaseKVStorage):
|
|
380 |
# PG handles persistence automatically
|
381 |
pass
|
382 |
|
383 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
"""Drop the storage"""
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
|
389 |
@final
|
@@ -393,8 +482,6 @@ class PGVectorStorage(BaseVectorStorage):
|
|
393 |
|
394 |
def __post_init__(self):
|
395 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
396 |
-
namespace_prefix = self.global_config.get("namespace_prefix")
|
397 |
-
self.base_namespace = self.namespace.replace(namespace_prefix, "")
|
398 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
399 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
400 |
if cosine_threshold is None:
|
@@ -523,7 +610,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
523 |
else:
|
524 |
formatted_ids = "NULL"
|
525 |
|
526 |
-
sql = SQL_TEMPLATES[self.
|
527 |
embedding_string=embedding_string, doc_ids=formatted_ids
|
528 |
)
|
529 |
params = {
|
@@ -552,13 +639,12 @@ class PGVectorStorage(BaseVectorStorage):
|
|
552 |
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
553 |
return
|
554 |
|
555 |
-
|
556 |
-
delete_sql = (
|
557 |
-
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
|
558 |
-
)
|
559 |
|
560 |
try:
|
561 |
-
await self.db.execute(
|
|
|
|
|
562 |
logger.debug(
|
563 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
564 |
)
|
@@ -690,6 +776,24 @@ class PGVectorStorage(BaseVectorStorage):
|
|
690 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
691 |
return []
|
692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
|
694 |
@final
|
695 |
@dataclass
|
@@ -810,6 +914,35 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
810 |
# PG handles persistence automatically
|
811 |
pass
|
812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
814 |
"""Update or insert document status
|
815 |
|
@@ -846,10 +979,23 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
846 |
},
|
847 |
)
|
848 |
|
849 |
-
async def drop(self) ->
|
850 |
"""Drop the storage"""
|
851 |
-
|
852 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
|
854 |
|
855 |
class PGGraphQueryException(Exception):
|
@@ -937,31 +1083,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
937 |
if v.startswith("[") and v.endswith("]"):
|
938 |
if "::vertex" in v:
|
939 |
v = v.replace("::vertex", "")
|
940 |
-
|
941 |
-
dl = []
|
942 |
-
for vertex in vertexes:
|
943 |
-
prop = vertex.get("properties")
|
944 |
-
if not prop:
|
945 |
-
prop = {}
|
946 |
-
prop["label"] = PGGraphStorage._decode_graph_label(
|
947 |
-
prop["node_id"]
|
948 |
-
)
|
949 |
-
dl.append(prop)
|
950 |
-
d[k] = dl
|
951 |
|
952 |
elif "::edge" in v:
|
953 |
v = v.replace("::edge", "")
|
954 |
-
|
955 |
-
dl = []
|
956 |
-
for edge in edges:
|
957 |
-
dl.append(
|
958 |
-
(
|
959 |
-
vertices[edge["start_id"]],
|
960 |
-
edge["label"],
|
961 |
-
vertices[edge["end_id"]],
|
962 |
-
)
|
963 |
-
)
|
964 |
-
d[k] = dl
|
965 |
else:
|
966 |
print("WARNING: unsupported type")
|
967 |
continue
|
@@ -970,32 +1096,19 @@ class PGGraphStorage(BaseGraphStorage):
|
|
970 |
dtype = v.split("::")[-1]
|
971 |
v = v.split("::")[0]
|
972 |
if dtype == "vertex":
|
973 |
-
|
974 |
-
field = vertex.get("properties")
|
975 |
-
if not field:
|
976 |
-
field = {}
|
977 |
-
field["label"] = PGGraphStorage._decode_graph_label(
|
978 |
-
field["node_id"]
|
979 |
-
)
|
980 |
-
d[k] = field
|
981 |
-
# convert edge from id-label->id by replacing id with node information
|
982 |
-
# we only do this if the vertex was also returned in the query
|
983 |
-
# this is an attempt to be consistent with neo4j implementation
|
984 |
elif dtype == "edge":
|
985 |
-
|
986 |
-
d[k] = (
|
987 |
-
vertices.get(edge["start_id"], {}),
|
988 |
-
edge[
|
989 |
-
"label"
|
990 |
-
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
991 |
-
vertices.get(edge["end_id"], {}),
|
992 |
-
)
|
993 |
else:
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
|
|
|
|
|
|
|
|
999 |
|
1000 |
return d
|
1001 |
|
@@ -1025,56 +1138,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1025 |
)
|
1026 |
return "{" + ", ".join(props) + "}"
|
1027 |
|
1028 |
-
@staticmethod
|
1029 |
-
def _encode_graph_label(label: str) -> str:
|
1030 |
-
"""
|
1031 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
1032 |
-
|
1033 |
-
Args:
|
1034 |
-
label (str): the original label
|
1035 |
-
|
1036 |
-
Returns:
|
1037 |
-
str: the encoded label
|
1038 |
-
"""
|
1039 |
-
return "x" + label.encode().hex()
|
1040 |
-
|
1041 |
-
@staticmethod
|
1042 |
-
def _decode_graph_label(encoded_label: str) -> str:
|
1043 |
-
"""
|
1044 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
1045 |
-
|
1046 |
-
Args:
|
1047 |
-
encoded_label (str): the encoded label
|
1048 |
-
|
1049 |
-
Returns:
|
1050 |
-
str: the decoded label
|
1051 |
-
"""
|
1052 |
-
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
|
1053 |
-
|
1054 |
-
@staticmethod
|
1055 |
-
def _get_col_name(field: str, idx: int) -> str:
|
1056 |
-
"""
|
1057 |
-
Convert a cypher return field to a pgsql select field
|
1058 |
-
If possible keep the cypher column name, but create a generic name if necessary
|
1059 |
-
|
1060 |
-
Args:
|
1061 |
-
field (str): a return field from a cypher query to be formatted for pgsql
|
1062 |
-
idx (int): the position of the field in the return statement
|
1063 |
-
|
1064 |
-
Returns:
|
1065 |
-
str: the field to be used in the pgsql select statement
|
1066 |
-
"""
|
1067 |
-
# remove white space
|
1068 |
-
field = field.strip()
|
1069 |
-
# if an alias is provided for the field, use it
|
1070 |
-
if " as " in field:
|
1071 |
-
return field.split(" as ")[-1].strip()
|
1072 |
-
# if the return value is an unnamed primitive, give it a generic name
|
1073 |
-
if field.isnumeric() or field in ("true", "false", "null"):
|
1074 |
-
return f"column_{idx}"
|
1075 |
-
# otherwise return the value stripping out some common special chars
|
1076 |
-
return field.replace("(", "_").replace(")", "")
|
1077 |
-
|
1078 |
async def _query(
|
1079 |
self,
|
1080 |
query: str,
|
@@ -1125,10 +1188,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1125 |
return result
|
1126 |
|
1127 |
async def has_node(self, node_id: str) -> bool:
|
1128 |
-
entity_name_label =
|
1129 |
|
1130 |
query = """SELECT * FROM cypher('%s', $$
|
1131 |
-
MATCH (n:
|
1132 |
RETURN count(n) > 0 AS node_exists
|
1133 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
1134 |
|
@@ -1137,11 +1200,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1137 |
return single_result["node_exists"]
|
1138 |
|
1139 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
1140 |
-
src_label =
|
1141 |
-
tgt_label =
|
1142 |
|
1143 |
query = """SELECT * FROM cypher('%s', $$
|
1144 |
-
MATCH (a:
|
1145 |
RETURN COUNT(r) > 0 AS edge_exists
|
1146 |
$$) AS (edge_exists bool)""" % (
|
1147 |
self.graph_name,
|
@@ -1154,30 +1217,31 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1154 |
return single_result["edge_exists"]
|
1155 |
|
1156 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
1157 |
-
label
|
|
|
|
|
1158 |
query = """SELECT * FROM cypher('%s', $$
|
1159 |
-
MATCH (n:
|
1160 |
RETURN n
|
1161 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1162 |
record = await self._query(query)
|
1163 |
if record:
|
1164 |
node = record[0]
|
1165 |
-
node_dict = node["n"]
|
1166 |
|
1167 |
return node_dict
|
1168 |
return None
|
1169 |
|
1170 |
async def node_degree(self, node_id: str) -> int:
|
1171 |
-
label =
|
1172 |
|
1173 |
query = """SELECT * FROM cypher('%s', $$
|
1174 |
-
MATCH (n:
|
1175 |
RETURN count(x) AS total_edge_count
|
1176 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1177 |
record = (await self._query(query))[0]
|
1178 |
if record:
|
1179 |
edge_count = int(record["total_edge_count"])
|
1180 |
-
|
1181 |
return edge_count
|
1182 |
|
1183 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
@@ -1195,11 +1259,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1195 |
async def get_edge(
|
1196 |
self, source_node_id: str, target_node_id: str
|
1197 |
) -> dict[str, str] | None:
|
1198 |
-
|
1199 |
-
|
|
|
|
|
1200 |
|
1201 |
query = """SELECT * FROM cypher('%s', $$
|
1202 |
-
MATCH (a:
|
1203 |
RETURN properties(r) as edge_properties
|
1204 |
LIMIT 1
|
1205 |
$$) AS (edge_properties agtype)""" % (
|
@@ -1218,11 +1284,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1218 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
1219 |
:return: list of dictionaries containing edge information
|
1220 |
"""
|
1221 |
-
label =
|
1222 |
|
1223 |
query = """SELECT * FROM cypher('%s', $$
|
1224 |
-
MATCH (n:
|
1225 |
-
OPTIONAL MATCH (n)-[]-(connected)
|
1226 |
RETURN n, connected
|
1227 |
$$) AS (n agtype, connected agtype)""" % (
|
1228 |
self.graph_name,
|
@@ -1235,24 +1301,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1235 |
source_node = record["n"] if record["n"] else None
|
1236 |
connected_node = record["connected"] if record["connected"] else None
|
1237 |
|
1238 |
-
|
1239 |
-
source_node
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
|
1246 |
-
else None
|
1247 |
-
)
|
1248 |
|
1249 |
-
|
1250 |
-
|
1251 |
-
(
|
1252 |
-
self._decode_graph_label(source_label),
|
1253 |
-
self._decode_graph_label(target_label),
|
1254 |
-
)
|
1255 |
-
)
|
1256 |
|
1257 |
return edges
|
1258 |
|
@@ -1262,24 +1321,36 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1262 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
1263 |
)
|
1264 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
1265 |
-
|
1266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1267 |
|
1268 |
query = """SELECT * FROM cypher('%s', $$
|
1269 |
-
MERGE (n:
|
1270 |
SET n += %s
|
1271 |
RETURN n
|
1272 |
$$) AS (n agtype)""" % (
|
1273 |
self.graph_name,
|
1274 |
label,
|
1275 |
-
|
1276 |
)
|
1277 |
|
1278 |
try:
|
1279 |
await self._query(query, readonly=False, upsert=True)
|
1280 |
|
1281 |
-
except Exception
|
1282 |
-
logger.error("POSTGRES,
|
1283 |
raise
|
1284 |
|
1285 |
@retry(
|
@@ -1298,14 +1369,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1298 |
target_node_id (str): Label of the target node (used as identifier)
|
1299 |
edge_data (dict): dictionary of properties to set on the edge
|
1300 |
"""
|
1301 |
-
src_label =
|
1302 |
-
tgt_label =
|
1303 |
-
edge_properties = edge_data
|
1304 |
|
1305 |
query = """SELECT * FROM cypher('%s', $$
|
1306 |
-
MATCH (source:
|
1307 |
WITH source
|
1308 |
-
MATCH (target:
|
1309 |
MERGE (source)-[r:DIRECTED]->(target)
|
1310 |
SET r += %s
|
1311 |
RETURN r
|
@@ -1313,14 +1384,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1313 |
self.graph_name,
|
1314 |
src_label,
|
1315 |
tgt_label,
|
1316 |
-
|
1317 |
)
|
1318 |
|
1319 |
try:
|
1320 |
await self._query(query, readonly=False, upsert=True)
|
1321 |
|
1322 |
-
except Exception
|
1323 |
-
logger.error(
|
|
|
|
|
1324 |
raise
|
1325 |
|
1326 |
async def _node2vec_embed(self):
|
@@ -1333,10 +1406,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1333 |
Args:
|
1334 |
node_id (str): The ID of the node to delete.
|
1335 |
"""
|
1336 |
-
label =
|
1337 |
|
1338 |
query = """SELECT * FROM cypher('%s', $$
|
1339 |
-
MATCH (n:
|
1340 |
DETACH DELETE n
|
1341 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1342 |
|
@@ -1353,14 +1426,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1353 |
Args:
|
1354 |
node_ids (list[str]): A list of node IDs to remove.
|
1355 |
"""
|
1356 |
-
|
1357 |
-
|
1358 |
-
]
|
1359 |
-
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
1360 |
|
1361 |
query = """SELECT * FROM cypher('%s', $$
|
1362 |
-
MATCH (n:
|
1363 |
-
WHERE n.
|
1364 |
DETACH DELETE n
|
1365 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
1366 |
|
@@ -1377,26 +1448,21 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1377 |
Args:
|
1378 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
1379 |
"""
|
1380 |
-
|
1381 |
-
(
|
1382 |
-
|
1383 |
-
self._encode_graph_label(tgt.strip('"')),
|
1384 |
-
)
|
1385 |
-
for src, tgt in edges
|
1386 |
-
]
|
1387 |
-
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
|
1388 |
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
$$) AS (r agtype)""" % (self.graph_name, edge_list)
|
1394 |
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
|
|
1400 |
|
1401 |
async def get_all_labels(self) -> list[str]:
|
1402 |
"""
|
@@ -1407,15 +1473,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1407 |
"""
|
1408 |
query = (
|
1409 |
"""SELECT * FROM cypher('%s', $$
|
1410 |
-
MATCH (n:
|
1411 |
-
|
|
|
|
|
1412 |
$$) AS (label text)"""
|
1413 |
% self.graph_name
|
1414 |
)
|
1415 |
|
1416 |
results = await self._query(query)
|
1417 |
-
labels = [
|
1418 |
-
|
1419 |
return labels
|
1420 |
|
1421 |
async def embed_nodes(
|
@@ -1437,105 +1504,135 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1437 |
return await embed_func()
|
1438 |
|
1439 |
async def get_knowledge_graph(
|
1440 |
-
self,
|
|
|
|
|
|
|
1441 |
) -> KnowledgeGraph:
|
1442 |
"""
|
1443 |
-
Retrieve a subgraph
|
1444 |
|
1445 |
Args:
|
1446 |
-
node_label
|
1447 |
-
max_depth
|
|
|
1448 |
|
1449 |
Returns:
|
1450 |
-
KnowledgeGraph
|
|
|
1451 |
"""
|
1452 |
-
|
1453 |
-
|
1454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1455 |
if node_label == "*":
|
1456 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1457 |
-
|
1458 |
-
|
1459 |
-
|
1460 |
-
|
1461 |
-
|
1462 |
else:
|
1463 |
-
|
1464 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1465 |
-
|
1466 |
-
|
1467 |
-
|
1468 |
-
|
1469 |
-
|
1470 |
|
1471 |
results = await self._query(query)
|
1472 |
|
1473 |
-
|
1474 |
-
|
1475 |
-
|
1476 |
-
|
1477 |
-
|
1478 |
-
|
1479 |
-
|
1480 |
-
|
1481 |
-
|
1482 |
-
|
1483 |
-
|
1484 |
-
|
1485 |
-
edge_key = f"{src_id},{tgt_id}"
|
1486 |
-
if edge_key not in unique_edge_ids:
|
1487 |
-
unique_edge_ids.add(edge_key)
|
1488 |
-
edges.append(
|
1489 |
-
(
|
1490 |
-
edge_key,
|
1491 |
-
src_id,
|
1492 |
-
tgt_id,
|
1493 |
-
{"source": edge_data[0], "target": edge_data[2]},
|
1494 |
)
|
1495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1496 |
|
1497 |
-
|
1498 |
-
|
1499 |
-
|
1500 |
-
if
|
1501 |
-
|
1502 |
-
|
1503 |
-
|
1504 |
-
|
1505 |
-
|
1506 |
-
|
1507 |
-
|
1508 |
-
|
1509 |
-
|
1510 |
-
for edge in result
|
1511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1512 |
|
1513 |
-
# Construct and return the KnowledgeGraph
|
1514 |
kg = KnowledgeGraph(
|
1515 |
-
nodes=
|
1516 |
-
|
1517 |
-
|
1518 |
-
],
|
1519 |
-
edges=[
|
1520 |
-
KnowledgeGraphEdge(
|
1521 |
-
id=edge_id,
|
1522 |
-
type="DIRECTED",
|
1523 |
-
source=src,
|
1524 |
-
target=tgt,
|
1525 |
-
properties=props,
|
1526 |
-
)
|
1527 |
-
for edge_id, src, tgt, props in edges
|
1528 |
-
],
|
1529 |
)
|
1530 |
|
|
|
|
|
|
|
1531 |
return kg
|
1532 |
|
1533 |
-
async def drop(self) ->
|
1534 |
"""Drop the storage"""
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
1538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1539 |
|
1540 |
|
1541 |
NAMESPACE_TABLE_MAP = {
|
@@ -1693,6 +1790,7 @@ SQL_TEMPLATES = {
|
|
1693 |
file_path=EXCLUDED.file_path,
|
1694 |
update_time = CURRENT_TIMESTAMP
|
1695 |
""",
|
|
|
1696 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
1697 |
content_vector, chunk_ids, file_path)
|
1698 |
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
|
@@ -1716,45 +1814,6 @@ SQL_TEMPLATES = {
|
|
1716 |
file_path=EXCLUDED.file_path,
|
1717 |
update_time = CURRENT_TIMESTAMP
|
1718 |
""",
|
1719 |
-
# SQL for VectorStorage
|
1720 |
-
# "entities": """SELECT entity_name FROM
|
1721 |
-
# (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1722 |
-
# FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
1723 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1724 |
-
# """,
|
1725 |
-
# "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
1726 |
-
# (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1727 |
-
# FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
1728 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1729 |
-
# """,
|
1730 |
-
# "chunks": """SELECT id FROM
|
1731 |
-
# (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1732 |
-
# FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
1733 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1734 |
-
# """,
|
1735 |
-
# DROP tables
|
1736 |
-
"drop_all": """
|
1737 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
1738 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
|
1739 |
-
DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
|
1740 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
|
1741 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
1742 |
-
""",
|
1743 |
-
"drop_doc_full": """
|
1744 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
1745 |
-
""",
|
1746 |
-
"drop_doc_chunks": """
|
1747 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
|
1748 |
-
""",
|
1749 |
-
"drop_llm_cache": """
|
1750 |
-
DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
|
1751 |
-
""",
|
1752 |
-
"drop_vdb_entity": """
|
1753 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
|
1754 |
-
""",
|
1755 |
-
"drop_vdb_relation": """
|
1756 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
1757 |
-
""",
|
1758 |
"relationships": """
|
1759 |
WITH relevant_chunks AS (
|
1760 |
SELECT id as chunk_id
|
@@ -1795,9 +1854,9 @@ SQL_TEMPLATES = {
|
|
1795 |
FROM LIGHTRAG_DOC_CHUNKS
|
1796 |
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
1797 |
)
|
1798 |
-
SELECT id FROM
|
1799 |
(
|
1800 |
-
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1801 |
FROM LIGHTRAG_DOC_CHUNKS
|
1802 |
where workspace=$1
|
1803 |
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
@@ -1806,4 +1865,8 @@ SQL_TEMPLATES = {
|
|
1806 |
ORDER BY distance DESC
|
1807 |
LIMIT $3
|
1808 |
""",
|
|
|
|
|
|
|
|
|
1809 |
}
|
|
|
9 |
|
10 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
11 |
|
|
|
12 |
from tenacity import (
|
13 |
retry,
|
14 |
retry_if_exception_type,
|
|
|
27 |
from ..namespace import NameSpace, is_namespace
|
28 |
from ..utils import logger
|
29 |
|
|
|
|
|
|
|
|
|
|
|
30 |
import pipmaster as pm
|
31 |
|
32 |
if not pm.is_installed("asyncpg"):
|
|
|
35 |
import asyncpg # type: ignore
|
36 |
from asyncpg import Pool # type: ignore
|
37 |
|
38 |
+
# Get maximum number of graph nodes from environment variable, default is 1000
|
39 |
+
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
40 |
+
|
41 |
|
42 |
class PostgreSQLDB:
|
43 |
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
|
|
115 |
)
|
116 |
raise e
|
117 |
|
118 |
+
# Create index for id column in each table
|
119 |
+
try:
|
120 |
+
index_name = f"idx_{k.lower()}_id"
|
121 |
+
check_index_sql = f"""
|
122 |
+
SELECT 1 FROM pg_indexes
|
123 |
+
WHERE indexname = '{index_name}'
|
124 |
+
AND tablename = '{k.lower()}'
|
125 |
+
"""
|
126 |
+
index_exists = await self.query(check_index_sql)
|
127 |
+
|
128 |
+
if not index_exists:
|
129 |
+
create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
|
130 |
+
logger.info(f"PostgreSQL, Creating index {index_name} on table {k}")
|
131 |
+
await self.execute(create_index_sql)
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(
|
134 |
+
f"PostgreSQL, Failed to create index on table {k}, Got: {e}"
|
135 |
+
)
|
136 |
+
|
137 |
async def query(
|
138 |
self,
|
139 |
sql: str,
|
|
|
270 |
db: PostgreSQLDB = field(default=None)
|
271 |
|
272 |
def __post_init__(self):
|
|
|
|
|
273 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
274 |
|
275 |
async def initialize(self):
|
|
|
285 |
|
286 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
287 |
"""Get doc_full data by id."""
|
288 |
+
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
289 |
params = {"workspace": self.db.workspace, "id": id}
|
290 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
291 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
|
299 |
|
300 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
301 |
"""Specifically for llm_response_cache."""
|
302 |
+
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
303 |
params = {"workspace": self.db.workspace, mode: mode, "id": id}
|
304 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
305 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
|
313 |
# Query by id
|
314 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
315 |
"""Get doc_chunks data by id"""
|
316 |
+
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
317 |
ids=",".join([f"'{id}'" for id in ids])
|
318 |
)
|
319 |
params = {"workspace": self.db.workspace}
|
|
|
334 |
|
335 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
336 |
"""Specifically for llm_response_cache."""
|
337 |
+
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
338 |
params = {"workspace": self.db.workspace, "status": status}
|
339 |
return await self.db.query(SQL, params, multirows=True)
|
340 |
|
|
|
394 |
# PG handles persistence automatically
|
395 |
pass
|
396 |
|
397 |
+
async def delete(self, ids: list[str]) -> None:
|
398 |
+
"""Delete specific records from storage by their IDs
|
399 |
+
|
400 |
+
Args:
|
401 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
None
|
405 |
+
"""
|
406 |
+
if not ids:
|
407 |
+
return
|
408 |
+
|
409 |
+
table_name = namespace_to_table_name(self.namespace)
|
410 |
+
if not table_name:
|
411 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
412 |
+
return
|
413 |
+
|
414 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
415 |
+
|
416 |
+
try:
|
417 |
+
await self.db.execute(
|
418 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
419 |
+
)
|
420 |
+
logger.debug(
|
421 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
422 |
+
)
|
423 |
+
except Exception as e:
|
424 |
+
logger.error(f"Error while deleting records from {self.namespace}: {e}")
|
425 |
+
|
426 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
427 |
+
"""Delete specific records from storage by cache mode
|
428 |
+
|
429 |
+
Args:
|
430 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
bool: True if successful, False otherwise
|
434 |
+
"""
|
435 |
+
if not modes:
|
436 |
+
return False
|
437 |
+
|
438 |
+
try:
|
439 |
+
table_name = namespace_to_table_name(self.namespace)
|
440 |
+
if not table_name:
|
441 |
+
return False
|
442 |
+
|
443 |
+
if table_name != "LIGHTRAG_LLM_CACHE":
|
444 |
+
return False
|
445 |
+
|
446 |
+
sql = f"""
|
447 |
+
DELETE FROM {table_name}
|
448 |
+
WHERE workspace = $1 AND mode = ANY($2)
|
449 |
+
"""
|
450 |
+
params = {"workspace": self.db.workspace, "modes": modes}
|
451 |
+
|
452 |
+
logger.info(f"Deleting cache by modes: {modes}")
|
453 |
+
await self.db.execute(sql, params)
|
454 |
+
return True
|
455 |
+
except Exception as e:
|
456 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
457 |
+
return False
|
458 |
+
|
459 |
+
async def drop(self) -> dict[str, str]:
|
460 |
"""Drop the storage"""
|
461 |
+
try:
|
462 |
+
table_name = namespace_to_table_name(self.namespace)
|
463 |
+
if not table_name:
|
464 |
+
return {
|
465 |
+
"status": "error",
|
466 |
+
"message": f"Unknown namespace: {self.namespace}",
|
467 |
+
}
|
468 |
+
|
469 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
470 |
+
table_name=table_name
|
471 |
+
)
|
472 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
473 |
+
return {"status": "success", "message": "data dropped"}
|
474 |
+
except Exception as e:
|
475 |
+
return {"status": "error", "message": str(e)}
|
476 |
|
477 |
|
478 |
@final
|
|
|
482 |
|
483 |
def __post_init__(self):
|
484 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
|
|
485 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
486 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
487 |
if cosine_threshold is None:
|
|
|
610 |
else:
|
611 |
formatted_ids = "NULL"
|
612 |
|
613 |
+
sql = SQL_TEMPLATES[self.namespace].format(
|
614 |
embedding_string=embedding_string, doc_ids=formatted_ids
|
615 |
)
|
616 |
params = {
|
|
|
639 |
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
640 |
return
|
641 |
|
642 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
|
|
|
|
|
|
643 |
|
644 |
try:
|
645 |
+
await self.db.execute(
|
646 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
647 |
+
)
|
648 |
logger.debug(
|
649 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
650 |
)
|
|
|
776 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
777 |
return []
|
778 |
|
779 |
+
async def drop(self) -> dict[str, str]:
|
780 |
+
"""Drop the storage"""
|
781 |
+
try:
|
782 |
+
table_name = namespace_to_table_name(self.namespace)
|
783 |
+
if not table_name:
|
784 |
+
return {
|
785 |
+
"status": "error",
|
786 |
+
"message": f"Unknown namespace: {self.namespace}",
|
787 |
+
}
|
788 |
+
|
789 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
790 |
+
table_name=table_name
|
791 |
+
)
|
792 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
793 |
+
return {"status": "success", "message": "data dropped"}
|
794 |
+
except Exception as e:
|
795 |
+
return {"status": "error", "message": str(e)}
|
796 |
+
|
797 |
|
798 |
@final
|
799 |
@dataclass
|
|
|
914 |
# PG handles persistence automatically
|
915 |
pass
|
916 |
|
917 |
+
async def delete(self, ids: list[str]) -> None:
|
918 |
+
"""Delete specific records from storage by their IDs
|
919 |
+
|
920 |
+
Args:
|
921 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
922 |
+
|
923 |
+
Returns:
|
924 |
+
None
|
925 |
+
"""
|
926 |
+
if not ids:
|
927 |
+
return
|
928 |
+
|
929 |
+
table_name = namespace_to_table_name(self.namespace)
|
930 |
+
if not table_name:
|
931 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
932 |
+
return
|
933 |
+
|
934 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
935 |
+
|
936 |
+
try:
|
937 |
+
await self.db.execute(
|
938 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
939 |
+
)
|
940 |
+
logger.debug(
|
941 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
942 |
+
)
|
943 |
+
except Exception as e:
|
944 |
+
logger.error(f"Error while deleting records from {self.namespace}: {e}")
|
945 |
+
|
946 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
947 |
"""Update or insert document status
|
948 |
|
|
|
979 |
},
|
980 |
)
|
981 |
|
982 |
+
async def drop(self) -> dict[str, str]:
|
983 |
"""Drop the storage"""
|
984 |
+
try:
|
985 |
+
table_name = namespace_to_table_name(self.namespace)
|
986 |
+
if not table_name:
|
987 |
+
return {
|
988 |
+
"status": "error",
|
989 |
+
"message": f"Unknown namespace: {self.namespace}",
|
990 |
+
}
|
991 |
+
|
992 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
993 |
+
table_name=table_name
|
994 |
+
)
|
995 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
996 |
+
return {"status": "success", "message": "data dropped"}
|
997 |
+
except Exception as e:
|
998 |
+
return {"status": "error", "message": str(e)}
|
999 |
|
1000 |
|
1001 |
class PGGraphQueryException(Exception):
|
|
|
1083 |
if v.startswith("[") and v.endswith("]"):
|
1084 |
if "::vertex" in v:
|
1085 |
v = v.replace("::vertex", "")
|
1086 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1087 |
|
1088 |
elif "::edge" in v:
|
1089 |
v = v.replace("::edge", "")
|
1090 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1091 |
else:
|
1092 |
print("WARNING: unsupported type")
|
1093 |
continue
|
|
|
1096 |
dtype = v.split("::")[-1]
|
1097 |
v = v.split("::")[0]
|
1098 |
if dtype == "vertex":
|
1099 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1100 |
elif dtype == "edge":
|
1101 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1102 |
else:
|
1103 |
+
try:
|
1104 |
+
d[k] = (
|
1105 |
+
json.loads(v)
|
1106 |
+
if isinstance(v, str)
|
1107 |
+
and (v.startswith("{") or v.startswith("["))
|
1108 |
+
else v
|
1109 |
+
)
|
1110 |
+
except json.JSONDecodeError:
|
1111 |
+
d[k] = v
|
1112 |
|
1113 |
return d
|
1114 |
|
|
|
1138 |
)
|
1139 |
return "{" + ", ".join(props) + "}"
|
1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
async def _query(
|
1142 |
self,
|
1143 |
query: str,
|
|
|
1188 |
return result
|
1189 |
|
1190 |
async def has_node(self, node_id: str) -> bool:
|
1191 |
+
entity_name_label = node_id.strip('"')
|
1192 |
|
1193 |
query = """SELECT * FROM cypher('%s', $$
|
1194 |
+
MATCH (n:base {entity_id: "%s"})
|
1195 |
RETURN count(n) > 0 AS node_exists
|
1196 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
1197 |
|
|
|
1200 |
return single_result["node_exists"]
|
1201 |
|
1202 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
1203 |
+
src_label = source_node_id.strip('"')
|
1204 |
+
tgt_label = target_node_id.strip('"')
|
1205 |
|
1206 |
query = """SELECT * FROM cypher('%s', $$
|
1207 |
+
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
1208 |
RETURN COUNT(r) > 0 AS edge_exists
|
1209 |
$$) AS (edge_exists bool)""" % (
|
1210 |
self.graph_name,
|
|
|
1217 |
return single_result["edge_exists"]
|
1218 |
|
1219 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
1220 |
+
"""Get node by its label identifier, return only node properties"""
|
1221 |
+
|
1222 |
+
label = node_id.strip('"')
|
1223 |
query = """SELECT * FROM cypher('%s', $$
|
1224 |
+
MATCH (n:base {entity_id: "%s"})
|
1225 |
RETURN n
|
1226 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1227 |
record = await self._query(query)
|
1228 |
if record:
|
1229 |
node = record[0]
|
1230 |
+
node_dict = node["n"]["properties"]
|
1231 |
|
1232 |
return node_dict
|
1233 |
return None
|
1234 |
|
1235 |
async def node_degree(self, node_id: str) -> int:
|
1236 |
+
label = node_id.strip('"')
|
1237 |
|
1238 |
query = """SELECT * FROM cypher('%s', $$
|
1239 |
+
MATCH (n:base {entity_id: "%s"})-[]-(x)
|
1240 |
RETURN count(x) AS total_edge_count
|
1241 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1242 |
record = (await self._query(query))[0]
|
1243 |
if record:
|
1244 |
edge_count = int(record["total_edge_count"])
|
|
|
1245 |
return edge_count
|
1246 |
|
1247 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
1259 |
async def get_edge(
|
1260 |
self, source_node_id: str, target_node_id: str
|
1261 |
) -> dict[str, str] | None:
|
1262 |
+
"""Get edge properties between two nodes"""
|
1263 |
+
|
1264 |
+
src_label = source_node_id.strip('"')
|
1265 |
+
tgt_label = target_node_id.strip('"')
|
1266 |
|
1267 |
query = """SELECT * FROM cypher('%s', $$
|
1268 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
1269 |
RETURN properties(r) as edge_properties
|
1270 |
LIMIT 1
|
1271 |
$$) AS (edge_properties agtype)""" % (
|
|
|
1284 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
1285 |
:return: list of dictionaries containing edge information
|
1286 |
"""
|
1287 |
+
label = source_node_id.strip('"')
|
1288 |
|
1289 |
query = """SELECT * FROM cypher('%s', $$
|
1290 |
+
MATCH (n:base {entity_id: "%s"})
|
1291 |
+
OPTIONAL MATCH (n)-[]-(connected:base)
|
1292 |
RETURN n, connected
|
1293 |
$$) AS (n agtype, connected agtype)""" % (
|
1294 |
self.graph_name,
|
|
|
1301 |
source_node = record["n"] if record["n"] else None
|
1302 |
connected_node = record["connected"] if record["connected"] else None
|
1303 |
|
1304 |
+
if (
|
1305 |
+
source_node
|
1306 |
+
and connected_node
|
1307 |
+
and "properties" in source_node
|
1308 |
+
and "properties" in connected_node
|
1309 |
+
):
|
1310 |
+
source_label = source_node["properties"].get("entity_id")
|
1311 |
+
target_label = connected_node["properties"].get("entity_id")
|
|
|
|
|
1312 |
|
1313 |
+
if source_label and target_label:
|
1314 |
+
edges.append((source_label, target_label))
|
|
|
|
|
|
|
|
|
|
|
1315 |
|
1316 |
return edges
|
1317 |
|
|
|
1321 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
1322 |
)
|
1323 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
1324 |
+
"""
|
1325 |
+
Upsert a node in the Neo4j database.
|
1326 |
+
|
1327 |
+
Args:
|
1328 |
+
node_id: The unique identifier for the node (used as label)
|
1329 |
+
node_data: Dictionary of node properties
|
1330 |
+
"""
|
1331 |
+
if "entity_id" not in node_data:
|
1332 |
+
raise ValueError(
|
1333 |
+
"PostgreSQL: node properties must contain an 'entity_id' field"
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
label = node_id.strip('"')
|
1337 |
+
properties = self._format_properties(node_data)
|
1338 |
|
1339 |
query = """SELECT * FROM cypher('%s', $$
|
1340 |
+
MERGE (n:base {entity_id: "%s"})
|
1341 |
SET n += %s
|
1342 |
RETURN n
|
1343 |
$$) AS (n agtype)""" % (
|
1344 |
self.graph_name,
|
1345 |
label,
|
1346 |
+
properties,
|
1347 |
)
|
1348 |
|
1349 |
try:
|
1350 |
await self._query(query, readonly=False, upsert=True)
|
1351 |
|
1352 |
+
except Exception:
|
1353 |
+
logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`")
|
1354 |
raise
|
1355 |
|
1356 |
@retry(
|
|
|
1369 |
target_node_id (str): Label of the target node (used as identifier)
|
1370 |
edge_data (dict): dictionary of properties to set on the edge
|
1371 |
"""
|
1372 |
+
src_label = source_node_id.strip('"')
|
1373 |
+
tgt_label = target_node_id.strip('"')
|
1374 |
+
edge_properties = self._format_properties(edge_data)
|
1375 |
|
1376 |
query = """SELECT * FROM cypher('%s', $$
|
1377 |
+
MATCH (source:base {entity_id: "%s"})
|
1378 |
WITH source
|
1379 |
+
MATCH (target:base {entity_id: "%s"})
|
1380 |
MERGE (source)-[r:DIRECTED]->(target)
|
1381 |
SET r += %s
|
1382 |
RETURN r
|
|
|
1384 |
self.graph_name,
|
1385 |
src_label,
|
1386 |
tgt_label,
|
1387 |
+
edge_properties,
|
1388 |
)
|
1389 |
|
1390 |
try:
|
1391 |
await self._query(query, readonly=False, upsert=True)
|
1392 |
|
1393 |
+
except Exception:
|
1394 |
+
logger.error(
|
1395 |
+
f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
|
1396 |
+
)
|
1397 |
raise
|
1398 |
|
1399 |
async def _node2vec_embed(self):
|
|
|
1406 |
Args:
|
1407 |
node_id (str): The ID of the node to delete.
|
1408 |
"""
|
1409 |
+
label = node_id.strip('"')
|
1410 |
|
1411 |
query = """SELECT * FROM cypher('%s', $$
|
1412 |
+
MATCH (n:base {entity_id: "%s"})
|
1413 |
DETACH DELETE n
|
1414 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1415 |
|
|
|
1426 |
Args:
|
1427 |
node_ids (list[str]): A list of node IDs to remove.
|
1428 |
"""
|
1429 |
+
node_ids = [node_id.strip('"') for node_id in node_ids]
|
1430 |
+
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
|
|
|
|
1431 |
|
1432 |
query = """SELECT * FROM cypher('%s', $$
|
1433 |
+
MATCH (n:base)
|
1434 |
+
WHERE n.entity_id IN [%s]
|
1435 |
DETACH DELETE n
|
1436 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
1437 |
|
|
|
1448 |
Args:
|
1449 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
1450 |
"""
|
1451 |
+
for source, target in edges:
|
1452 |
+
src_label = source.strip('"')
|
1453 |
+
tgt_label = target.strip('"')
|
|
|
|
|
|
|
|
|
|
|
1454 |
|
1455 |
+
query = """SELECT * FROM cypher('%s', $$
|
1456 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
1457 |
+
DELETE r
|
1458 |
+
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
|
|
1459 |
|
1460 |
+
try:
|
1461 |
+
await self._query(query, readonly=False)
|
1462 |
+
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
1463 |
+
except Exception as e:
|
1464 |
+
logger.error(f"Error during edge deletion: {str(e)}")
|
1465 |
+
raise
|
1466 |
|
1467 |
async def get_all_labels(self) -> list[str]:
|
1468 |
"""
|
|
|
1473 |
"""
|
1474 |
query = (
|
1475 |
"""SELECT * FROM cypher('%s', $$
|
1476 |
+
MATCH (n:base)
|
1477 |
+
WHERE n.entity_id IS NOT NULL
|
1478 |
+
RETURN DISTINCT n.entity_id AS label
|
1479 |
+
ORDER BY n.entity_id
|
1480 |
$$) AS (label text)"""
|
1481 |
% self.graph_name
|
1482 |
)
|
1483 |
|
1484 |
results = await self._query(query)
|
1485 |
+
labels = [result["label"] for result in results]
|
|
|
1486 |
return labels
|
1487 |
|
1488 |
async def embed_nodes(
|
|
|
1504 |
return await embed_func()
|
1505 |
|
1506 |
async def get_knowledge_graph(
|
1507 |
+
self,
|
1508 |
+
node_label: str,
|
1509 |
+
max_depth: int = 3,
|
1510 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
1511 |
) -> KnowledgeGraph:
|
1512 |
"""
|
1513 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
1514 |
|
1515 |
Args:
|
1516 |
+
node_label: Label of the starting node, * means all nodes
|
1517 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
1518 |
+
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
1519 |
|
1520 |
Returns:
|
1521 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
1522 |
+
indicating whether the graph was truncated due to max_nodes limit
|
1523 |
"""
|
1524 |
+
# First, count the total number of nodes that would be returned without limit
|
1525 |
+
if node_label == "*":
|
1526 |
+
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1527 |
+
MATCH (n:base)
|
1528 |
+
RETURN count(distinct n) AS total_nodes
|
1529 |
+
$$) AS (total_nodes bigint)"""
|
1530 |
+
else:
|
1531 |
+
strip_label = node_label.strip('"')
|
1532 |
+
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1533 |
+
MATCH (n:base {{entity_id: "{strip_label}"}})
|
1534 |
+
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
1535 |
+
RETURN count(distinct m) AS total_nodes
|
1536 |
+
$$) AS (total_nodes bigint)"""
|
1537 |
+
|
1538 |
+
count_result = await self._query(count_query)
|
1539 |
+
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
1540 |
+
is_truncated = total_nodes > max_nodes
|
1541 |
+
|
1542 |
+
# Now get the actual data with limit
|
1543 |
if node_label == "*":
|
1544 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1545 |
+
MATCH (n:base)
|
1546 |
+
OPTIONAL MATCH (n)-[r]->(target:base)
|
1547 |
+
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
1548 |
+
LIMIT {max_nodes}
|
1549 |
+
$$) AS (n agtype, r agtype)"""
|
1550 |
else:
|
1551 |
+
strip_label = node_label.strip('"')
|
1552 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1553 |
+
MATCH (n:base {{entity_id: "{strip_label}"}})
|
1554 |
+
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
1555 |
+
RETURN nodes(p) AS n, relationships(p) AS r
|
1556 |
+
LIMIT {max_nodes}
|
1557 |
+
$$) AS (n agtype, r agtype)"""
|
1558 |
|
1559 |
results = await self._query(query)
|
1560 |
|
1561 |
+
# Process the query results with deduplication by node and edge IDs
|
1562 |
+
nodes_dict = {}
|
1563 |
+
edges_dict = {}
|
1564 |
+
for result in results:
|
1565 |
+
# Handle single node cases
|
1566 |
+
if result.get("n") and isinstance(result["n"], dict):
|
1567 |
+
node_id = str(result["n"]["id"])
|
1568 |
+
if node_id not in nodes_dict:
|
1569 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
1570 |
+
id=node_id,
|
1571 |
+
labels=[result["n"]["properties"]["entity_id"]],
|
1572 |
+
properties=result["n"]["properties"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1573 |
)
|
1574 |
+
# Handle node list cases
|
1575 |
+
elif result.get("n") and isinstance(result["n"], list):
|
1576 |
+
for node in result["n"]:
|
1577 |
+
if isinstance(node, dict) and "id" in node:
|
1578 |
+
node_id = str(node["id"])
|
1579 |
+
if node_id not in nodes_dict and "properties" in node:
|
1580 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
1581 |
+
id=node_id,
|
1582 |
+
labels=[node["properties"]["entity_id"]],
|
1583 |
+
properties=node["properties"],
|
1584 |
+
)
|
1585 |
|
1586 |
+
# Handle single edge cases
|
1587 |
+
if result.get("r") and isinstance(result["r"], dict):
|
1588 |
+
edge_id = str(result["r"]["id"])
|
1589 |
+
if edge_id not in edges_dict:
|
1590 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
1591 |
+
id=edge_id,
|
1592 |
+
type="DIRECTED",
|
1593 |
+
source=str(result["r"]["start_id"]),
|
1594 |
+
target=str(result["r"]["end_id"]),
|
1595 |
+
properties=result["r"]["properties"],
|
1596 |
+
)
|
1597 |
+
# Handle edge list cases
|
1598 |
+
elif result.get("r") and isinstance(result["r"], list):
|
1599 |
+
for edge in result["r"]:
|
1600 |
+
if isinstance(edge, dict) and "id" in edge:
|
1601 |
+
edge_id = str(edge["id"])
|
1602 |
+
if edge_id not in edges_dict:
|
1603 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
1604 |
+
id=edge_id,
|
1605 |
+
type="DIRECTED",
|
1606 |
+
source=str(edge["start_id"]),
|
1607 |
+
target=str(edge["end_id"]),
|
1608 |
+
properties=edge["properties"],
|
1609 |
+
)
|
1610 |
|
1611 |
+
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
1612 |
kg = KnowledgeGraph(
|
1613 |
+
nodes=list(nodes_dict.values()),
|
1614 |
+
edges=list(edges_dict.values()),
|
1615 |
+
is_truncated=is_truncated,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1616 |
)
|
1617 |
|
1618 |
+
logger.info(
|
1619 |
+
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
1620 |
+
)
|
1621 |
return kg
|
1622 |
|
1623 |
+
async def drop(self) -> dict[str, str]:
|
1624 |
"""Drop the storage"""
|
1625 |
+
try:
|
1626 |
+
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1627 |
+
MATCH (n)
|
1628 |
+
DETACH DELETE n
|
1629 |
+
$$) AS (result agtype)"""
|
1630 |
+
|
1631 |
+
await self._query(drop_query, readonly=False)
|
1632 |
+
return {"status": "success", "message": "graph data dropped"}
|
1633 |
+
except Exception as e:
|
1634 |
+
logger.error(f"Error dropping graph: {e}")
|
1635 |
+
return {"status": "error", "message": str(e)}
|
1636 |
|
1637 |
|
1638 |
NAMESPACE_TABLE_MAP = {
|
|
|
1790 |
file_path=EXCLUDED.file_path,
|
1791 |
update_time = CURRENT_TIMESTAMP
|
1792 |
""",
|
1793 |
+
# SQL for VectorStorage
|
1794 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
1795 |
content_vector, chunk_ids, file_path)
|
1796 |
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
|
|
|
1814 |
file_path=EXCLUDED.file_path,
|
1815 |
update_time = CURRENT_TIMESTAMP
|
1816 |
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1817 |
"relationships": """
|
1818 |
WITH relevant_chunks AS (
|
1819 |
SELECT id as chunk_id
|
|
|
1854 |
FROM LIGHTRAG_DOC_CHUNKS
|
1855 |
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
1856 |
)
|
1857 |
+
SELECT id, content, file_path FROM
|
1858 |
(
|
1859 |
+
SELECT id, content, file_path, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1860 |
FROM LIGHTRAG_DOC_CHUNKS
|
1861 |
where workspace=$1
|
1862 |
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
|
|
1865 |
ORDER BY distance DESC
|
1866 |
LIMIT $3
|
1867 |
""",
|
1868 |
+
# DROP tables
|
1869 |
+
"drop_specifiy_table_workspace": """
|
1870 |
+
DELETE FROM {table_name} WHERE workspace=$1
|
1871 |
+
""",
|
1872 |
}
|
lightrag/kg/qdrant_impl.py
CHANGED
@@ -8,17 +8,15 @@ import uuid
|
|
8 |
from ..utils import logger
|
9 |
from ..base import BaseVectorStorage
|
10 |
import configparser
|
11 |
-
|
12 |
-
|
13 |
-
config = configparser.ConfigParser()
|
14 |
-
config.read("config.ini", "utf-8")
|
15 |
-
|
16 |
import pipmaster as pm
|
17 |
|
18 |
if not pm.is_installed("qdrant-client"):
|
19 |
pm.install("qdrant-client")
|
20 |
|
21 |
-
from qdrant_client import QdrantClient, models
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def compute_mdhash_id_for_qdrant(
|
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|
275 |
except Exception as e:
|
276 |
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
277 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from ..utils import logger
|
9 |
from ..base import BaseVectorStorage
|
10 |
import configparser
|
|
|
|
|
|
|
|
|
|
|
11 |
import pipmaster as pm
|
12 |
|
13 |
if not pm.is_installed("qdrant-client"):
|
14 |
pm.install("qdrant-client")
|
15 |
|
16 |
+
from qdrant_client import QdrantClient, models # type: ignore
|
17 |
+
|
18 |
+
config = configparser.ConfigParser()
|
19 |
+
config.read("config.ini", "utf-8")
|
20 |
|
21 |
|
22 |
def compute_mdhash_id_for_qdrant(
|
|
|
273 |
except Exception as e:
|
274 |
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
275 |
return []
|
276 |
+
|
277 |
+
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
278 |
+
"""Get vector data by its ID
|
279 |
+
|
280 |
+
Args:
|
281 |
+
id: The unique identifier of the vector
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
The vector data if found, or None if not found
|
285 |
+
"""
|
286 |
+
try:
|
287 |
+
# Convert to Qdrant compatible ID
|
288 |
+
qdrant_id = compute_mdhash_id_for_qdrant(id)
|
289 |
+
|
290 |
+
# Retrieve the point by ID
|
291 |
+
result = self._client.retrieve(
|
292 |
+
collection_name=self.namespace,
|
293 |
+
ids=[qdrant_id],
|
294 |
+
with_payload=True,
|
295 |
+
)
|
296 |
+
|
297 |
+
if not result:
|
298 |
+
return None
|
299 |
+
|
300 |
+
return result[0].payload
|
301 |
+
except Exception as e:
|
302 |
+
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
303 |
+
return None
|
304 |
+
|
305 |
+
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
306 |
+
"""Get multiple vector data by their IDs
|
307 |
+
|
308 |
+
Args:
|
309 |
+
ids: List of unique identifiers
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
List of vector data objects that were found
|
313 |
+
"""
|
314 |
+
if not ids:
|
315 |
+
return []
|
316 |
+
|
317 |
+
try:
|
318 |
+
# Convert to Qdrant compatible IDs
|
319 |
+
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
320 |
+
|
321 |
+
# Retrieve the points by IDs
|
322 |
+
results = self._client.retrieve(
|
323 |
+
collection_name=self.namespace,
|
324 |
+
ids=qdrant_ids,
|
325 |
+
with_payload=True,
|
326 |
+
)
|
327 |
+
|
328 |
+
return [point.payload for point in results]
|
329 |
+
except Exception as e:
|
330 |
+
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
331 |
+
return []
|
332 |
+
|
333 |
+
async def drop(self) -> dict[str, str]:
|
334 |
+
"""Drop all vector data from storage and clean up resources
|
335 |
+
|
336 |
+
This method will delete all data from the Qdrant collection.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
dict[str, str]: Operation status and message
|
340 |
+
- On success: {"status": "success", "message": "data dropped"}
|
341 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
342 |
+
"""
|
343 |
+
try:
|
344 |
+
# Delete the collection and recreate it
|
345 |
+
if self._client.collection_exists(self.namespace):
|
346 |
+
self._client.delete_collection(self.namespace)
|
347 |
+
|
348 |
+
# Recreate the collection
|
349 |
+
QdrantVectorDBStorage.create_collection_if_not_exist(
|
350 |
+
self._client,
|
351 |
+
self.namespace,
|
352 |
+
vectors_config=models.VectorParams(
|
353 |
+
size=self.embedding_func.embedding_dim,
|
354 |
+
distance=models.Distance.COSINE,
|
355 |
+
),
|
356 |
+
)
|
357 |
+
|
358 |
+
logger.info(
|
359 |
+
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
360 |
+
)
|
361 |
+
return {"status": "success", "message": "data dropped"}
|
362 |
+
except Exception as e:
|
363 |
+
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
|
364 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/redis_impl.py
CHANGED
@@ -8,8 +8,8 @@ if not pm.is_installed("redis"):
|
|
8 |
pm.install("redis")
|
9 |
|
10 |
# aioredis is a depricated library, replaced with redis
|
11 |
-
from redis.asyncio import Redis
|
12 |
-
from lightrag.utils import logger
|
13 |
from lightrag.base import BaseKVStorage
|
14 |
import json
|
15 |
|
@@ -84,66 +84,50 @@ class RedisKVStorage(BaseKVStorage):
|
|
84 |
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
85 |
)
|
86 |
|
87 |
-
async def
|
88 |
-
"""Delete
|
|
|
|
|
|
|
89 |
|
90 |
Args:
|
91 |
-
|
|
|
|
|
|
|
|
|
92 |
"""
|
|
|
|
|
93 |
|
94 |
try:
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
# Delete the entity
|
101 |
-
result = await self._redis.delete(f"{self.namespace}:{entity_id}")
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
else:
|
106 |
-
logger.debug(f"Entity {entity_name} not found in storage")
|
107 |
-
except Exception as e:
|
108 |
-
logger.error(f"Error deleting entity {entity_name}: {e}")
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
Args:
|
114 |
-
entity_name: Name of the entity whose relations should be deleted
|
115 |
"""
|
116 |
try:
|
117 |
-
|
118 |
-
cursor = 0
|
119 |
-
relation_keys = []
|
120 |
-
pattern = f"{self.namespace}:*"
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
# For each key, get the value and check if it's related to entity_name
|
126 |
for key in keys:
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
or data.get("tgt_id") == entity_name
|
134 |
-
):
|
135 |
-
relation_keys.append(key)
|
136 |
-
|
137 |
-
# Exit loop when cursor returns to 0
|
138 |
-
if cursor == 0:
|
139 |
-
break
|
140 |
-
|
141 |
-
# Delete the relation keys
|
142 |
-
if relation_keys:
|
143 |
-
deleted = await self._redis.delete(*relation_keys)
|
144 |
-
logger.debug(f"Deleted {deleted} relations for {entity_name}")
|
145 |
else:
|
146 |
-
logger.
|
|
|
147 |
|
148 |
except Exception as e:
|
149 |
-
logger.error(f"Error
|
|
|
|
8 |
pm.install("redis")
|
9 |
|
10 |
# aioredis is a depricated library, replaced with redis
|
11 |
+
from redis.asyncio import Redis # type: ignore
|
12 |
+
from lightrag.utils import logger
|
13 |
from lightrag.base import BaseKVStorage
|
14 |
import json
|
15 |
|
|
|
84 |
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
85 |
)
|
86 |
|
87 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
88 |
+
"""Delete specific records from storage by by cache mode
|
89 |
+
|
90 |
+
Importance notes for Redis storage:
|
91 |
+
1. This will immediately delete the specified cache modes from Redis
|
92 |
|
93 |
Args:
|
94 |
+
modes (list[str]): List of cache mode to be drop from storage
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
True: if the cache drop successfully
|
98 |
+
False: if the cache drop failed
|
99 |
"""
|
100 |
+
if not modes:
|
101 |
+
return False
|
102 |
|
103 |
try:
|
104 |
+
await self.delete(modes)
|
105 |
+
return True
|
106 |
+
except Exception:
|
107 |
+
return False
|
|
|
|
|
|
|
108 |
|
109 |
+
async def drop(self) -> dict[str, str]:
|
110 |
+
"""Drop the storage by removing all keys under the current namespace.
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
Returns:
|
113 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
|
|
|
|
|
|
114 |
"""
|
115 |
try:
|
116 |
+
keys = await self._redis.keys(f"{self.namespace}:*")
|
|
|
|
|
|
|
117 |
|
118 |
+
if keys:
|
119 |
+
pipe = self._redis.pipeline()
|
|
|
|
|
120 |
for key in keys:
|
121 |
+
pipe.delete(key)
|
122 |
+
results = await pipe.execute()
|
123 |
+
deleted_count = sum(results)
|
124 |
+
|
125 |
+
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
126 |
+
return {"status": "success", "message": f"{deleted_count} keys dropped"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
else:
|
128 |
+
logger.info(f"No keys found to drop in {self.namespace}")
|
129 |
+
return {"status": "success", "message": "no keys to drop"}
|
130 |
|
131 |
except Exception as e:
|
132 |
+
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
133 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
|
|
20 |
if not pm.is_installed("sqlalchemy"):
|
21 |
pm.install("sqlalchemy")
|
22 |
|
23 |
-
from sqlalchemy import create_engine, text
|
24 |
|
25 |
|
26 |
class TiDB:
|
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
|
|
278 |
# Ti handles persistence automatically
|
279 |
pass
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
@final
|
283 |
@dataclass
|
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
406 |
params = {"workspace": self.db.workspace, "status": status}
|
407 |
return await self.db.query(SQL, params, multirows=True)
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
async def delete_entity(self, entity_name: str) -> None:
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
async def index_done_callback(self) -> None:
|
416 |
# Ti handles persistence automatically
|
417 |
pass
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
420 |
"""Search for records with IDs starting with a specific prefix.
|
421 |
|
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|
710 |
# Ti handles persistence automatically
|
711 |
pass
|
712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
async def delete_node(self, node_id: str) -> None:
|
714 |
"""Delete a node and all its related edges
|
715 |
|
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
|
|
1129 |
FROM LIGHTRAG_DOC_CHUNKS
|
1130 |
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
1131 |
""",
|
|
|
|
|
1132 |
}
|
|
|
20 |
if not pm.is_installed("sqlalchemy"):
|
21 |
pm.install("sqlalchemy")
|
22 |
|
23 |
+
from sqlalchemy import create_engine, text # type: ignore
|
24 |
|
25 |
|
26 |
class TiDB:
|
|
|
278 |
# Ti handles persistence automatically
|
279 |
pass
|
280 |
|
281 |
+
async def delete(self, ids: list[str]) -> None:
|
282 |
+
"""Delete records with specified IDs from the storage.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
ids: List of record IDs to be deleted
|
286 |
+
"""
|
287 |
+
if not ids:
|
288 |
+
return
|
289 |
+
|
290 |
+
try:
|
291 |
+
table_name = namespace_to_table_name(self.namespace)
|
292 |
+
id_field = namespace_to_id(self.namespace)
|
293 |
+
|
294 |
+
if not table_name or not id_field:
|
295 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
296 |
+
return
|
297 |
+
|
298 |
+
ids_list = ",".join([f"'{id}'" for id in ids])
|
299 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
300 |
+
|
301 |
+
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
302 |
+
logger.info(
|
303 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
304 |
+
)
|
305 |
+
except Exception as e:
|
306 |
+
logger.error(f"Error deleting records from {self.namespace}: {e}")
|
307 |
+
|
308 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
309 |
+
"""Delete specific records from storage by cache mode
|
310 |
+
|
311 |
+
Args:
|
312 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
bool: True if successful, False otherwise
|
316 |
+
"""
|
317 |
+
if not modes:
|
318 |
+
return False
|
319 |
+
|
320 |
+
try:
|
321 |
+
table_name = namespace_to_table_name(self.namespace)
|
322 |
+
if not table_name:
|
323 |
+
return False
|
324 |
+
|
325 |
+
if table_name != "LIGHTRAG_LLM_CACHE":
|
326 |
+
return False
|
327 |
+
|
328 |
+
# 构建MySQL风格的IN查询
|
329 |
+
modes_list = ", ".join([f"'{mode}'" for mode in modes])
|
330 |
+
sql = f"""
|
331 |
+
DELETE FROM {table_name}
|
332 |
+
WHERE workspace = :workspace
|
333 |
+
AND mode IN ({modes_list})
|
334 |
+
"""
|
335 |
+
|
336 |
+
logger.info(f"Deleting cache by modes: {modes}")
|
337 |
+
await self.db.execute(sql, {"workspace": self.db.workspace})
|
338 |
+
return True
|
339 |
+
except Exception as e:
|
340 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
341 |
+
return False
|
342 |
+
|
343 |
+
async def drop(self) -> dict[str, str]:
|
344 |
+
"""Drop the storage"""
|
345 |
+
try:
|
346 |
+
table_name = namespace_to_table_name(self.namespace)
|
347 |
+
if not table_name:
|
348 |
+
return {
|
349 |
+
"status": "error",
|
350 |
+
"message": f"Unknown namespace: {self.namespace}",
|
351 |
+
}
|
352 |
+
|
353 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
354 |
+
table_name=table_name
|
355 |
+
)
|
356 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
357 |
+
return {"status": "success", "message": "data dropped"}
|
358 |
+
except Exception as e:
|
359 |
+
return {"status": "error", "message": str(e)}
|
360 |
+
|
361 |
|
362 |
@final
|
363 |
@dataclass
|
|
|
486 |
params = {"workspace": self.db.workspace, "status": status}
|
487 |
return await self.db.query(SQL, params, multirows=True)
|
488 |
|
489 |
+
async def delete(self, ids: list[str]) -> None:
|
490 |
+
"""Delete vectors with specified IDs from the storage.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
ids: List of vector IDs to be deleted
|
494 |
+
"""
|
495 |
+
if not ids:
|
496 |
+
return
|
497 |
+
|
498 |
+
table_name = namespace_to_table_name(self.namespace)
|
499 |
+
id_field = namespace_to_id(self.namespace)
|
500 |
+
|
501 |
+
if not table_name or not id_field:
|
502 |
+
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
503 |
+
return
|
504 |
+
|
505 |
+
ids_list = ",".join([f"'{id}'" for id in ids])
|
506 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
507 |
+
|
508 |
+
try:
|
509 |
+
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
510 |
+
logger.debug(
|
511 |
+
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
512 |
+
)
|
513 |
+
except Exception as e:
|
514 |
+
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
515 |
+
|
516 |
async def delete_entity(self, entity_name: str) -> None:
|
517 |
+
"""Delete an entity by its name from the vector storage.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
entity_name: The name of the entity to delete
|
521 |
+
"""
|
522 |
+
try:
|
523 |
+
# Construct SQL to delete the entity
|
524 |
+
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
|
525 |
+
WHERE workspace = :workspace AND name = :entity_name"""
|
526 |
+
|
527 |
+
await self.db.execute(
|
528 |
+
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
529 |
+
)
|
530 |
+
logger.debug(f"Successfully deleted entity {entity_name}")
|
531 |
+
except Exception as e:
|
532 |
+
logger.error(f"Error deleting entity {entity_name}: {e}")
|
533 |
|
534 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
535 |
+
"""Delete all relations associated with an entity.
|
536 |
+
|
537 |
+
Args:
|
538 |
+
entity_name: The name of the entity whose relations should be deleted
|
539 |
+
"""
|
540 |
+
try:
|
541 |
+
# Delete relations where the entity is either the source or target
|
542 |
+
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
|
543 |
+
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
|
544 |
+
|
545 |
+
await self.db.execute(
|
546 |
+
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
547 |
+
)
|
548 |
+
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
549 |
+
except Exception as e:
|
550 |
+
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
551 |
|
552 |
async def index_done_callback(self) -> None:
|
553 |
# Ti handles persistence automatically
|
554 |
pass
|
555 |
|
556 |
+
async def drop(self) -> dict[str, str]:
|
557 |
+
"""Drop the storage"""
|
558 |
+
try:
|
559 |
+
table_name = namespace_to_table_name(self.namespace)
|
560 |
+
if not table_name:
|
561 |
+
return {
|
562 |
+
"status": "error",
|
563 |
+
"message": f"Unknown namespace: {self.namespace}",
|
564 |
+
}
|
565 |
+
|
566 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
567 |
+
table_name=table_name
|
568 |
+
)
|
569 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
570 |
+
return {"status": "success", "message": "data dropped"}
|
571 |
+
except Exception as e:
|
572 |
+
return {"status": "error", "message": str(e)}
|
573 |
+
|
574 |
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
575 |
"""Search for records with IDs starting with a specific prefix.
|
576 |
|
|
|
865 |
# Ti handles persistence automatically
|
866 |
pass
|
867 |
|
868 |
+
async def drop(self) -> dict[str, str]:
|
869 |
+
"""Drop the storage"""
|
870 |
+
try:
|
871 |
+
drop_sql = """
|
872 |
+
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
|
873 |
+
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
|
874 |
+
"""
|
875 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
876 |
+
return {"status": "success", "message": "graph data dropped"}
|
877 |
+
except Exception as e:
|
878 |
+
return {"status": "error", "message": str(e)}
|
879 |
+
|
880 |
async def delete_node(self, node_id: str) -> None:
|
881 |
"""Delete a node and all its related edges
|
882 |
|
|
|
1296 |
FROM LIGHTRAG_DOC_CHUNKS
|
1297 |
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
1298 |
""",
|
1299 |
+
# Drop tables
|
1300 |
+
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
|
1301 |
}
|
lightrag/lightrag.py
CHANGED
@@ -13,7 +13,6 @@ import pandas as pd
|
|
13 |
|
14 |
|
15 |
from lightrag.kg import (
|
16 |
-
STORAGE_ENV_REQUIREMENTS,
|
17 |
STORAGES,
|
18 |
verify_storage_implementation,
|
19 |
)
|
@@ -230,6 +229,7 @@ class LightRAG:
|
|
230 |
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
231 |
"""Additional parameters for vector database storage."""
|
232 |
|
|
|
233 |
namespace_prefix: str = field(default="")
|
234 |
"""Prefix for namespacing stored data across different environments."""
|
235 |
|
@@ -510,36 +510,22 @@ class LightRAG:
|
|
510 |
self,
|
511 |
node_label: str,
|
512 |
max_depth: int = 3,
|
513 |
-
|
514 |
-
inclusive: bool = False,
|
515 |
) -> KnowledgeGraph:
|
516 |
"""Get knowledge graph for a given label
|
517 |
|
518 |
Args:
|
519 |
node_label (str): Label to get knowledge graph for
|
520 |
max_depth (int): Maximum depth of graph
|
521 |
-
|
522 |
-
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
523 |
|
524 |
Returns:
|
525 |
KnowledgeGraph: Knowledge graph containing nodes and edges
|
526 |
"""
|
527 |
-
# get params supported by get_knowledge_graph of specified storage
|
528 |
-
import inspect
|
529 |
|
530 |
-
|
531 |
-
|
532 |
-
)
|
533 |
-
|
534 |
-
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
535 |
-
|
536 |
-
if "min_degree" in storage_params and min_degree > 0:
|
537 |
-
kwargs["min_degree"] = min_degree
|
538 |
-
|
539 |
-
if "inclusive" in storage_params:
|
540 |
-
kwargs["inclusive"] = inclusive
|
541 |
-
|
542 |
-
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
543 |
|
544 |
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
545 |
import_path = STORAGES[storage_name]
|
@@ -1449,6 +1435,7 @@ class LightRAG:
|
|
1449 |
loop = always_get_an_event_loop()
|
1450 |
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
1451 |
|
|
|
1452 |
async def adelete_by_entity(self, entity_name: str) -> None:
|
1453 |
try:
|
1454 |
await self.entities_vdb.delete_entity(entity_name)
|
@@ -1486,6 +1473,7 @@ class LightRAG:
|
|
1486 |
self.adelete_by_relation(source_entity, target_entity)
|
1487 |
)
|
1488 |
|
|
|
1489 |
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
1490 |
"""Asynchronously delete a relation between two entities.
|
1491 |
|
@@ -1494,6 +1482,7 @@ class LightRAG:
|
|
1494 |
target_entity: Name of the target entity
|
1495 |
"""
|
1496 |
try:
|
|
|
1497 |
# Check if the relation exists
|
1498 |
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
1499 |
source_entity, target_entity
|
@@ -1554,6 +1543,7 @@ class LightRAG:
|
|
1554 |
"""
|
1555 |
return await self.doc_status.get_docs_by_status(status)
|
1556 |
|
|
|
1557 |
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
1558 |
"""Delete a document and all its related data
|
1559 |
|
@@ -1586,6 +1576,8 @@ class LightRAG:
|
|
1586 |
chunk_ids = set(related_chunks.keys())
|
1587 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
1588 |
|
|
|
|
|
1589 |
# 3. Before deleting, check the related entities and relationships for these chunks
|
1590 |
for chunk_id in chunk_ids:
|
1591 |
# Check entities
|
@@ -1857,24 +1849,6 @@ class LightRAG:
|
|
1857 |
|
1858 |
return result
|
1859 |
|
1860 |
-
def check_storage_env_vars(self, storage_name: str) -> None:
|
1861 |
-
"""Check if all required environment variables for storage implementation exist
|
1862 |
-
|
1863 |
-
Args:
|
1864 |
-
storage_name: Storage implementation name
|
1865 |
-
|
1866 |
-
Raises:
|
1867 |
-
ValueError: If required environment variables are missing
|
1868 |
-
"""
|
1869 |
-
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
1870 |
-
missing_vars = [var for var in required_vars if var not in os.environ]
|
1871 |
-
|
1872 |
-
if missing_vars:
|
1873 |
-
raise ValueError(
|
1874 |
-
f"Storage implementation '{storage_name}' requires the following "
|
1875 |
-
f"environment variables: {', '.join(missing_vars)}"
|
1876 |
-
)
|
1877 |
-
|
1878 |
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
1879 |
"""Clear cache data from the LLM response cache storage.
|
1880 |
|
@@ -1906,12 +1880,18 @@ class LightRAG:
|
|
1906 |
try:
|
1907 |
# Reset the cache storage for specified mode
|
1908 |
if modes:
|
1909 |
-
await self.llm_response_cache.
|
1910 |
-
|
|
|
|
|
|
|
1911 |
else:
|
1912 |
# Clear all modes
|
1913 |
-
await self.llm_response_cache.
|
1914 |
-
|
|
|
|
|
|
|
1915 |
|
1916 |
await self.llm_response_cache.index_done_callback()
|
1917 |
|
@@ -1922,6 +1902,7 @@ class LightRAG:
|
|
1922 |
"""Synchronous version of aclear_cache."""
|
1923 |
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
1924 |
|
|
|
1925 |
async def aedit_entity(
|
1926 |
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
1927 |
) -> dict[str, Any]:
|
@@ -2134,6 +2115,7 @@ class LightRAG:
|
|
2134 |
]
|
2135 |
)
|
2136 |
|
|
|
2137 |
async def aedit_relation(
|
2138 |
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
2139 |
) -> dict[str, Any]:
|
@@ -2448,6 +2430,7 @@ class LightRAG:
|
|
2448 |
self.acreate_relation(source_entity, target_entity, relation_data)
|
2449 |
)
|
2450 |
|
|
|
2451 |
async def amerge_entities(
|
2452 |
self,
|
2453 |
source_entities: list[str],
|
|
|
13 |
|
14 |
|
15 |
from lightrag.kg import (
|
|
|
16 |
STORAGES,
|
17 |
verify_storage_implementation,
|
18 |
)
|
|
|
229 |
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
230 |
"""Additional parameters for vector database storage."""
|
231 |
|
232 |
+
# TODO:deprecated, remove in the future, use WORKSPACE instead
|
233 |
namespace_prefix: str = field(default="")
|
234 |
"""Prefix for namespacing stored data across different environments."""
|
235 |
|
|
|
510 |
self,
|
511 |
node_label: str,
|
512 |
max_depth: int = 3,
|
513 |
+
max_nodes: int = 1000,
|
|
|
514 |
) -> KnowledgeGraph:
|
515 |
"""Get knowledge graph for a given label
|
516 |
|
517 |
Args:
|
518 |
node_label (str): Label to get knowledge graph for
|
519 |
max_depth (int): Maximum depth of graph
|
520 |
+
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
|
|
521 |
|
522 |
Returns:
|
523 |
KnowledgeGraph: Knowledge graph containing nodes and edges
|
524 |
"""
|
|
|
|
|
525 |
|
526 |
+
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
527 |
+
node_label, max_depth, max_nodes
|
528 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
531 |
import_path = STORAGES[storage_name]
|
|
|
1435 |
loop = always_get_an_event_loop()
|
1436 |
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
1437 |
|
1438 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
1439 |
async def adelete_by_entity(self, entity_name: str) -> None:
|
1440 |
try:
|
1441 |
await self.entities_vdb.delete_entity(entity_name)
|
|
|
1473 |
self.adelete_by_relation(source_entity, target_entity)
|
1474 |
)
|
1475 |
|
1476 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
1477 |
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
1478 |
"""Asynchronously delete a relation between two entities.
|
1479 |
|
|
|
1482 |
target_entity: Name of the target entity
|
1483 |
"""
|
1484 |
try:
|
1485 |
+
# TODO: check if has_edge function works on reverse relation
|
1486 |
# Check if the relation exists
|
1487 |
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
1488 |
source_entity, target_entity
|
|
|
1543 |
"""
|
1544 |
return await self.doc_status.get_docs_by_status(status)
|
1545 |
|
1546 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
1547 |
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
1548 |
"""Delete a document and all its related data
|
1549 |
|
|
|
1576 |
chunk_ids = set(related_chunks.keys())
|
1577 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
1578 |
|
1579 |
+
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
|
1580 |
+
|
1581 |
# 3. Before deleting, check the related entities and relationships for these chunks
|
1582 |
for chunk_id in chunk_ids:
|
1583 |
# Check entities
|
|
|
1849 |
|
1850 |
return result
|
1851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1852 |
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
1853 |
"""Clear cache data from the LLM response cache storage.
|
1854 |
|
|
|
1880 |
try:
|
1881 |
# Reset the cache storage for specified mode
|
1882 |
if modes:
|
1883 |
+
success = await self.llm_response_cache.drop_cache_by_modes(modes)
|
1884 |
+
if success:
|
1885 |
+
logger.info(f"Cleared cache for modes: {modes}")
|
1886 |
+
else:
|
1887 |
+
logger.warning(f"Failed to clear cache for modes: {modes}")
|
1888 |
else:
|
1889 |
# Clear all modes
|
1890 |
+
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
|
1891 |
+
if success:
|
1892 |
+
logger.info("Cleared all cache")
|
1893 |
+
else:
|
1894 |
+
logger.warning("Failed to clear all cache")
|
1895 |
|
1896 |
await self.llm_response_cache.index_done_callback()
|
1897 |
|
|
|
1902 |
"""Synchronous version of aclear_cache."""
|
1903 |
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
1904 |
|
1905 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
1906 |
async def aedit_entity(
|
1907 |
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
1908 |
) -> dict[str, Any]:
|
|
|
2115 |
]
|
2116 |
)
|
2117 |
|
2118 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
2119 |
async def aedit_relation(
|
2120 |
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
2121 |
) -> dict[str, Any]:
|
|
|
2430 |
self.acreate_relation(source_entity, target_entity, relation_data)
|
2431 |
)
|
2432 |
|
2433 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
2434 |
async def amerge_entities(
|
2435 |
self,
|
2436 |
source_entities: list[str],
|
lightrag/operate.py
CHANGED
@@ -25,7 +25,6 @@ from .utils import (
|
|
25 |
CacheData,
|
26 |
statistic_data,
|
27 |
get_conversation_turns,
|
28 |
-
verbose_debug,
|
29 |
)
|
30 |
from .base import (
|
31 |
BaseGraphStorage,
|
@@ -441,6 +440,13 @@ async def extract_entities(
|
|
441 |
|
442 |
processed_chunks = 0
|
443 |
total_chunks = len(ordered_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
async def _user_llm_func_with_cache(
|
446 |
input_text: str, history_messages: list[dict[str, str]] = None
|
@@ -539,7 +545,7 @@ async def extract_entities(
|
|
539 |
chunk_key_dp (tuple[str, TextChunkSchema]):
|
540 |
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
541 |
"""
|
542 |
-
nonlocal processed_chunks
|
543 |
chunk_key = chunk_key_dp[0]
|
544 |
chunk_dp = chunk_key_dp[1]
|
545 |
content = chunk_dp["content"]
|
@@ -597,102 +603,74 @@ async def extract_entities(
|
|
597 |
async with pipeline_status_lock:
|
598 |
pipeline_status["latest_message"] = log_message
|
599 |
pipeline_status["history_messages"].append(log_message)
|
600 |
-
return dict(maybe_nodes), dict(maybe_edges)
|
601 |
-
|
602 |
-
tasks = [_process_single_content(c) for c in ordered_chunks]
|
603 |
-
results = await asyncio.gather(*tasks)
|
604 |
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
for k, v in m_nodes.items():
|
609 |
-
maybe_nodes[k].extend(v)
|
610 |
-
for k, v in m_edges.items():
|
611 |
-
maybe_edges[tuple(sorted(k))].extend(v)
|
612 |
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
async with graph_db_lock:
|
619 |
-
all_entities_data = await asyncio.gather(
|
620 |
-
*[
|
621 |
-
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
622 |
-
for k, v in maybe_nodes.items()
|
623 |
-
]
|
624 |
-
)
|
625 |
-
|
626 |
-
all_relationships_data = await asyncio.gather(
|
627 |
-
*[
|
628 |
-
_merge_edges_then_upsert(
|
629 |
-
k[0], k[1], v, knowledge_graph_inst, global_config
|
630 |
)
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
if pipeline_status is not None:
|
639 |
-
async with pipeline_status_lock:
|
640 |
-
pipeline_status["latest_message"] = log_message
|
641 |
-
pipeline_status["history_messages"].append(log_message)
|
642 |
-
return
|
643 |
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
if pipeline_status is not None:
|
648 |
-
async with pipeline_status_lock:
|
649 |
-
pipeline_status["latest_message"] = log_message
|
650 |
-
pipeline_status["history_messages"].append(log_message)
|
651 |
-
if not all_relationships_data:
|
652 |
-
log_message = "Didn't extract any relationships"
|
653 |
-
logger.info(log_message)
|
654 |
-
if pipeline_status is not None:
|
655 |
-
async with pipeline_status_lock:
|
656 |
-
pipeline_status["latest_message"] = log_message
|
657 |
-
pipeline_status["history_messages"].append(log_message)
|
658 |
|
659 |
-
log_message = f"Extracted {
|
660 |
logger.info(log_message)
|
661 |
if pipeline_status is not None:
|
662 |
async with pipeline_status_lock:
|
663 |
pipeline_status["latest_message"] = log_message
|
664 |
pipeline_status["history_messages"].append(log_message)
|
665 |
-
verbose_debug(
|
666 |
-
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
667 |
-
)
|
668 |
-
verbose_debug(f"New relationships:{all_relationships_data}")
|
669 |
-
|
670 |
-
if entity_vdb is not None:
|
671 |
-
data_for_vdb = {
|
672 |
-
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
673 |
-
"entity_name": dp["entity_name"],
|
674 |
-
"entity_type": dp["entity_type"],
|
675 |
-
"content": f"{dp['entity_name']}\n{dp['description']}",
|
676 |
-
"source_id": dp["source_id"],
|
677 |
-
"file_path": dp.get("file_path", "unknown_source"),
|
678 |
-
}
|
679 |
-
for dp in all_entities_data
|
680 |
-
}
|
681 |
-
await entity_vdb.upsert(data_for_vdb)
|
682 |
-
|
683 |
-
if relationships_vdb is not None:
|
684 |
-
data_for_vdb = {
|
685 |
-
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
686 |
-
"src_id": dp["src_id"],
|
687 |
-
"tgt_id": dp["tgt_id"],
|
688 |
-
"keywords": dp["keywords"],
|
689 |
-
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
690 |
-
"source_id": dp["source_id"],
|
691 |
-
"file_path": dp.get("file_path", "unknown_source"),
|
692 |
-
}
|
693 |
-
for dp in all_relationships_data
|
694 |
-
}
|
695 |
-
await relationships_vdb.upsert(data_for_vdb)
|
696 |
|
697 |
|
698 |
async def kg_query(
|
@@ -1367,7 +1345,9 @@ async def _get_node_data(
|
|
1367 |
|
1368 |
text_units_section_list = [["id", "content", "file_path"]]
|
1369 |
for i, t in enumerate(use_text_units):
|
1370 |
-
text_units_section_list.append(
|
|
|
|
|
1371 |
text_units_context = list_of_list_to_csv(text_units_section_list)
|
1372 |
return entities_context, relations_context, text_units_context
|
1373 |
|
|
|
25 |
CacheData,
|
26 |
statistic_data,
|
27 |
get_conversation_turns,
|
|
|
28 |
)
|
29 |
from .base import (
|
30 |
BaseGraphStorage,
|
|
|
440 |
|
441 |
processed_chunks = 0
|
442 |
total_chunks = len(ordered_chunks)
|
443 |
+
total_entities_count = 0
|
444 |
+
total_relations_count = 0
|
445 |
+
|
446 |
+
# Get lock manager from shared storage
|
447 |
+
from .kg.shared_storage import get_graph_db_lock
|
448 |
+
|
449 |
+
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
450 |
|
451 |
async def _user_llm_func_with_cache(
|
452 |
input_text: str, history_messages: list[dict[str, str]] = None
|
|
|
545 |
chunk_key_dp (tuple[str, TextChunkSchema]):
|
546 |
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
547 |
"""
|
548 |
+
nonlocal processed_chunks, total_entities_count, total_relations_count
|
549 |
chunk_key = chunk_key_dp[0]
|
550 |
chunk_dp = chunk_key_dp[1]
|
551 |
content = chunk_dp["content"]
|
|
|
603 |
async with pipeline_status_lock:
|
604 |
pipeline_status["latest_message"] = log_message
|
605 |
pipeline_status["history_messages"].append(log_message)
|
|
|
|
|
|
|
|
|
606 |
|
607 |
+
# Use graph database lock to ensure atomic merges and updates
|
608 |
+
chunk_entities_data = []
|
609 |
+
chunk_relationships_data = []
|
|
|
|
|
|
|
|
|
610 |
|
611 |
+
async with graph_db_lock:
|
612 |
+
# Process and update entities
|
613 |
+
for entity_name, entities in maybe_nodes.items():
|
614 |
+
entity_data = await _merge_nodes_then_upsert(
|
615 |
+
entity_name, entities, knowledge_graph_inst, global_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
)
|
617 |
+
chunk_entities_data.append(entity_data)
|
618 |
+
|
619 |
+
# Process and update relationships
|
620 |
+
for edge_key, edges in maybe_edges.items():
|
621 |
+
# Ensure edge direction consistency
|
622 |
+
sorted_edge_key = tuple(sorted(edge_key))
|
623 |
+
edge_data = await _merge_edges_then_upsert(
|
624 |
+
sorted_edge_key[0],
|
625 |
+
sorted_edge_key[1],
|
626 |
+
edges,
|
627 |
+
knowledge_graph_inst,
|
628 |
+
global_config,
|
629 |
+
)
|
630 |
+
chunk_relationships_data.append(edge_data)
|
631 |
+
|
632 |
+
# Update vector database (within the same lock to ensure atomicity)
|
633 |
+
if entity_vdb is not None and chunk_entities_data:
|
634 |
+
data_for_vdb = {
|
635 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
636 |
+
"entity_name": dp["entity_name"],
|
637 |
+
"entity_type": dp["entity_type"],
|
638 |
+
"content": f"{dp['entity_name']}\n{dp['description']}",
|
639 |
+
"source_id": dp["source_id"],
|
640 |
+
"file_path": dp.get("file_path", "unknown_source"),
|
641 |
+
}
|
642 |
+
for dp in chunk_entities_data
|
643 |
+
}
|
644 |
+
await entity_vdb.upsert(data_for_vdb)
|
645 |
+
|
646 |
+
if relationships_vdb is not None and chunk_relationships_data:
|
647 |
+
data_for_vdb = {
|
648 |
+
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
649 |
+
"src_id": dp["src_id"],
|
650 |
+
"tgt_id": dp["tgt_id"],
|
651 |
+
"keywords": dp["keywords"],
|
652 |
+
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
653 |
+
"source_id": dp["source_id"],
|
654 |
+
"file_path": dp.get("file_path", "unknown_source"),
|
655 |
+
}
|
656 |
+
for dp in chunk_relationships_data
|
657 |
+
}
|
658 |
+
await relationships_vdb.upsert(data_for_vdb)
|
659 |
|
660 |
+
# Update counters
|
661 |
+
total_entities_count += len(chunk_entities_data)
|
662 |
+
total_relations_count += len(chunk_relationships_data)
|
|
|
|
|
|
|
|
|
|
|
663 |
|
664 |
+
# Handle all chunks in parallel
|
665 |
+
tasks = [_process_single_content(c) for c in ordered_chunks]
|
666 |
+
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
+
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
669 |
logger.info(log_message)
|
670 |
if pipeline_status is not None:
|
671 |
async with pipeline_status_lock:
|
672 |
pipeline_status["latest_message"] = log_message
|
673 |
pipeline_status["history_messages"].append(log_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
|
675 |
|
676 |
async def kg_query(
|
|
|
1345 |
|
1346 |
text_units_section_list = [["id", "content", "file_path"]]
|
1347 |
for i, t in enumerate(use_text_units):
|
1348 |
+
text_units_section_list.append(
|
1349 |
+
[i, t["content"], t.get("file_path", "unknown_source")]
|
1350 |
+
)
|
1351 |
text_units_context = list_of_list_to_csv(text_units_section_list)
|
1352 |
return entities_context, relations_context, text_units_context
|
1353 |
|
lightrag/types.py
CHANGED
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
|
|
26 |
class KnowledgeGraph(BaseModel):
|
27 |
nodes: list[KnowledgeGraphNode] = []
|
28 |
edges: list[KnowledgeGraphEdge] = []
|
|
|
|
26 |
class KnowledgeGraph(BaseModel):
|
27 |
nodes: list[KnowledgeGraphNode] = []
|
28 |
edges: list[KnowledgeGraphEdge] = []
|
29 |
+
is_truncated: bool = False
|
lightrag_webui/src/AppRouter.tsx
CHANGED
@@ -80,7 +80,12 @@ const AppRouter = () => {
|
|
80 |
<ThemeProvider>
|
81 |
<Router>
|
82 |
<AppContent />
|
83 |
-
<Toaster
|
|
|
|
|
|
|
|
|
|
|
84 |
</Router>
|
85 |
</ThemeProvider>
|
86 |
)
|
|
|
80 |
<ThemeProvider>
|
81 |
<Router>
|
82 |
<AppContent />
|
83 |
+
<Toaster
|
84 |
+
position="bottom-center"
|
85 |
+
theme="system"
|
86 |
+
closeButton
|
87 |
+
richColors
|
88 |
+
/>
|
89 |
</Router>
|
90 |
</ThemeProvider>
|
91 |
)
|
lightrag_webui/src/api/lightrag.ts
CHANGED
@@ -3,6 +3,7 @@ import { backendBaseUrl } from '@/lib/constants'
|
|
3 |
import { errorMessage } from '@/lib/utils'
|
4 |
import { useSettingsStore } from '@/stores/settings'
|
5 |
import { navigationService } from '@/services/navigation'
|
|
|
6 |
|
7 |
// Types
|
8 |
export type LightragNodeType = {
|
@@ -46,6 +47,8 @@ export type LightragStatus = {
|
|
46 |
api_version?: string
|
47 |
auth_mode?: 'enabled' | 'disabled'
|
48 |
pipeline_busy: boolean
|
|
|
|
|
49 |
}
|
50 |
|
51 |
export type LightragDocumentsScanProgress = {
|
@@ -140,6 +143,8 @@ export type AuthStatusResponse = {
|
|
140 |
message?: string
|
141 |
core_version?: string
|
142 |
api_version?: string
|
|
|
|
|
143 |
}
|
144 |
|
145 |
export type PipelineStatusResponse = {
|
@@ -163,6 +168,8 @@ export type LoginResponse = {
|
|
163 |
message?: string // Optional message
|
164 |
core_version?: string
|
165 |
api_version?: string
|
|
|
|
|
166 |
}
|
167 |
|
168 |
export const InvalidApiKeyError = 'Invalid API Key'
|
@@ -221,9 +228,9 @@ axiosInstance.interceptors.response.use(
|
|
221 |
export const queryGraphs = async (
|
222 |
label: string,
|
223 |
maxDepth: number,
|
224 |
-
|
225 |
): Promise<LightragGraphType> => {
|
226 |
-
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&
|
227 |
return response.data
|
228 |
}
|
229 |
|
@@ -382,6 +389,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
|
|
382 |
return response.data
|
383 |
}
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
386 |
try {
|
387 |
// Add a timeout to the request to prevent hanging
|
@@ -411,12 +426,26 @@ export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
|
411 |
// For unconfigured auth, ensure we have an access token
|
412 |
if (!response.data.auth_configured) {
|
413 |
if (response.data.access_token && typeof response.data.access_token === 'string') {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
return response.data;
|
415 |
} else {
|
416 |
console.warn('Auth not configured but no valid access token provided');
|
417 |
}
|
418 |
} else {
|
419 |
// For configured auth, just return the data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
return response.data;
|
421 |
}
|
422 |
}
|
@@ -455,5 +484,13 @@ export const loginToServer = async (username: string, password: string): Promise
|
|
455 |
}
|
456 |
});
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
return response.data;
|
459 |
}
|
|
|
3 |
import { errorMessage } from '@/lib/utils'
|
4 |
import { useSettingsStore } from '@/stores/settings'
|
5 |
import { navigationService } from '@/services/navigation'
|
6 |
+
import { useAuthStore } from '@/stores/state'
|
7 |
|
8 |
// Types
|
9 |
export type LightragNodeType = {
|
|
|
47 |
api_version?: string
|
48 |
auth_mode?: 'enabled' | 'disabled'
|
49 |
pipeline_busy: boolean
|
50 |
+
webui_title?: string
|
51 |
+
webui_description?: string
|
52 |
}
|
53 |
|
54 |
export type LightragDocumentsScanProgress = {
|
|
|
143 |
message?: string
|
144 |
core_version?: string
|
145 |
api_version?: string
|
146 |
+
webui_title?: string
|
147 |
+
webui_description?: string
|
148 |
}
|
149 |
|
150 |
export type PipelineStatusResponse = {
|
|
|
168 |
message?: string // Optional message
|
169 |
core_version?: string
|
170 |
api_version?: string
|
171 |
+
webui_title?: string
|
172 |
+
webui_description?: string
|
173 |
}
|
174 |
|
175 |
export const InvalidApiKeyError = 'Invalid API Key'
|
|
|
228 |
export const queryGraphs = async (
|
229 |
label: string,
|
230 |
maxDepth: number,
|
231 |
+
maxNodes: number
|
232 |
): Promise<LightragGraphType> => {
|
233 |
+
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
|
234 |
return response.data
|
235 |
}
|
236 |
|
|
|
389 |
return response.data
|
390 |
}
|
391 |
|
392 |
+
export const clearCache = async (modes?: string[]): Promise<{
|
393 |
+
status: 'success' | 'fail'
|
394 |
+
message: string
|
395 |
+
}> => {
|
396 |
+
const response = await axiosInstance.post('/documents/clear_cache', { modes })
|
397 |
+
return response.data
|
398 |
+
}
|
399 |
+
|
400 |
export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
401 |
try {
|
402 |
// Add a timeout to the request to prevent hanging
|
|
|
426 |
// For unconfigured auth, ensure we have an access token
|
427 |
if (!response.data.auth_configured) {
|
428 |
if (response.data.access_token && typeof response.data.access_token === 'string') {
|
429 |
+
// Update custom title if available
|
430 |
+
if ('webui_title' in response.data || 'webui_description' in response.data) {
|
431 |
+
useAuthStore.getState().setCustomTitle(
|
432 |
+
'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
|
433 |
+
'webui_description' in response.data ? (response.data.webui_description ?? null) : null
|
434 |
+
);
|
435 |
+
}
|
436 |
return response.data;
|
437 |
} else {
|
438 |
console.warn('Auth not configured but no valid access token provided');
|
439 |
}
|
440 |
} else {
|
441 |
// For configured auth, just return the data
|
442 |
+
// Update custom title if available
|
443 |
+
if ('webui_title' in response.data || 'webui_description' in response.data) {
|
444 |
+
useAuthStore.getState().setCustomTitle(
|
445 |
+
'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
|
446 |
+
'webui_description' in response.data ? (response.data.webui_description ?? null) : null
|
447 |
+
);
|
448 |
+
}
|
449 |
return response.data;
|
450 |
}
|
451 |
}
|
|
|
484 |
}
|
485 |
});
|
486 |
|
487 |
+
// Update custom title if available
|
488 |
+
if ('webui_title' in response.data || 'webui_description' in response.data) {
|
489 |
+
useAuthStore.getState().setCustomTitle(
|
490 |
+
'webui_title' in response.data ? (response.data.webui_title ?? null) : null,
|
491 |
+
'webui_description' in response.data ? (response.data.webui_description ?? null) : null
|
492 |
+
);
|
493 |
+
}
|
494 |
+
|
495 |
return response.data;
|
496 |
}
|
lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import { useState, useCallback } from 'react'
|
2 |
import Button from '@/components/ui/Button'
|
3 |
import {
|
4 |
Dialog,
|
@@ -6,32 +6,88 @@ import {
|
|
6 |
DialogDescription,
|
7 |
DialogHeader,
|
8 |
DialogTitle,
|
9 |
-
DialogTrigger
|
|
|
10 |
} from '@/components/ui/Dialog'
|
|
|
|
|
11 |
import { toast } from 'sonner'
|
12 |
import { errorMessage } from '@/lib/utils'
|
13 |
-
import { clearDocuments } from '@/api/lightrag'
|
14 |
|
15 |
-
import { EraserIcon } from 'lucide-react'
|
16 |
import { useTranslation } from 'react-i18next'
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
const { t } = useTranslation()
|
20 |
const [open, setOpen] = useState(false)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
const handleClear = useCallback(async () => {
|
|
|
|
|
23 |
try {
|
24 |
const result = await clearDocuments()
|
25 |
-
|
26 |
-
|
27 |
-
setOpen(false)
|
28 |
-
} else {
|
29 |
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
|
|
|
|
|
30 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
} catch (err) {
|
32 |
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
|
|
|
33 |
}
|
34 |
-
}, [setOpen, t])
|
35 |
|
36 |
return (
|
37 |
<Dialog open={open} onOpenChange={setOpen}>
|
@@ -42,12 +98,58 @@ export default function ClearDocumentsDialog() {
|
|
42 |
</DialogTrigger>
|
43 |
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
|
44 |
<DialogHeader>
|
45 |
-
<DialogTitle>
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
</DialogHeader>
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
</DialogContent>
|
52 |
</Dialog>
|
53 |
)
|
|
|
1 |
+
import { useState, useCallback, useEffect } from 'react'
|
2 |
import Button from '@/components/ui/Button'
|
3 |
import {
|
4 |
Dialog,
|
|
|
6 |
DialogDescription,
|
7 |
DialogHeader,
|
8 |
DialogTitle,
|
9 |
+
DialogTrigger,
|
10 |
+
DialogFooter
|
11 |
} from '@/components/ui/Dialog'
|
12 |
+
import Input from '@/components/ui/Input'
|
13 |
+
import Checkbox from '@/components/ui/Checkbox'
|
14 |
import { toast } from 'sonner'
|
15 |
import { errorMessage } from '@/lib/utils'
|
16 |
+
import { clearDocuments, clearCache } from '@/api/lightrag'
|
17 |
|
18 |
+
import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
|
19 |
import { useTranslation } from 'react-i18next'
|
20 |
|
21 |
+
// 简单的Label组件
|
22 |
+
const Label = ({
|
23 |
+
htmlFor,
|
24 |
+
className,
|
25 |
+
children,
|
26 |
+
...props
|
27 |
+
}: React.LabelHTMLAttributes<HTMLLabelElement>) => (
|
28 |
+
<label
|
29 |
+
htmlFor={htmlFor}
|
30 |
+
className={className}
|
31 |
+
{...props}
|
32 |
+
>
|
33 |
+
{children}
|
34 |
+
</label>
|
35 |
+
)
|
36 |
+
|
37 |
+
interface ClearDocumentsDialogProps {
|
38 |
+
onDocumentsCleared?: () => Promise<void>
|
39 |
+
}
|
40 |
+
|
41 |
+
export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
|
42 |
const { t } = useTranslation()
|
43 |
const [open, setOpen] = useState(false)
|
44 |
+
const [confirmText, setConfirmText] = useState('')
|
45 |
+
const [clearCacheOption, setClearCacheOption] = useState(false)
|
46 |
+
const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
|
47 |
+
|
48 |
+
// 重置状态当对话框关闭时
|
49 |
+
useEffect(() => {
|
50 |
+
if (!open) {
|
51 |
+
setConfirmText('')
|
52 |
+
setClearCacheOption(false)
|
53 |
+
}
|
54 |
+
}, [open])
|
55 |
|
56 |
const handleClear = useCallback(async () => {
|
57 |
+
if (!isConfirmEnabled) return
|
58 |
+
|
59 |
try {
|
60 |
const result = await clearDocuments()
|
61 |
+
|
62 |
+
if (result.status !== 'success') {
|
|
|
|
|
63 |
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
|
64 |
+
setConfirmText('')
|
65 |
+
return
|
66 |
}
|
67 |
+
|
68 |
+
toast.success(t('documentPanel.clearDocuments.success'))
|
69 |
+
|
70 |
+
if (clearCacheOption) {
|
71 |
+
try {
|
72 |
+
await clearCache()
|
73 |
+
toast.success(t('documentPanel.clearDocuments.cacheCleared'))
|
74 |
+
} catch (cacheErr) {
|
75 |
+
toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
// Refresh document list if provided
|
80 |
+
if (onDocumentsCleared) {
|
81 |
+
onDocumentsCleared().catch(console.error)
|
82 |
+
}
|
83 |
+
|
84 |
+
// 所有操作成功后关闭对话框
|
85 |
+
setOpen(false)
|
86 |
} catch (err) {
|
87 |
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
|
88 |
+
setConfirmText('')
|
89 |
}
|
90 |
+
}, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
|
91 |
|
92 |
return (
|
93 |
<Dialog open={open} onOpenChange={setOpen}>
|
|
|
98 |
</DialogTrigger>
|
99 |
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
|
100 |
<DialogHeader>
|
101 |
+
<DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
|
102 |
+
<AlertTriangleIcon className="h-5 w-5" />
|
103 |
+
{t('documentPanel.clearDocuments.title')}
|
104 |
+
</DialogTitle>
|
105 |
+
<DialogDescription className="pt-2">
|
106 |
+
<div className="text-red-500 dark:text-red-400 font-semibold mb-4">
|
107 |
+
{t('documentPanel.clearDocuments.warning')}
|
108 |
+
</div>
|
109 |
+
<div className="mb-4">
|
110 |
+
{t('documentPanel.clearDocuments.confirm')}
|
111 |
+
</div>
|
112 |
+
</DialogDescription>
|
113 |
</DialogHeader>
|
114 |
+
|
115 |
+
<div className="space-y-4">
|
116 |
+
<div className="space-y-2">
|
117 |
+
<Label htmlFor="confirm-text" className="text-sm font-medium">
|
118 |
+
{t('documentPanel.clearDocuments.confirmPrompt')}
|
119 |
+
</Label>
|
120 |
+
<Input
|
121 |
+
id="confirm-text"
|
122 |
+
value={confirmText}
|
123 |
+
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
|
124 |
+
placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
|
125 |
+
className="w-full"
|
126 |
+
/>
|
127 |
+
</div>
|
128 |
+
|
129 |
+
<div className="flex items-center space-x-2">
|
130 |
+
<Checkbox
|
131 |
+
id="clear-cache"
|
132 |
+
checked={clearCacheOption}
|
133 |
+
onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
|
134 |
+
/>
|
135 |
+
<Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
|
136 |
+
{t('documentPanel.clearDocuments.clearCache')}
|
137 |
+
</Label>
|
138 |
+
</div>
|
139 |
+
</div>
|
140 |
+
|
141 |
+
<DialogFooter>
|
142 |
+
<Button variant="outline" onClick={() => setOpen(false)}>
|
143 |
+
{t('common.cancel')}
|
144 |
+
</Button>
|
145 |
+
<Button
|
146 |
+
variant="destructive"
|
147 |
+
onClick={handleClear}
|
148 |
+
disabled={!isConfirmEnabled}
|
149 |
+
>
|
150 |
+
{t('documentPanel.clearDocuments.confirmButton')}
|
151 |
+
</Button>
|
152 |
+
</DialogFooter>
|
153 |
</DialogContent>
|
154 |
</Dialog>
|
155 |
)
|
lightrag_webui/src/components/documents/UploadDocumentsDialog.tsx
CHANGED
@@ -17,7 +17,11 @@ import { uploadDocument } from '@/api/lightrag'
|
|
17 |
import { UploadIcon } from 'lucide-react'
|
18 |
import { useTranslation } from 'react-i18next'
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
const { t } = useTranslation()
|
22 |
const [open, setOpen] = useState(false)
|
23 |
const [isUploading, setIsUploading] = useState(false)
|
@@ -55,6 +59,7 @@ export default function UploadDocumentsDialog() {
|
|
55 |
const handleDocumentsUpload = useCallback(
|
56 |
async (filesToUpload: File[]) => {
|
57 |
setIsUploading(true)
|
|
|
58 |
|
59 |
// Only clear errors for files that are being uploaded, keep errors for rejected files
|
60 |
setFileErrors(prev => {
|
@@ -101,6 +106,9 @@ export default function UploadDocumentsDialog() {
|
|
101 |
...prev,
|
102 |
[file.name]: result.message
|
103 |
}))
|
|
|
|
|
|
|
104 |
}
|
105 |
} catch (err) {
|
106 |
console.error(`Upload failed for ${file.name}:`, err)
|
@@ -142,6 +150,16 @@ export default function UploadDocumentsDialog() {
|
|
142 |
} else {
|
143 |
toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
|
144 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
} catch (err) {
|
146 |
console.error('Unexpected error during upload:', err)
|
147 |
toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
|
@@ -149,7 +167,7 @@ export default function UploadDocumentsDialog() {
|
|
149 |
setIsUploading(false)
|
150 |
}
|
151 |
},
|
152 |
-
[setIsUploading, setProgresses, setFileErrors, t]
|
153 |
)
|
154 |
|
155 |
return (
|
|
|
17 |
import { UploadIcon } from 'lucide-react'
|
18 |
import { useTranslation } from 'react-i18next'
|
19 |
|
20 |
+
interface UploadDocumentsDialogProps {
|
21 |
+
onDocumentsUploaded?: () => Promise<void>
|
22 |
+
}
|
23 |
+
|
24 |
+
export default function UploadDocumentsDialog({ onDocumentsUploaded }: UploadDocumentsDialogProps) {
|
25 |
const { t } = useTranslation()
|
26 |
const [open, setOpen] = useState(false)
|
27 |
const [isUploading, setIsUploading] = useState(false)
|
|
|
59 |
const handleDocumentsUpload = useCallback(
|
60 |
async (filesToUpload: File[]) => {
|
61 |
setIsUploading(true)
|
62 |
+
let hasSuccessfulUpload = false
|
63 |
|
64 |
// Only clear errors for files that are being uploaded, keep errors for rejected files
|
65 |
setFileErrors(prev => {
|
|
|
106 |
...prev,
|
107 |
[file.name]: result.message
|
108 |
}))
|
109 |
+
} else {
|
110 |
+
// Mark that we had at least one successful upload
|
111 |
+
hasSuccessfulUpload = true
|
112 |
}
|
113 |
} catch (err) {
|
114 |
console.error(`Upload failed for ${file.name}:`, err)
|
|
|
150 |
} else {
|
151 |
toast.success(t('documentPanel.uploadDocuments.batch.success'), { id: toastId })
|
152 |
}
|
153 |
+
|
154 |
+
// Only update if at least one file was uploaded successfully
|
155 |
+
if (hasSuccessfulUpload) {
|
156 |
+
// Refresh document list
|
157 |
+
if (onDocumentsUploaded) {
|
158 |
+
onDocumentsUploaded().catch(err => {
|
159 |
+
console.error('Error refreshing documents:', err)
|
160 |
+
})
|
161 |
+
}
|
162 |
+
}
|
163 |
} catch (err) {
|
164 |
console.error('Unexpected error during upload:', err)
|
165 |
toast.error(t('documentPanel.uploadDocuments.generalError', { error: errorMessage(err) }), { id: toastId })
|
|
|
167 |
setIsUploading(false)
|
168 |
}
|
169 |
},
|
170 |
+
[setIsUploading, setProgresses, setFileErrors, t, onDocumentsUploaded]
|
171 |
)
|
172 |
|
173 |
return (
|
lightrag_webui/src/components/graph/Settings.tsx
CHANGED
@@ -8,7 +8,7 @@ import Input from '@/components/ui/Input'
|
|
8 |
import { controlButtonVariant } from '@/lib/constants'
|
9 |
import { useSettingsStore } from '@/stores/settings'
|
10 |
|
11 |
-
import { SettingsIcon } from 'lucide-react'
|
12 |
import { useTranslation } from 'react-i18next';
|
13 |
|
14 |
/**
|
@@ -44,14 +44,17 @@ const LabeledNumberInput = ({
|
|
44 |
onEditFinished,
|
45 |
label,
|
46 |
min,
|
47 |
-
max
|
|
|
48 |
}: {
|
49 |
value: number
|
50 |
onEditFinished: (value: number) => void
|
51 |
label: string
|
52 |
min: number
|
53 |
max?: number
|
|
|
54 |
}) => {
|
|
|
55 |
const [currentValue, setCurrentValue] = useState<number | null>(value)
|
56 |
|
57 |
const onValueChange = useCallback(
|
@@ -81,6 +84,13 @@ const LabeledNumberInput = ({
|
|
81 |
}
|
82 |
}, [value, currentValue, onEditFinished])
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return (
|
85 |
<div className="flex flex-col gap-2">
|
86 |
<label
|
@@ -89,20 +99,34 @@ const LabeledNumberInput = ({
|
|
89 |
>
|
90 |
{label}
|
91 |
</label>
|
92 |
-
<
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
</div>
|
107 |
)
|
108 |
}
|
@@ -121,7 +145,7 @@ export default function Settings() {
|
|
121 |
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
122 |
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
123 |
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
124 |
-
const
|
125 |
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
126 |
|
127 |
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
@@ -180,15 +204,14 @@ export default function Settings() {
|
|
180 |
}, 300)
|
181 |
}, [])
|
182 |
|
183 |
-
const
|
184 |
-
if (
|
185 |
-
useSettingsStore.setState({
|
186 |
const currentLabel = useSettingsStore.getState().queryLabel
|
187 |
useSettingsStore.getState().setQueryLabel('')
|
188 |
setTimeout(() => {
|
189 |
useSettingsStore.getState().setQueryLabel(currentLabel)
|
190 |
}, 300)
|
191 |
-
|
192 |
}, [])
|
193 |
|
194 |
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
@@ -274,19 +297,23 @@ export default function Settings() {
|
|
274 |
label={t('graphPanel.sideBar.settings.maxQueryDepth')}
|
275 |
min={1}
|
276 |
value={graphQueryMaxDepth}
|
|
|
277 |
onEditFinished={setGraphQueryMaxDepth}
|
278 |
/>
|
279 |
<LabeledNumberInput
|
280 |
-
label={t('graphPanel.sideBar.settings.
|
281 |
-
min={
|
282 |
-
|
283 |
-
|
|
|
|
|
284 |
/>
|
285 |
<LabeledNumberInput
|
286 |
label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
|
287 |
min={1}
|
288 |
max={30}
|
289 |
value={graphLayoutMaxIterations}
|
|
|
290 |
onEditFinished={setGraphLayoutMaxIterations}
|
291 |
/>
|
292 |
<Separator />
|
|
|
8 |
import { controlButtonVariant } from '@/lib/constants'
|
9 |
import { useSettingsStore } from '@/stores/settings'
|
10 |
|
11 |
+
import { SettingsIcon, Undo2 } from 'lucide-react'
|
12 |
import { useTranslation } from 'react-i18next';
|
13 |
|
14 |
/**
|
|
|
44 |
onEditFinished,
|
45 |
label,
|
46 |
min,
|
47 |
+
max,
|
48 |
+
defaultValue
|
49 |
}: {
|
50 |
value: number
|
51 |
onEditFinished: (value: number) => void
|
52 |
label: string
|
53 |
min: number
|
54 |
max?: number
|
55 |
+
defaultValue?: number
|
56 |
}) => {
|
57 |
+
const { t } = useTranslation();
|
58 |
const [currentValue, setCurrentValue] = useState<number | null>(value)
|
59 |
|
60 |
const onValueChange = useCallback(
|
|
|
84 |
}
|
85 |
}, [value, currentValue, onEditFinished])
|
86 |
|
87 |
+
const handleReset = useCallback(() => {
|
88 |
+
if (defaultValue !== undefined && value !== defaultValue) {
|
89 |
+
setCurrentValue(defaultValue)
|
90 |
+
onEditFinished(defaultValue)
|
91 |
+
}
|
92 |
+
}, [defaultValue, value, onEditFinished])
|
93 |
+
|
94 |
return (
|
95 |
<div className="flex flex-col gap-2">
|
96 |
<label
|
|
|
99 |
>
|
100 |
{label}
|
101 |
</label>
|
102 |
+
<div className="flex items-center gap-1">
|
103 |
+
<Input
|
104 |
+
type="number"
|
105 |
+
value={currentValue === null ? '' : currentValue}
|
106 |
+
onChange={onValueChange}
|
107 |
+
className="h-6 w-full min-w-0 pr-1"
|
108 |
+
min={min}
|
109 |
+
max={max}
|
110 |
+
onBlur={onBlur}
|
111 |
+
onKeyDown={(e) => {
|
112 |
+
if (e.key === 'Enter') {
|
113 |
+
onBlur()
|
114 |
+
}
|
115 |
+
}}
|
116 |
+
/>
|
117 |
+
{defaultValue !== undefined && (
|
118 |
+
<Button
|
119 |
+
variant="ghost"
|
120 |
+
size="icon"
|
121 |
+
className="h-6 w-6 flex-shrink-0 hover:bg-muted text-muted-foreground hover:text-foreground"
|
122 |
+
onClick={handleReset}
|
123 |
+
type="button"
|
124 |
+
title={t('graphPanel.sideBar.settings.resetToDefault')}
|
125 |
+
>
|
126 |
+
<Undo2 className="h-3.5 w-3.5" />
|
127 |
+
</Button>
|
128 |
+
)}
|
129 |
+
</div>
|
130 |
</div>
|
131 |
)
|
132 |
}
|
|
|
145 |
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
146 |
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
147 |
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
148 |
+
const graphMaxNodes = useSettingsStore.use.graphMaxNodes()
|
149 |
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
150 |
|
151 |
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
|
|
204 |
}, 300)
|
205 |
}, [])
|
206 |
|
207 |
+
const setGraphMaxNodes = useCallback((nodes: number) => {
|
208 |
+
if (nodes < 1 || nodes > 1000) return
|
209 |
+
useSettingsStore.setState({ graphMaxNodes: nodes })
|
210 |
const currentLabel = useSettingsStore.getState().queryLabel
|
211 |
useSettingsStore.getState().setQueryLabel('')
|
212 |
setTimeout(() => {
|
213 |
useSettingsStore.getState().setQueryLabel(currentLabel)
|
214 |
}, 300)
|
|
|
215 |
}, [])
|
216 |
|
217 |
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
|
|
297 |
label={t('graphPanel.sideBar.settings.maxQueryDepth')}
|
298 |
min={1}
|
299 |
value={graphQueryMaxDepth}
|
300 |
+
defaultValue={3}
|
301 |
onEditFinished={setGraphQueryMaxDepth}
|
302 |
/>
|
303 |
<LabeledNumberInput
|
304 |
+
label={t('graphPanel.sideBar.settings.maxNodes')}
|
305 |
+
min={1}
|
306 |
+
max={1000}
|
307 |
+
value={graphMaxNodes}
|
308 |
+
defaultValue={1000}
|
309 |
+
onEditFinished={setGraphMaxNodes}
|
310 |
/>
|
311 |
<LabeledNumberInput
|
312 |
label={t('graphPanel.sideBar.settings.maxLayoutIterations')}
|
313 |
min={1}
|
314 |
max={30}
|
315 |
value={graphLayoutMaxIterations}
|
316 |
+
defaultValue={15}
|
317 |
onEditFinished={setGraphLayoutMaxIterations}
|
318 |
/>
|
319 |
<Separator />
|