ArnoChen commited on
Commit
023166e
·
2 Parent(s): f456d4f 85fc9f9

Merge branch 'main' into light-webui

Browse files
.dockerignore CHANGED
@@ -1 +1,63 @@
1
- .env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-related files and directories
2
+ __pycache__
3
+ .cache
4
+
5
+ # Virtual environment directories
6
+ *.venv
7
+
8
+ # Env
9
+ env/
10
+ *.env*
11
+ .env_example
12
+
13
+ # Distribution / build files
14
+ site
15
+ dist/
16
+ build/
17
+ .eggs/
18
+ *.egg-info/
19
+ *.tgz
20
+ *.tar.gz
21
+
22
+ # Exclude siles and folders
23
+ *.yml
24
+ .dockerignore
25
+ Dockerfile
26
+ Makefile
27
+
28
+ # Exclude other projects
29
+ /tests
30
+ /scripts
31
+
32
+ # Python version manager file
33
+ .python-version
34
+
35
+ # Reports
36
+ *.coverage/
37
+ *.log
38
+ log/
39
+ *.logfire
40
+
41
+ # Cache
42
+ .cache/
43
+ .mypy_cache
44
+ .pytest_cache
45
+ .ruff_cache
46
+ .gradio
47
+ .logfire
48
+ temp/
49
+
50
+ # MacOS-related files
51
+ .DS_Store
52
+
53
+ # VS Code settings (local configuration files)
54
+ .vscode
55
+
56
+ # file
57
+ TODO.md
58
+
59
+ # Exclude Git-related files
60
+ .git
61
+ .github
62
+ .gitignore
63
+ .pre-commit-config.yaml
.env.example CHANGED
@@ -1,19 +1,20 @@
1
  ### Server Configuration
2
- #HOST=0.0.0.0
3
- #PORT=9621
4
- #NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
 
5
 
6
  ### Optional SSL Configuration
7
- #SSL=true
8
- #SSL_CERTFILE=/path/to/cert.pem
9
- #SSL_KEYFILE=/path/to/key.pem
10
 
11
  ### Security (empty for no api-key is needed)
12
  # LIGHTRAG_API_KEY=your-secure-api-key-here
13
 
14
  ### Directory Configuration
15
- # WORKING_DIR=./rag_storage
16
- # INPUT_DIR=./inputs
17
 
18
  ### Logging level
19
  LOG_LEVEL=INFO
 
1
  ### Server Configuration
2
+ # HOST=0.0.0.0
3
+ # PORT=9621
4
+ # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
5
+ # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
6
 
7
  ### Optional SSL Configuration
8
+ # SSL=true
9
+ # SSL_CERTFILE=/path/to/cert.pem
10
+ # SSL_KEYFILE=/path/to/key.pem
11
 
12
  ### Security (empty for no api-key is needed)
13
  # LIGHTRAG_API_KEY=your-secure-api-key-here
14
 
15
  ### Directory Configuration
16
+ # WORKING_DIR=<absolute_path_for_working_dir>
17
+ # INPUT_DIR=<absolute_path_for_doc_input_dir>
18
 
19
  ### Logging level
20
  LOG_LEVEL=INFO
.gitignore CHANGED
@@ -1,26 +1,61 @@
1
- __pycache__
2
- *.egg-info
3
- dickens/
4
- book.txt
5
- lightrag-dev/
6
- .idea/
7
- dist/
 
 
 
 
8
  env/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  local_neo4jWorkDir/
10
  neo4jWorkDir/
11
- ignore_this.txt
12
- .venv/
13
- *.ignore.*
14
- .ruff_cache/
15
- gui/
16
- *.log
17
- .vscode
18
- inputs
19
- rag_storage
20
- .env
21
- venv/
22
  examples/input/
23
  examples/output/
 
 
24
  .DS_Store
25
- #Remove config.ini from repo
26
- *.ini
 
 
 
 
 
 
 
 
1
+ # Python-related files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ .eggs/
6
+ *.tgz
7
+ *.tar.gz
8
+ *.ini # Remove config.ini from repo
9
+
10
+ # Virtual Environment
11
+ .venv/
12
  env/
13
+ venv/
14
+ *.env*
15
+ .env_example
16
+
17
+ # Build / Distribution
18
+ dist/
19
+ build/
20
+ site/
21
+
22
+ # Logs / Reports
23
+ *.log
24
+ *.logfire
25
+ *.coverage/
26
+ log/
27
+
28
+ # Caches
29
+ .cache/
30
+ .mypy_cache/
31
+ .pytest_cache/
32
+ .ruff_cache/
33
+ .gradio/
34
+ temp/
35
+
36
+ # IDE / Editor Files
37
+ .idea/
38
+ .vscode/
39
+ .vscode/settings.json
40
+
41
+ # Framework-specific files
42
  local_neo4jWorkDir/
43
  neo4jWorkDir/
44
+
45
+ # Data & Storage
46
+ inputs/
47
+ rag_storage/
 
 
 
 
 
 
 
48
  examples/input/
49
  examples/output/
50
+
51
+ # Miscellaneous
52
  .DS_Store
53
+ TODO.md
54
+ ignore_this.txt
55
+ *.ignore.*
56
+
57
+ # Project-specific files
58
+ dickens/
59
+ book.txt
60
+ lightrag-dev/
61
+ gui/
README.md CHANGED
@@ -237,7 +237,7 @@ rag = LightRAG(
237
 
238
  * If you want to use Hugging Face models, you only need to set LightRAG as follows:
239
  ```python
240
- from lightrag.llm import hf_model_complete, hf_embedding
241
  from transformers import AutoModel, AutoTokenizer
242
  from lightrag.utils import EmbeddingFunc
243
 
@@ -250,7 +250,7 @@ rag = LightRAG(
250
  embedding_func=EmbeddingFunc(
251
  embedding_dim=384,
252
  max_token_size=5000,
253
- func=lambda texts: hf_embedding(
254
  texts,
255
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
256
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@@ -428,9 +428,9 @@ And using a routine to process news documents.
428
 
429
  ```python
430
  rag = LightRAG(..)
431
- await rag.apipeline_enqueue_documents(string_or_strings)
432
  # Your routine in loop
433
- await rag.apipeline_process_enqueue_documents(string_or_strings)
434
  ```
435
 
436
  ### Separate Keyword Extraction
 
237
 
238
  * If you want to use Hugging Face models, you only need to set LightRAG as follows:
239
  ```python
240
+ from lightrag.llm import hf_model_complete, hf_embed
241
  from transformers import AutoModel, AutoTokenizer
242
  from lightrag.utils import EmbeddingFunc
243
 
 
250
  embedding_func=EmbeddingFunc(
251
  embedding_dim=384,
252
  max_token_size=5000,
253
+ func=lambda texts: hf_embed(
254
  texts,
255
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
256
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
 
428
 
429
  ```python
430
  rag = LightRAG(..)
431
+ await rag.apipeline_enqueue_documents(input)
432
  # Your routine in loop
433
+ await rag.apipeline_process_enqueue_documents(input)
434
  ```
435
 
436
  ### Separate Keyword Extraction
examples/lightrag_oracle_demo.py CHANGED
@@ -113,7 +113,24 @@ async def main():
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
- rag.set_storage_client(db_client=oracle_db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Extract and Insert into LightRAG storage
119
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
 
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
+
117
+ for storage in [
118
+ rag.vector_db_storage_cls,
119
+ rag.graph_storage_cls,
120
+ rag.doc_status,
121
+ rag.full_docs,
122
+ rag.text_chunks,
123
+ rag.llm_response_cache,
124
+ rag.key_string_value_json_storage_cls,
125
+ rag.chunks_vdb,
126
+ rag.relationships_vdb,
127
+ rag.entities_vdb,
128
+ rag.graph_storage_cls,
129
+ rag.chunk_entity_relation_graph,
130
+ rag.llm_response_cache,
131
+ ]:
132
+ # set client
133
+ storage.db = oracle_db
134
 
135
  # Extract and Insert into LightRAG storage
136
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
examples/test_chromadb.py CHANGED
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
 
 
 
 
 
 
18
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
19
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
20
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +66,50 @@ async def create_embedding_function_instance():
60
 
61
  async def initialize_rag():
62
  embedding_func_instance = await create_embedding_function_instance()
63
-
64
- return LightRAG(
65
- working_dir=WORKING_DIR,
66
- llm_model_func=gpt_4o_mini_complete,
67
- embedding_func=embedding_func_instance,
68
- vector_storage="ChromaVectorDBStorage",
69
- log_level="DEBUG",
70
- embedding_batch_num=32,
71
- vector_db_storage_cls_kwargs={
72
- "host": CHROMADB_HOST,
73
- "port": CHROMADB_PORT,
74
- "auth_token": CHROMADB_AUTH_TOKEN,
75
- "auth_provider": CHROMADB_AUTH_PROVIDER,
76
- "auth_header_name": CHROMADB_AUTH_HEADER,
77
- "collection_settings": {
78
- "hnsw:space": "cosine",
79
- "hnsw:construction_ef": 128,
80
- "hnsw:search_ef": 128,
81
- "hnsw:M": 16,
82
- "hnsw:batch_size": 100,
83
- "hnsw:sync_threshold": 1000,
84
  },
85
- },
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  # Run the initialization
 
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
18
+ CHROMADB_USE_LOCAL_PERSISTENT = False
19
+ # Local PersistentClient Configuration
20
+ CHROMADB_LOCAL_PATH = os.environ.get(
21
+ "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
22
+ )
23
+ # Remote HttpClient Configuration
24
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
25
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
26
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
 
66
 
67
  async def initialize_rag():
68
  embedding_func_instance = await create_embedding_function_instance()
69
+ if CHROMADB_USE_LOCAL_PERSISTENT:
70
+ return LightRAG(
71
+ working_dir=WORKING_DIR,
72
+ llm_model_func=gpt_4o_mini_complete,
73
+ embedding_func=embedding_func_instance,
74
+ vector_storage="ChromaVectorDBStorage",
75
+ log_level="DEBUG",
76
+ embedding_batch_num=32,
77
+ vector_db_storage_cls_kwargs={
78
+ "local_path": CHROMADB_LOCAL_PATH,
79
+ "collection_settings": {
80
+ "hnsw:space": "cosine",
81
+ "hnsw:construction_ef": 128,
82
+ "hnsw:search_ef": 128,
83
+ "hnsw:M": 16,
84
+ "hnsw:batch_size": 100,
85
+ "hnsw:sync_threshold": 1000,
86
+ },
 
 
 
87
  },
88
+ )
89
+ else:
90
+ return LightRAG(
91
+ working_dir=WORKING_DIR,
92
+ llm_model_func=gpt_4o_mini_complete,
93
+ embedding_func=embedding_func_instance,
94
+ vector_storage="ChromaVectorDBStorage",
95
+ log_level="DEBUG",
96
+ embedding_batch_num=32,
97
+ vector_db_storage_cls_kwargs={
98
+ "host": CHROMADB_HOST,
99
+ "port": CHROMADB_PORT,
100
+ "auth_token": CHROMADB_AUTH_TOKEN,
101
+ "auth_provider": CHROMADB_AUTH_PROVIDER,
102
+ "auth_header_name": CHROMADB_AUTH_HEADER,
103
+ "collection_settings": {
104
+ "hnsw:space": "cosine",
105
+ "hnsw:construction_ef": 128,
106
+ "hnsw:search_ef": 128,
107
+ "hnsw:M": 16,
108
+ "hnsw:batch_size": 100,
109
+ "hnsw:sync_threshold": 1000,
110
+ },
111
+ },
112
+ )
113
 
114
 
115
  # Run the initialization
external_bindings/OpenWebuiTool/openwebui_tool.py DELETED
@@ -1,358 +0,0 @@
1
- """
2
- OpenWebui Lightrag Integration Tool
3
- ==================================
4
-
5
- This tool enables the integration and use of Lightrag within the OpenWebui environment,
6
- providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
7
-
8
- Author: ParisNeo ([email protected])
9
- Social:
10
- - Twitter: @ParisNeo_AI
11
- - Reddit: r/lollms
12
- - Instagram: https://www.instagram.com/parisneo_ai/
13
-
14
- License: Apache 2.0
15
- Copyright (c) 2024-2025 ParisNeo
16
-
17
- This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
18
- For more information, visit: https://github.com/ParisNeo/lollms
19
-
20
- Requirements:
21
- - Python 3.8+
22
- - OpenWebui
23
- - Lightrag
24
- """
25
-
26
- # Tool version
27
- __version__ = "1.0.0"
28
- __author__ = "ParisNeo"
29
- __author_email__ = "[email protected]"
30
- __description__ = "Lightrag integration for OpenWebui"
31
-
32
-
33
- import requests
34
- import json
35
- from pydantic import BaseModel, Field
36
- from typing import Callable, Any, Literal, Union, List, Tuple
37
-
38
-
39
- class StatusEventEmitter:
40
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
41
- self.event_emitter = event_emitter
42
-
43
- async def emit(self, description="Unknown State", status="in_progress", done=False):
44
- if self.event_emitter:
45
- await self.event_emitter(
46
- {
47
- "type": "status",
48
- "data": {
49
- "status": status,
50
- "description": description,
51
- "done": done,
52
- },
53
- }
54
- )
55
-
56
-
57
- class MessageEventEmitter:
58
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
59
- self.event_emitter = event_emitter
60
-
61
- async def emit(self, content="Some message"):
62
- if self.event_emitter:
63
- await self.event_emitter(
64
- {
65
- "type": "message",
66
- "data": {
67
- "content": content,
68
- },
69
- }
70
- )
71
-
72
-
73
- class Tools:
74
- class Valves(BaseModel):
75
- LIGHTRAG_SERVER_URL: str = Field(
76
- default="http://localhost:9621/query",
77
- description="The base URL for the LightRag server",
78
- )
79
- MODE: Literal["naive", "local", "global", "hybrid"] = Field(
80
- default="hybrid",
81
- description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
82
- )
83
- ONLY_NEED_CONTEXT: bool = Field(
84
- default=False,
85
- description="If True, only the context is needed from the LightRag response",
86
- )
87
- DEBUG_MODE: bool = Field(
88
- default=False,
89
- description="If True, debugging information will be emitted",
90
- )
91
- KEY: str = Field(
92
- default="",
93
- description="Optional Bearer Key for authentication",
94
- )
95
- MAX_ENTITIES: int = Field(
96
- default=5,
97
- description="Maximum number of entities to keep",
98
- )
99
- MAX_RELATIONSHIPS: int = Field(
100
- default=5,
101
- description="Maximum number of relationships to keep",
102
- )
103
- MAX_SOURCES: int = Field(
104
- default=3,
105
- description="Maximum number of sources to keep",
106
- )
107
-
108
- def __init__(self):
109
- self.valves = self.Valves()
110
- self.headers = {
111
- "Content-Type": "application/json",
112
- "User-Agent": "LightRag-Tool/1.0",
113
- }
114
-
115
- async def query_lightrag(
116
- self,
117
- query: str,
118
- __event_emitter__: Callable[[dict], Any] = None,
119
- ) -> str:
120
- """
121
- Query the LightRag server and retrieve information.
122
- This function must be called before answering the user question
123
- :params query: The query string to send to the LightRag server.
124
- :return: The response from the LightRag server in Markdown format or raw response.
125
- """
126
- self.status_emitter = StatusEventEmitter(__event_emitter__)
127
- self.message_emitter = MessageEventEmitter(__event_emitter__)
128
-
129
- lightrag_url = self.valves.LIGHTRAG_SERVER_URL
130
- payload = {
131
- "query": query,
132
- "mode": str(self.valves.MODE),
133
- "stream": False,
134
- "only_need_context": self.valves.ONLY_NEED_CONTEXT,
135
- }
136
- await self.status_emitter.emit("Initializing Lightrag query..")
137
-
138
- if self.valves.DEBUG_MODE:
139
- await self.message_emitter.emit(
140
- "### Debug Mode Active\n\nDebugging information will be displayed.\n"
141
- )
142
- await self.message_emitter.emit(
143
- "#### Payload Sent to LightRag Server\n```json\n"
144
- + json.dumps(payload, indent=4)
145
- + "\n```\n"
146
- )
147
-
148
- # Add Bearer Key to headers if provided
149
- if self.valves.KEY:
150
- self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
151
-
152
- try:
153
- await self.status_emitter.emit("Sending request to LightRag server")
154
-
155
- response = requests.post(
156
- lightrag_url, json=payload, headers=self.headers, timeout=120
157
- )
158
- response.raise_for_status()
159
- data = response.json()
160
- await self.status_emitter.emit(
161
- status="complete",
162
- description="LightRag query Succeeded",
163
- done=True,
164
- )
165
-
166
- # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
167
- if self.valves.ONLY_NEED_CONTEXT:
168
- try:
169
- if self.valves.DEBUG_MODE:
170
- await self.message_emitter.emit(
171
- "#### LightRag Server Response\n```json\n"
172
- + data["response"]
173
- + "\n```\n"
174
- )
175
- except Exception as ex:
176
- if self.valves.DEBUG_MODE:
177
- await self.message_emitter.emit(
178
- "#### Exception\n" + str(ex) + "\n"
179
- )
180
- return f"Exception: {ex}"
181
- return data["response"]
182
- else:
183
- if self.valves.DEBUG_MODE:
184
- await self.message_emitter.emit(
185
- "#### LightRag Server Response\n```json\n"
186
- + data["response"]
187
- + "\n```\n"
188
- )
189
- await self.status_emitter.emit("Lightrag query success")
190
- return data["response"]
191
-
192
- except requests.exceptions.RequestException as e:
193
- await self.status_emitter.emit(
194
- status="error",
195
- description=f"Error during LightRag query: {str(e)}",
196
- done=True,
197
- )
198
- return json.dumps({"error": str(e)})
199
-
200
- def extract_code_blocks(
201
- self, text: str, return_remaining_text: bool = False
202
- ) -> Union[List[dict], Tuple[List[dict], str]]:
203
- """
204
- This function extracts code blocks from a given text and optionally returns the text without code blocks.
205
-
206
- Parameters:
207
- text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
208
- return_remaining_text (bool): If True, also returns the text with code blocks removed.
209
-
210
- Returns:
211
- Union[List[dict], Tuple[List[dict], str]]:
212
- - If return_remaining_text is False: Returns only the list of code block dictionaries
213
- - If return_remaining_text is True: Returns a tuple containing:
214
- * List of code block dictionaries
215
- * String containing the text with all code blocks removed
216
-
217
- Each code block dictionary contains:
218
- - 'index' (int): The index of the code block in the text
219
- - 'file_name' (str): The name of the file extracted from the preceding line, if available
220
- - 'content' (str): The content of the code block
221
- - 'type' (str): The type of the code block
222
- - 'is_complete' (bool): True if the block has a closing tag, False otherwise
223
- """
224
- remaining = text
225
- bloc_index = 0
226
- first_index = 0
227
- indices = []
228
- text_without_blocks = text
229
-
230
- # Find all code block delimiters
231
- while len(remaining) > 0:
232
- try:
233
- index = remaining.index("```")
234
- indices.append(index + first_index)
235
- remaining = remaining[index + 3 :]
236
- first_index += index + 3
237
- bloc_index += 1
238
- except Exception:
239
- if bloc_index % 2 == 1:
240
- index = len(remaining)
241
- indices.append(index)
242
- remaining = ""
243
-
244
- code_blocks = []
245
- is_start = True
246
-
247
- # Process code blocks and build text without blocks if requested
248
- if return_remaining_text:
249
- text_parts = []
250
- last_end = 0
251
-
252
- for index, code_delimiter_position in enumerate(indices):
253
- if is_start:
254
- block_infos = {
255
- "index": len(code_blocks),
256
- "file_name": "",
257
- "section": "",
258
- "content": "",
259
- "type": "",
260
- "is_complete": False,
261
- }
262
-
263
- # Store text before code block if returning remaining text
264
- if return_remaining_text:
265
- text_parts.append(text[last_end:code_delimiter_position].strip())
266
-
267
- # Check the preceding line for file name
268
- preceding_text = text[:code_delimiter_position].strip().splitlines()
269
- if preceding_text:
270
- last_line = preceding_text[-1].strip()
271
- if last_line.startswith("<file_name>") and last_line.endswith(
272
- "</file_name>"
273
- ):
274
- file_name = last_line[
275
- len("<file_name>") : -len("</file_name>")
276
- ].strip()
277
- block_infos["file_name"] = file_name
278
- elif last_line.startswith("## filename:"):
279
- file_name = last_line[len("## filename:") :].strip()
280
- block_infos["file_name"] = file_name
281
- if last_line.startswith("<section>") and last_line.endswith(
282
- "</section>"
283
- ):
284
- section = last_line[
285
- len("<section>") : -len("</section>")
286
- ].strip()
287
- block_infos["section"] = section
288
-
289
- sub_text = text[code_delimiter_position + 3 :]
290
- if len(sub_text) > 0:
291
- try:
292
- find_space = sub_text.index(" ")
293
- except Exception:
294
- find_space = int(1e10)
295
- try:
296
- find_return = sub_text.index("\n")
297
- except Exception:
298
- find_return = int(1e10)
299
- next_index = min(find_return, find_space)
300
- if "{" in sub_text[:next_index]:
301
- next_index = 0
302
- start_pos = next_index
303
-
304
- if code_delimiter_position + 3 < len(text) and text[
305
- code_delimiter_position + 3
306
- ] in ["\n", " ", "\t"]:
307
- block_infos["type"] = "language-specific"
308
- else:
309
- block_infos["type"] = sub_text[:next_index]
310
-
311
- if index + 1 < len(indices):
312
- next_pos = indices[index + 1] - code_delimiter_position
313
- if (
314
- next_pos - 3 < len(sub_text)
315
- and sub_text[next_pos - 3] == "`"
316
- ):
317
- block_infos["content"] = sub_text[
318
- start_pos : next_pos - 3
319
- ].strip()
320
- block_infos["is_complete"] = True
321
- else:
322
- block_infos["content"] = sub_text[
323
- start_pos:next_pos
324
- ].strip()
325
- block_infos["is_complete"] = False
326
-
327
- if return_remaining_text:
328
- last_end = indices[index + 1] + 3
329
- else:
330
- block_infos["content"] = sub_text[start_pos:].strip()
331
- block_infos["is_complete"] = False
332
-
333
- if return_remaining_text:
334
- last_end = len(text)
335
-
336
- code_blocks.append(block_infos)
337
- is_start = False
338
- else:
339
- is_start = True
340
-
341
- if return_remaining_text:
342
- # Add any remaining text after the last code block
343
- if last_end < len(text):
344
- text_parts.append(text[last_end:].strip())
345
- # Join all non-code parts with newlines
346
- text_without_blocks = "\n".join(filter(None, text_parts))
347
- return code_blocks, text_without_blocks
348
-
349
- return code_blocks
350
-
351
- def clean(self, csv_content: str):
352
- lines = csv_content.splitlines()
353
- if lines:
354
- # Remove spaces around headers and ensure no spaces between commas
355
- header = ",".join([col.strip() for col in lines[0].split(",")])
356
- lines[0] = header # Replace the first line with the cleaned header
357
- csv_content = "\n".join(lines)
358
- return csv_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/api/README.md CHANGED
@@ -74,30 +74,38 @@ LLM_MODEL=model_name_of_azure_ai
74
  LLM_BINDING_API_KEY=api_key_of_azure_ai
75
  ```
76
 
77
- ### About Ollama API
78
 
79
- We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
80
 
81
- #### Choose Query mode in chat
 
 
 
 
82
 
83
- A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
84
 
85
- ```
86
- /local
87
- /global
88
- /hybrid
89
- /naive
90
- /mix
91
- /bypass
92
  ```
93
 
94
- For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
95
 
96
- "/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the LightRAG query results. (If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix)
 
 
 
 
 
 
97
 
98
- #### Connect Open WebUI to LightRAG
99
 
100
- After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
101
 
102
  ## Configuration
103
 
@@ -177,7 +185,8 @@ TiDBVectorDBStorage TiDB
177
  PGVectorStorage Postgres
178
  FaissVectorDBStorage Faiss
179
  QdrantVectorDBStorage Qdrant
180
- OracleVectorDBStorag Oracle
 
181
  ```
182
 
183
  * DOC_STATUS_STORAGE:supported implement-name
@@ -378,7 +387,7 @@ curl -X DELETE "http://localhost:9621/documents"
378
 
379
  #### GET /api/version
380
 
381
- Get Ollama version information
382
 
383
  ```bash
384
  curl http://localhost:9621/api/version
@@ -386,7 +395,7 @@ curl http://localhost:9621/api/version
386
 
387
  #### GET /api/tags
388
 
389
- Get Ollama available models
390
 
391
  ```bash
392
  curl http://localhost:9621/api/tags
@@ -394,7 +403,7 @@ curl http://localhost:9621/api/tags
394
 
395
  #### POST /api/chat
396
 
397
- Handle chat completion requests
398
 
399
  ```shell
400
  curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \
@@ -403,6 +412,10 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso
403
 
404
  > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
405
 
 
 
 
 
406
  ### Utility Endpoints
407
 
408
  #### GET /health
@@ -412,7 +425,35 @@ Check server health and configuration.
412
  curl "http://localhost:9621/health"
413
  ```
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  ## Development
 
416
  Contribute to the project: [Guide](contributor-readme.MD)
417
 
418
  ### Running in Development Mode
@@ -470,34 +511,3 @@ This intelligent caching mechanism:
470
  - Only new documents in the input directory will be processed
471
  - This optimization significantly reduces startup time for subsequent runs
472
  - The working directory (`--working-dir`) stores the vectorized documents database
473
-
474
- ## Install Lightrag as a Linux Service
475
-
476
- Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file:
477
-
478
- ```text
479
- Description=LightRAG Ollama Service
480
- WorkingDirectory=<lightrag installed directory>
481
- ExecStart=<lightrag installed directory>/lightrag/api/lightrag-api
482
- ```
483
-
484
- Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed:
485
-
486
- ```shell
487
- #!/bin/bash
488
-
489
- # your python virtual environment activation
490
- source /home/netman/lightrag-xyj/venv/bin/activate
491
- # start lightrag api server
492
- lightrag-server
493
- ```
494
-
495
- Install LightRAG service. If your system is Ubuntu, the following commands will work:
496
-
497
- ```shell
498
- sudo cp lightrag.service /etc/systemd/system/
499
- sudo systemctl daemon-reload
500
- sudo systemctl start lightrag.service
501
- sudo systemctl status lightrag.service
502
- sudo systemctl enable lightrag.service
503
- ```
 
74
  LLM_BINDING_API_KEY=api_key_of_azure_ai
75
  ```
76
 
77
+ ### 3. Install Lightrag as a Linux Service
78
 
79
+ Create a your service file `lightrag.sevice` from the sample file : `lightrag.sevice.example`. Modified the WorkingDirectoryand EexecStart in the service file:
80
 
81
+ ```text
82
+ Description=LightRAG Ollama Service
83
+ WorkingDirectory=<lightrag installed directory>
84
+ ExecStart=<lightrag installed directory>/lightrag/api/lightrag-api
85
+ ```
86
 
87
+ Modify your service startup script: `lightrag-api`. Change you python virtual environment activation command as needed:
88
 
89
+ ```shell
90
+ #!/bin/bash
91
+
92
+ # your python virtual environment activation
93
+ source /home/netman/lightrag-xyj/venv/bin/activate
94
+ # start lightrag api server
95
+ lightrag-server
96
  ```
97
 
98
+ Install LightRAG service. If your system is Ubuntu, the following commands will work:
99
 
100
+ ```shell
101
+ sudo cp lightrag.service /etc/systemd/system/
102
+ sudo systemctl daemon-reload
103
+ sudo systemctl start lightrag.service
104
+ sudo systemctl status lightrag.service
105
+ sudo systemctl enable lightrag.service
106
+ ```
107
 
 
108
 
 
109
 
110
  ## Configuration
111
 
 
185
  PGVectorStorage Postgres
186
  FaissVectorDBStorage Faiss
187
  QdrantVectorDBStorage Qdrant
188
+ OracleVectorDBStorage Oracle
189
+ MongoVectorDBStorage MongoDB
190
  ```
191
 
192
  * DOC_STATUS_STORAGE:supported implement-name
 
387
 
388
  #### GET /api/version
389
 
390
+ Get Ollama version information.
391
 
392
  ```bash
393
  curl http://localhost:9621/api/version
 
395
 
396
  #### GET /api/tags
397
 
398
+ Get Ollama available models.
399
 
400
  ```bash
401
  curl http://localhost:9621/api/tags
 
403
 
404
  #### POST /api/chat
405
 
406
+ Handle chat completion requests. Routes user queries through LightRAG by selecting query mode based on query prefix. Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to underlying LLM.
407
 
408
  ```shell
409
  curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/json" -d \
 
412
 
413
  > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
414
 
415
+ #### POST /api/generate
416
+
417
+ Handle generate completion requests. For compatibility purpose, the request is not processed by LightRAG, and will be handled by underlying LLM model.
418
+
419
  ### Utility Endpoints
420
 
421
  #### GET /health
 
425
  curl "http://localhost:9621/health"
426
  ```
427
 
428
+ ## Ollama Emulation
429
+
430
+ We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
431
+
432
+ ### Connect Open WebUI to LightRAG
433
+
434
+ After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. You'd better install LightRAG as service for this use case.
435
+
436
+ Open WebUI's use LLM to do the session title and session keyword generation task. So the Ollama chat chat completion API detects and forwards OpenWebUI session-related requests directly to underlying LLM.
437
+
438
+ ### Choose Query mode in chat
439
+
440
+ A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
441
+
442
+ ```
443
+ /local
444
+ /global
445
+ /hybrid
446
+ /naive
447
+ /mix
448
+ /bypass
449
+ ```
450
+
451
+ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
452
+
453
+ "/bypass" is not a LightRAG query mode, it will tell API Server to pass the query directly to the underlying LLM with chat history. So user can use LLM to answer question base on the chat history. If you are using Open WebUI as front end, you can just switch the model to a normal LLM instead of using /bypass prefix.
454
+
455
  ## Development
456
+
457
  Contribute to the project: [Guide](contributor-readme.MD)
458
 
459
  ### Running in Development Mode
 
511
  - Only new documents in the input directory will be processed
512
  - This optimization significantly reduces startup time for subsequent runs
513
  - The working directory (`--working-dir`) stores the vectorized documents database
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/api/lightrag_server.py CHANGED
@@ -3,7 +3,6 @@ from fastapi import (
3
  HTTPException,
4
  File,
5
  UploadFile,
6
- Form,
7
  BackgroundTasks,
8
  )
9
  import asyncio
@@ -14,7 +13,7 @@ import re
14
  from fastapi.staticfiles import StaticFiles
15
  import logging
16
  import argparse
17
- from typing import List, Any, Optional, Union, Dict
18
  from pydantic import BaseModel
19
  from lightrag import LightRAG, QueryParam
20
  from lightrag.types import GPTKeywordExtractionFormat
@@ -34,6 +33,9 @@ from starlette.status import HTTP_403_FORBIDDEN
34
  import pipmaster as pm
35
  from dotenv import load_dotenv
36
  import configparser
 
 
 
37
  from lightrag.utils import logger
38
  from .ollama_api import (
39
  OllamaAPI,
@@ -159,8 +161,12 @@ def display_splash_screen(args: argparse.Namespace) -> None:
159
  ASCIIColors.yellow(f"{args.host}")
160
  ASCIIColors.white(" ├─ Port: ", end="")
161
  ASCIIColors.yellow(f"{args.port}")
162
- ASCIIColors.white(" └─ SSL Enabled: ", end="")
 
 
163
  ASCIIColors.yellow(f"{args.ssl}")
 
 
164
  if args.ssl:
165
  ASCIIColors.white(" ├─ SSL Cert: ", end="")
166
  ASCIIColors.yellow(f"{args.ssl_certfile}")
@@ -229,10 +235,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
229
  ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
230
  ASCIIColors.white(" ├─ Log Level: ", end="")
231
  ASCIIColors.yellow(f"{args.log_level}")
232
- ASCIIColors.white(" ├─ Timeout: ", end="")
233
  ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
234
- ASCIIColors.white(" └─ API Key: ", end="")
235
- ASCIIColors.yellow("Set" if args.key else "Not Set")
236
 
237
  # Server Status
238
  ASCIIColors.green("\n✨ Server starting up...\n")
@@ -564,6 +568,10 @@ def parse_args() -> argparse.Namespace:
564
 
565
  args = parser.parse_args()
566
 
 
 
 
 
567
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
568
 
569
  return args
@@ -595,6 +603,7 @@ class DocumentManager:
595
  """Scan input directory for new files"""
596
  new_files = []
597
  for ext in self.supported_extensions:
 
598
  for file_path in self.input_dir.rglob(f"*{ext}"):
599
  if file_path not in self.indexed_files:
600
  new_files.append(file_path)
@@ -628,9 +637,47 @@ class SearchMode(str, Enum):
628
 
629
  class QueryRequest(BaseModel):
630
  query: str
 
 
631
  mode: SearchMode = SearchMode.hybrid
632
- stream: bool = False
633
- only_need_context: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
 
636
  class QueryResponse(BaseModel):
@@ -639,13 +686,38 @@ class QueryResponse(BaseModel):
639
 
640
  class InsertTextRequest(BaseModel):
641
  text: str
642
- description: Optional[str] = None
643
 
644
 
645
  class InsertResponse(BaseModel):
646
  status: str
647
  message: str
648
- document_count: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
651
  def get_api_key_dependency(api_key: Optional[str]):
@@ -659,7 +731,9 @@ def get_api_key_dependency(api_key: Optional[str]):
659
  # If API key is configured, use proper authentication
660
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
661
 
662
- async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
 
 
663
  if not api_key_header_value:
664
  raise HTTPException(
665
  status_code=HTTP_403_FORBIDDEN, detail="API Key required"
@@ -675,6 +749,7 @@ def get_api_key_dependency(api_key: Optional[str]):
675
 
676
  # Global configuration
677
  global_top_k = 60 # default value
 
678
 
679
 
680
  def create_app(args):
@@ -842,10 +917,19 @@ def create_app(args):
842
  lifespan=lifespan,
843
  )
844
 
 
 
 
 
 
 
 
 
 
845
  # Add CORS middleware
846
  app.add_middleware(
847
  CORSMiddleware,
848
- allow_origins=["*"],
849
  allow_credentials=True,
850
  allow_methods=["*"],
851
  allow_headers=["*"],
@@ -1116,79 +1200,162 @@ def create_app(args):
1116
  ("llm_response_cache", rag.llm_response_cache),
1117
  ]
1118
 
1119
- async def index_file(file_path: Union[str, Path]) -> None:
1120
- """Index all files inside the folder with support for multiple file formats
1121
 
1122
  Args:
1123
- file_path: Path to the file to be indexed (str or Path object)
1124
-
1125
- Raises:
1126
- ValueError: If file format is not supported
1127
- FileNotFoundError: If file doesn't exist
1128
  """
1129
- if not pm.is_installed("aiofiles"):
1130
- pm.install("aiofiles")
 
1131
 
1132
- # Convert to Path object if string
1133
- file_path = Path(file_path)
 
1134
 
1135
- # Check if file exists
1136
- if not file_path.exists():
1137
- raise FileNotFoundError(f"File not found: {file_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
 
1139
- content = ""
1140
- # Get file extension in lowercase
1141
- ext = file_path.suffix.lower()
 
 
 
 
 
 
 
 
 
 
1142
 
1143
- match ext:
1144
- case ".txt" | ".md":
1145
- # Text files handling
1146
- async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
1147
- content = await f.read()
1148
 
1149
- case ".pdf" | ".docx" | ".pptx" | ".xlsx":
1150
- if not pm.is_installed("docling"):
1151
- pm.install("docling")
1152
- from docling.document_converter import DocumentConverter
 
 
1153
 
1154
- async def convert_doc():
1155
- def sync_convert():
1156
- converter = DocumentConverter()
1157
- result = converter.convert(file_path)
1158
- return result.document.export_to_markdown()
1159
 
1160
- return await asyncio.to_thread(sync_convert)
 
1161
 
1162
- content = await convert_doc()
 
 
 
 
 
 
1163
 
1164
- case _:
1165
- raise ValueError(f"Unsupported file format: {ext}")
 
 
 
1166
 
1167
- # Insert content into RAG system
1168
- if content:
1169
- await rag.ainsert(content)
1170
- doc_manager.mark_as_indexed(file_path)
1171
- logging.info(f"Successfully indexed file: {file_path}")
1172
- else:
1173
- logging.warning(f"No content extracted from file: {file_path}")
1174
 
1175
- @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
1176
- async def scan_for_new_documents(background_tasks: BackgroundTasks):
1177
- """Trigger the scanning process"""
1178
- global scan_progress
1179
 
1180
- with progress_lock:
1181
- if scan_progress["is_scanning"]:
1182
- return {"status": "already_scanning"}
 
 
 
 
1183
 
1184
- scan_progress["is_scanning"] = True
1185
- scan_progress["indexed_count"] = 0
1186
- scan_progress["progress"] = 0
1187
 
1188
- # Start the scanning process in the background
1189
- background_tasks.add_task(run_scanning_process)
1190
 
1191
- return {"status": "scanning_started"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1192
 
1193
  async def run_scanning_process():
1194
  """Background task to scan and index documents"""
@@ -1198,12 +1365,13 @@ def create_app(args):
1198
  new_files = doc_manager.scan_directory_for_new_files()
1199
  scan_progress["total_files"] = len(new_files)
1200
 
 
1201
  for file_path in new_files:
1202
  try:
1203
  with progress_lock:
1204
  scan_progress["current_file"] = os.path.basename(file_path)
1205
 
1206
- await index_file(file_path)
1207
 
1208
  with progress_lock:
1209
  scan_progress["indexed_count"] += 1
@@ -1221,6 +1389,24 @@ def create_app(args):
1221
  with progress_lock:
1222
  scan_progress["is_scanning"] = False
1223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1224
  @app.get("/documents/scan-progress")
1225
  async def get_scan_progress():
1226
  """Get the current scanning progress"""
@@ -1228,7 +1414,9 @@ def create_app(args):
1228
  return scan_progress
1229
 
1230
  @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
1231
- async def upload_to_input_dir(file: UploadFile = File(...)):
 
 
1232
  """
1233
  Endpoint for uploading a file to the input directory and indexing it.
1234
 
@@ -1237,6 +1425,7 @@ def create_app(args):
1237
  indexes it for retrieval, and returns a success status with relevant details.
1238
 
1239
  Parameters:
 
1240
  file (UploadFile): The file to be uploaded. It must have an allowed extension as per
1241
  `doc_manager.supported_extensions`.
1242
 
@@ -1261,124 +1450,16 @@ def create_app(args):
1261
  with open(file_path, "wb") as buffer:
1262
  shutil.copyfileobj(file.file, buffer)
1263
 
1264
- # Immediately index the uploaded file
1265
- await index_file(file_path)
1266
-
1267
- return {
1268
- "status": "success",
1269
- "message": f"File uploaded and indexed: {file.filename}",
1270
- "total_documents": len(doc_manager.indexed_files),
1271
- }
1272
- except Exception as e:
1273
- raise HTTPException(status_code=500, detail=str(e))
1274
-
1275
- @app.post(
1276
- "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
1277
- )
1278
- async def query_text(request: QueryRequest):
1279
- """
1280
- Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
1281
-
1282
- Parameters:
1283
- request (QueryRequest): A Pydantic model containing the following fields:
1284
- - query (str): The text of the user's query.
1285
- - mode (ModeEnum): Optional. Specifies the mode of retrieval augmentation.
1286
- - stream (bool): Optional. Determines if the response should be streamed.
1287
- - only_need_context (bool): Optional. If true, returns only the context without further processing.
1288
-
1289
- Returns:
1290
- QueryResponse: A Pydantic model containing the result of the query processing.
1291
- If a string is returned (e.g., cache hit), it's directly returned.
1292
- Otherwise, an async generator may be used to build the response.
1293
-
1294
- Raises:
1295
- HTTPException: Raised when an error occurs during the request handling process,
1296
- with status code 500 and detail containing the exception message.
1297
- """
1298
- try:
1299
- response = await rag.aquery(
1300
- request.query,
1301
- param=QueryParam(
1302
- mode=request.mode,
1303
- stream=request.stream,
1304
- only_need_context=request.only_need_context,
1305
- top_k=global_top_k,
1306
- ),
1307
- )
1308
-
1309
- # If response is a string (e.g. cache hit), return directly
1310
- if isinstance(response, str):
1311
- return QueryResponse(response=response)
1312
-
1313
- # If it's an async generator, decide whether to stream based on stream parameter
1314
- if request.stream:
1315
- result = ""
1316
- async for chunk in response:
1317
- result += chunk
1318
- return QueryResponse(response=result)
1319
- else:
1320
- result = ""
1321
- async for chunk in response:
1322
- result += chunk
1323
- return QueryResponse(response=result)
1324
- except Exception as e:
1325
- trace_exception(e)
1326
- raise HTTPException(status_code=500, detail=str(e))
1327
-
1328
- @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
1329
- async def query_text_stream(request: QueryRequest):
1330
- """
1331
- This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
1332
-
1333
- Args:
1334
- request (QueryRequest): The request object containing the query parameters.
1335
- optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
1336
-
1337
- Returns:
1338
- StreamingResponse: A streaming response containing the RAG query results.
1339
- """
1340
- try:
1341
- response = await rag.aquery( # Use aquery instead of query, and add await
1342
- request.query,
1343
- param=QueryParam(
1344
- mode=request.mode,
1345
- stream=True,
1346
- only_need_context=request.only_need_context,
1347
- top_k=global_top_k,
1348
- ),
1349
- )
1350
 
1351
- from fastapi.responses import StreamingResponse
1352
-
1353
- async def stream_generator():
1354
- if isinstance(response, str):
1355
- # If it's a string, send it all at once
1356
- yield f"{json.dumps({'response': response})}\n"
1357
- else:
1358
- # If it's an async generator, send chunks one by one
1359
- try:
1360
- async for chunk in response:
1361
- if chunk: # Only send non-empty content
1362
- yield f"{json.dumps({'response': chunk})}\n"
1363
- except Exception as e:
1364
- logging.error(f"Streaming error: {str(e)}")
1365
- yield f"{json.dumps({'error': str(e)})}\n"
1366
-
1367
- return StreamingResponse(
1368
- stream_generator(),
1369
- media_type="application/x-ndjson",
1370
- headers={
1371
- "Cache-Control": "no-cache",
1372
- "Connection": "keep-alive",
1373
- "Content-Type": "application/x-ndjson",
1374
- "Access-Control-Allow-Origin": "*",
1375
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1376
- "Access-Control-Allow-Headers": "Content-Type",
1377
- "X-Accel-Buffering": "no", # Disable Nginx buffering
1378
- },
1379
  )
1380
  except Exception as e:
1381
- trace_exception(e)
 
1382
  raise HTTPException(status_code=500, detail=str(e))
1383
 
1384
  @app.post(
@@ -1386,7 +1467,9 @@ def create_app(args):
1386
  response_model=InsertResponse,
1387
  dependencies=[Depends(optional_api_key)],
1388
  )
1389
- async def insert_text(request: InsertTextRequest):
 
 
1390
  """
1391
  Insert text into the Retrieval-Augmented Generation (RAG) system.
1392
 
@@ -1394,18 +1477,20 @@ def create_app(args):
1394
 
1395
  Args:
1396
  request (InsertTextRequest): The request body containing the text to be inserted.
 
1397
 
1398
  Returns:
1399
  InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
1400
  """
1401
  try:
1402
- await rag.ainsert(request.text)
1403
  return InsertResponse(
1404
  status="success",
1405
- message="Text successfully inserted",
1406
- document_count=1,
1407
  )
1408
  except Exception as e:
 
 
1409
  raise HTTPException(status_code=500, detail=str(e))
1410
 
1411
  @app.post(
@@ -1413,12 +1498,14 @@ def create_app(args):
1413
  response_model=InsertResponse,
1414
  dependencies=[Depends(optional_api_key)],
1415
  )
1416
- async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
 
 
1417
  """Insert a file directly into the RAG system
1418
 
1419
  Args:
 
1420
  file: Uploaded file
1421
- description: Optional description of the file
1422
 
1423
  Returns:
1424
  InsertResponse: Status of the insertion operation
@@ -1427,68 +1514,26 @@ def create_app(args):
1427
  HTTPException: For unsupported file types or processing errors
1428
  """
1429
  try:
1430
- content = ""
1431
- # Get file extension in lowercase
1432
- ext = Path(file.filename).suffix.lower()
1433
-
1434
- match ext:
1435
- case ".txt" | ".md":
1436
- # Text files handling
1437
- text_content = await file.read()
1438
- content = text_content.decode("utf-8")
1439
-
1440
- case ".pdf" | ".docx" | ".pptx" | ".xlsx":
1441
- if not pm.is_installed("docling"):
1442
- pm.install("docling")
1443
- from docling.document_converter import DocumentConverter
1444
-
1445
- # Create a temporary file to save the uploaded content
1446
- temp_path = Path("temp") / file.filename
1447
- temp_path.parent.mkdir(exist_ok=True)
1448
-
1449
- # Save the uploaded file
1450
- with temp_path.open("wb") as f:
1451
- f.write(await file.read())
1452
-
1453
- try:
1454
-
1455
- async def convert_doc():
1456
- def sync_convert():
1457
- converter = DocumentConverter()
1458
- result = converter.convert(str(temp_path))
1459
- return result.document.export_to_markdown()
1460
-
1461
- return await asyncio.to_thread(sync_convert)
1462
-
1463
- content = await convert_doc()
1464
- finally:
1465
- # Clean up the temporary file
1466
- temp_path.unlink()
1467
-
1468
- # Insert content into RAG system
1469
- if content:
1470
- # Add description if provided
1471
- if description:
1472
- content = f"{description}\n\n{content}"
1473
-
1474
- await rag.ainsert(content)
1475
- logging.info(f"Successfully indexed file: {file.filename}")
1476
-
1477
- return InsertResponse(
1478
- status="success",
1479
- message=f"File '{file.filename}' successfully inserted",
1480
- document_count=1,
1481
- )
1482
- else:
1483
  raise HTTPException(
1484
  status_code=400,
1485
- detail="No content could be extracted from the file",
1486
  )
1487
 
1488
- except UnicodeDecodeError:
1489
- raise HTTPException(status_code=400, detail="File encoding not supported")
 
 
 
 
 
 
 
 
 
1490
  except Exception as e:
1491
- logging.error(f"Error processing file {file.filename}: {str(e)}")
 
1492
  raise HTTPException(status_code=500, detail=str(e))
1493
 
1494
  @app.post(
@@ -1496,10 +1541,13 @@ def create_app(args):
1496
  response_model=InsertResponse,
1497
  dependencies=[Depends(optional_api_key)],
1498
  )
1499
- async def insert_batch(files: List[UploadFile] = File(...)):
 
 
1500
  """Process multiple files in batch mode
1501
 
1502
  Args:
 
1503
  files: List of files to process
1504
 
1505
  Returns:
@@ -1511,72 +1559,18 @@ def create_app(args):
1511
  try:
1512
  inserted_count = 0
1513
  failed_files = []
 
1514
 
1515
  for file in files:
1516
- try:
1517
- content = ""
1518
- ext = Path(file.filename).suffix.lower()
1519
-
1520
- match ext:
1521
- case ".txt" | ".md":
1522
- text_content = await file.read()
1523
- content = text_content.decode("utf-8")
1524
-
1525
- case ".pdf":
1526
- if not pm.is_installed("pypdf2"):
1527
- pm.install("pypdf2")
1528
- from PyPDF2 import PdfReader
1529
- from io import BytesIO
1530
-
1531
- pdf_content = await file.read()
1532
- pdf_file = BytesIO(pdf_content)
1533
- reader = PdfReader(pdf_file)
1534
- for page in reader.pages:
1535
- content += page.extract_text() + "\n"
1536
-
1537
- case ".docx":
1538
- if not pm.is_installed("docx"):
1539
- pm.install("docx")
1540
- from docx import Document
1541
- from io import BytesIO
1542
-
1543
- docx_content = await file.read()
1544
- docx_file = BytesIO(docx_content)
1545
- doc = Document(docx_file)
1546
- content = "\n".join(
1547
- [paragraph.text for paragraph in doc.paragraphs]
1548
- )
1549
-
1550
- case ".pptx":
1551
- if not pm.is_installed("pptx"):
1552
- pm.install("pptx")
1553
- from pptx import Presentation # type: ignore
1554
- from io import BytesIO
1555
-
1556
- pptx_content = await file.read()
1557
- pptx_file = BytesIO(pptx_content)
1558
- prs = Presentation(pptx_file)
1559
- for slide in prs.slides:
1560
- for shape in slide.shapes:
1561
- if hasattr(shape, "text"):
1562
- content += shape.text + "\n"
1563
-
1564
- case _:
1565
- failed_files.append(f"{file.filename} (unsupported type)")
1566
- continue
1567
-
1568
- if content:
1569
- await rag.ainsert(content)
1570
- inserted_count += 1
1571
- logging.info(f"Successfully indexed file: {file.filename}")
1572
- else:
1573
- failed_files.append(f"{file.filename} (no content extracted)")
1574
 
1575
- except UnicodeDecodeError:
1576
- failed_files.append(f"{file.filename} (encoding error)")
1577
- except Exception as e:
1578
- failed_files.append(f"{file.filename} ({str(e)})")
1579
- logging.error(f"Error processing file {file.filename}: {str(e)}")
1580
 
1581
  # Prepare status message
1582
  if inserted_count == len(files):
@@ -1593,14 +1587,11 @@ def create_app(args):
1593
  if failed_files:
1594
  status_message += f". Failed files: {', '.join(failed_files)}"
1595
 
1596
- return InsertResponse(
1597
- status=status,
1598
- message=status_message,
1599
- document_count=inserted_count,
1600
- )
1601
 
1602
  except Exception as e:
1603
- logging.error(f"Batch processing error: {str(e)}")
 
1604
  raise HTTPException(status_code=500, detail=str(e))
1605
 
1606
  @app.delete(
@@ -1623,11 +1614,103 @@ def create_app(args):
1623
  rag.entities_vdb = None
1624
  rag.relationships_vdb = None
1625
  return InsertResponse(
1626
- status="success",
1627
- message="All documents cleared successfully",
1628
- document_count=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1629
  )
1630
  except Exception as e:
 
1631
  raise HTTPException(status_code=500, detail=str(e))
1632
 
1633
  # query all graph labels
 
3
  HTTPException,
4
  File,
5
  UploadFile,
 
6
  BackgroundTasks,
7
  )
8
  import asyncio
 
13
  from fastapi.staticfiles import StaticFiles
14
  import logging
15
  import argparse
16
+ from typing import List, Any, Optional, Dict
17
  from pydantic import BaseModel
18
  from lightrag import LightRAG, QueryParam
19
  from lightrag.types import GPTKeywordExtractionFormat
 
33
  import pipmaster as pm
34
  from dotenv import load_dotenv
35
  import configparser
36
+ import traceback
37
+ from datetime import datetime
38
+
39
  from lightrag.utils import logger
40
  from .ollama_api import (
41
  OllamaAPI,
 
161
  ASCIIColors.yellow(f"{args.host}")
162
  ASCIIColors.white(" ├─ Port: ", end="")
163
  ASCIIColors.yellow(f"{args.port}")
164
+ ASCIIColors.white(" ├─ CORS Origins: ", end="")
165
+ ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
166
+ ASCIIColors.white(" ├─ SSL Enabled: ", end="")
167
  ASCIIColors.yellow(f"{args.ssl}")
168
+ ASCIIColors.white(" └─ API Key: ", end="")
169
+ ASCIIColors.yellow("Set" if args.key else "Not Set")
170
  if args.ssl:
171
  ASCIIColors.white(" ├─ SSL Cert: ", end="")
172
  ASCIIColors.yellow(f"{args.ssl_certfile}")
 
235
  ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
236
  ASCIIColors.white(" ├─ Log Level: ", end="")
237
  ASCIIColors.yellow(f"{args.log_level}")
238
+ ASCIIColors.white(" └─ Timeout: ", end="")
239
  ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
 
 
240
 
241
  # Server Status
242
  ASCIIColors.green("\n✨ Server starting up...\n")
 
568
 
569
  args = parser.parse_args()
570
 
571
+ # conver relative path to absolute path
572
+ args.working_dir = os.path.abspath(args.working_dir)
573
+ args.input_dir = os.path.abspath(args.input_dir)
574
+
575
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
576
 
577
  return args
 
603
  """Scan input directory for new files"""
604
  new_files = []
605
  for ext in self.supported_extensions:
606
+ logger.info(f"Scanning for {ext} files in {self.input_dir}")
607
  for file_path in self.input_dir.rglob(f"*{ext}"):
608
  if file_path not in self.indexed_files:
609
  new_files.append(file_path)
 
637
 
638
  class QueryRequest(BaseModel):
639
  query: str
640
+
641
+ """Specifies the retrieval mode"""
642
  mode: SearchMode = SearchMode.hybrid
643
+
644
+ """If True, enables streaming output for real-time responses."""
645
+ stream: Optional[bool] = None
646
+
647
+ """If True, only returns the retrieved context without generating a response."""
648
+ only_need_context: Optional[bool] = None
649
+
650
+ """If True, only returns the generated prompt without producing a response."""
651
+ only_need_prompt: Optional[bool] = None
652
+
653
+ """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'."""
654
+ response_type: Optional[str] = None
655
+
656
+ """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
657
+ top_k: Optional[int] = None
658
+
659
+ """Maximum number of tokens allowed for each retrieved text chunk."""
660
+ max_token_for_text_unit: Optional[int] = None
661
+
662
+ """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
663
+ max_token_for_global_context: Optional[int] = None
664
+
665
+ """Maximum number of tokens allocated for entity descriptions in local retrieval."""
666
+ max_token_for_local_context: Optional[int] = None
667
+
668
+ """List of high-level keywords to prioritize in retrieval."""
669
+ hl_keywords: Optional[List[str]] = None
670
+
671
+ """List of low-level keywords to refine retrieval focus."""
672
+ ll_keywords: Optional[List[str]] = None
673
+
674
+ """Stores past conversation history to maintain context.
675
+ Format: [{"role": "user/assistant", "content": "message"}].
676
+ """
677
+ conversation_history: Optional[List[dict[str, Any]]] = None
678
+
679
+ """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
680
+ history_turns: Optional[int] = None
681
 
682
 
683
  class QueryResponse(BaseModel):
 
686
 
687
  class InsertTextRequest(BaseModel):
688
  text: str
 
689
 
690
 
691
  class InsertResponse(BaseModel):
692
  status: str
693
  message: str
694
+
695
+
696
+ def QueryRequestToQueryParams(request: QueryRequest):
697
+ param = QueryParam(mode=request.mode, stream=request.stream)
698
+ if request.only_need_context is not None:
699
+ param.only_need_context = request.only_need_context
700
+ if request.only_need_prompt is not None:
701
+ param.only_need_prompt = request.only_need_prompt
702
+ if request.response_type is not None:
703
+ param.response_type = request.response_type
704
+ if request.top_k is not None:
705
+ param.top_k = request.top_k
706
+ if request.max_token_for_text_unit is not None:
707
+ param.max_token_for_text_unit = request.max_token_for_text_unit
708
+ if request.max_token_for_global_context is not None:
709
+ param.max_token_for_global_context = request.max_token_for_global_context
710
+ if request.max_token_for_local_context is not None:
711
+ param.max_token_for_local_context = request.max_token_for_local_context
712
+ if request.hl_keywords is not None:
713
+ param.hl_keywords = request.hl_keywords
714
+ if request.ll_keywords is not None:
715
+ param.ll_keywords = request.ll_keywords
716
+ if request.conversation_history is not None:
717
+ param.conversation_history = request.conversation_history
718
+ if request.history_turns is not None:
719
+ param.history_turns = request.history_turns
720
+ return param
721
 
722
 
723
  def get_api_key_dependency(api_key: Optional[str]):
 
731
  # If API key is configured, use proper authentication
732
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
733
 
734
+ async def api_key_auth(
735
+ api_key_header_value: Optional[str] = Security(api_key_header),
736
+ ):
737
  if not api_key_header_value:
738
  raise HTTPException(
739
  status_code=HTTP_403_FORBIDDEN, detail="API Key required"
 
749
 
750
  # Global configuration
751
  global_top_k = 60 # default value
752
+ temp_prefix = "__tmp_" # prefix for temporary files
753
 
754
 
755
  def create_app(args):
 
917
  lifespan=lifespan,
918
  )
919
 
920
+ def get_cors_origins():
921
+ """Get allowed origins from environment variable
922
+ Returns a list of allowed origins, defaults to ["*"] if not set
923
+ """
924
+ origins_str = os.getenv("CORS_ORIGINS", "*")
925
+ if origins_str == "*":
926
+ return ["*"]
927
+ return [origin.strip() for origin in origins_str.split(",")]
928
+
929
  # Add CORS middleware
930
  app.add_middleware(
931
  CORSMiddleware,
932
+ allow_origins=get_cors_origins(),
933
  allow_credentials=True,
934
  allow_methods=["*"],
935
  allow_headers=["*"],
 
1200
  ("llm_response_cache", rag.llm_response_cache),
1201
  ]
1202
 
1203
+ async def pipeline_enqueue_file(file_path: Path) -> bool:
1204
+ """Add a file to the queue for processing
1205
 
1206
  Args:
1207
+ file_path: Path to the saved file
1208
+ Returns:
1209
+ bool: True if the file was successfully enqueued, False otherwise
 
 
1210
  """
1211
+ try:
1212
+ content = ""
1213
+ ext = file_path.suffix.lower()
1214
 
1215
+ file = None
1216
+ async with aiofiles.open(file_path, "rb") as f:
1217
+ file = await f.read()
1218
 
1219
+ # Process based on file type
1220
+ match ext:
1221
+ case ".txt" | ".md":
1222
+ content = file.decode("utf-8")
1223
+ case ".pdf":
1224
+ if not pm.is_installed("pypdf2"):
1225
+ pm.install("pypdf2")
1226
+ from PyPDF2 import PdfReader
1227
+ from io import BytesIO
1228
+
1229
+ pdf_file = BytesIO(file)
1230
+ reader = PdfReader(pdf_file)
1231
+ for page in reader.pages:
1232
+ content += page.extract_text() + "\n"
1233
+ case ".docx":
1234
+ if not pm.is_installed("docx"):
1235
+ pm.install("docx")
1236
+ from docx import Document
1237
+ from io import BytesIO
1238
+
1239
+ docx_content = await file.read()
1240
+ docx_file = BytesIO(docx_content)
1241
+ doc = Document(docx_file)
1242
+ content = "\n".join(
1243
+ [paragraph.text for paragraph in doc.paragraphs]
1244
+ )
1245
+ case ".pptx":
1246
+ if not pm.is_installed("pptx"):
1247
+ pm.install("pptx")
1248
+ from pptx import Presentation # type: ignore
1249
+ from io import BytesIO
1250
+
1251
+ pptx_content = await file.read()
1252
+ pptx_file = BytesIO(pptx_content)
1253
+ prs = Presentation(pptx_file)
1254
+ for slide in prs.slides:
1255
+ for shape in slide.shapes:
1256
+ if hasattr(shape, "text"):
1257
+ content += shape.text + "\n"
1258
+ case _:
1259
+ logging.error(
1260
+ f"Unsupported file type: {file_path.name} (extension {ext})"
1261
+ )
1262
+ return False
1263
+
1264
+ # Insert into the RAG queue
1265
+ if content:
1266
+ await rag.apipeline_enqueue_documents(content)
1267
+ logging.info(
1268
+ f"Successfully processed and enqueued file: {file_path.name}"
1269
+ )
1270
+ return True
1271
+ else:
1272
+ logging.error(
1273
+ f"No content could be extracted from file: {file_path.name}"
1274
+ )
1275
 
1276
+ except Exception as e:
1277
+ logging.error(
1278
+ f"Error processing or enqueueing file {file_path.name}: {str(e)}"
1279
+ )
1280
+ logging.error(traceback.format_exc())
1281
+ finally:
1282
+ if file_path.name.startswith(temp_prefix):
1283
+ # Clean up the temporary file after indexing
1284
+ try:
1285
+ file_path.unlink()
1286
+ except Exception as e:
1287
+ logging.error(f"Error deleting file {file_path}: {str(e)}")
1288
+ return False
1289
 
1290
+ async def pipeline_index_file(file_path: Path):
1291
+ """Index a file
 
 
 
1292
 
1293
+ Args:
1294
+ file_path: Path to the saved file
1295
+ """
1296
+ try:
1297
+ if await pipeline_enqueue_file(file_path):
1298
+ await rag.apipeline_process_enqueue_documents()
1299
 
1300
+ except Exception as e:
1301
+ logging.error(f"Error indexing file {file_path.name}: {str(e)}")
1302
+ logging.error(traceback.format_exc())
 
 
1303
 
1304
+ async def pipeline_index_files(file_paths: List[Path]):
1305
+ """Index multiple files concurrently
1306
 
1307
+ Args:
1308
+ file_paths: Paths to the files to index
1309
+ """
1310
+ if not file_paths:
1311
+ return
1312
+ try:
1313
+ enqueued = False
1314
 
1315
+ if len(file_paths) == 1:
1316
+ enqueued = await pipeline_enqueue_file(file_paths[0])
1317
+ else:
1318
+ tasks = [pipeline_enqueue_file(path) for path in file_paths]
1319
+ enqueued = any(await asyncio.gather(*tasks))
1320
 
1321
+ if enqueued:
1322
+ await rag.apipeline_process_enqueue_documents()
1323
+ except Exception as e:
1324
+ logging.error(f"Error indexing files: {str(e)}")
1325
+ logging.error(traceback.format_exc())
 
 
1326
 
1327
+ async def pipeline_index_texts(texts: List[str]):
1328
+ """Index a list of texts
 
 
1329
 
1330
+ Args:
1331
+ texts: The texts to index
1332
+ """
1333
+ if not texts:
1334
+ return
1335
+ await rag.apipeline_enqueue_documents(texts)
1336
+ await rag.apipeline_process_enqueue_documents()
1337
 
1338
+ async def save_temp_file(file: UploadFile = File(...)) -> Path:
1339
+ """Save the uploaded file to a temporary location
 
1340
 
1341
+ Args:
1342
+ file: The uploaded file
1343
 
1344
+ Returns:
1345
+ Path: The path to the saved file
1346
+ """
1347
+ # Generate unique filename to avoid conflicts
1348
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1349
+ unique_filename = f"{temp_prefix}{timestamp}_{file.filename}"
1350
+
1351
+ # Create a temporary file to save the uploaded content
1352
+ temp_path = doc_manager.input_dir / "temp" / unique_filename
1353
+ temp_path.parent.mkdir(exist_ok=True)
1354
+
1355
+ # Save the file
1356
+ with open(temp_path, "wb") as buffer:
1357
+ shutil.copyfileobj(file.file, buffer)
1358
+ return temp_path
1359
 
1360
  async def run_scanning_process():
1361
  """Background task to scan and index documents"""
 
1365
  new_files = doc_manager.scan_directory_for_new_files()
1366
  scan_progress["total_files"] = len(new_files)
1367
 
1368
+ logger.info(f"Found {len(new_files)} new files to index.")
1369
  for file_path in new_files:
1370
  try:
1371
  with progress_lock:
1372
  scan_progress["current_file"] = os.path.basename(file_path)
1373
 
1374
+ await pipeline_index_file(file_path)
1375
 
1376
  with progress_lock:
1377
  scan_progress["indexed_count"] += 1
 
1389
  with progress_lock:
1390
  scan_progress["is_scanning"] = False
1391
 
1392
+ @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
1393
+ async def scan_for_new_documents(background_tasks: BackgroundTasks):
1394
+ """Trigger the scanning process"""
1395
+ global scan_progress
1396
+
1397
+ with progress_lock:
1398
+ if scan_progress["is_scanning"]:
1399
+ return {"status": "already_scanning"}
1400
+
1401
+ scan_progress["is_scanning"] = True
1402
+ scan_progress["indexed_count"] = 0
1403
+ scan_progress["progress"] = 0
1404
+
1405
+ # Start the scanning process in the background
1406
+ background_tasks.add_task(run_scanning_process)
1407
+
1408
+ return {"status": "scanning_started"}
1409
+
1410
  @app.get("/documents/scan-progress")
1411
  async def get_scan_progress():
1412
  """Get the current scanning progress"""
 
1414
  return scan_progress
1415
 
1416
  @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
1417
+ async def upload_to_input_dir(
1418
+ background_tasks: BackgroundTasks, file: UploadFile = File(...)
1419
+ ):
1420
  """
1421
  Endpoint for uploading a file to the input directory and indexing it.
1422
 
 
1425
  indexes it for retrieval, and returns a success status with relevant details.
1426
 
1427
  Parameters:
1428
+ background_tasks: FastAPI BackgroundTasks for async processing
1429
  file (UploadFile): The file to be uploaded. It must have an allowed extension as per
1430
  `doc_manager.supported_extensions`.
1431
 
 
1450
  with open(file_path, "wb") as buffer:
1451
  shutil.copyfileobj(file.file, buffer)
1452
 
1453
+ # Add to background tasks
1454
+ background_tasks.add_task(pipeline_index_file, file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
 
1456
+ return InsertResponse(
1457
+ status="success",
1458
+ message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1459
  )
1460
  except Exception as e:
1461
+ logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
1462
+ logging.error(traceback.format_exc())
1463
  raise HTTPException(status_code=500, detail=str(e))
1464
 
1465
  @app.post(
 
1467
  response_model=InsertResponse,
1468
  dependencies=[Depends(optional_api_key)],
1469
  )
1470
+ async def insert_text(
1471
+ request: InsertTextRequest, background_tasks: BackgroundTasks
1472
+ ):
1473
  """
1474
  Insert text into the Retrieval-Augmented Generation (RAG) system.
1475
 
 
1477
 
1478
  Args:
1479
  request (InsertTextRequest): The request body containing the text to be inserted.
1480
+ background_tasks: FastAPI BackgroundTasks for async processing
1481
 
1482
  Returns:
1483
  InsertResponse: A response object containing the status of the operation, a message, and the number of documents inserted.
1484
  """
1485
  try:
1486
+ background_tasks.add_task(pipeline_index_texts, [request.text])
1487
  return InsertResponse(
1488
  status="success",
1489
+ message="Text successfully received. Processing will continue in background.",
 
1490
  )
1491
  except Exception as e:
1492
+ logging.error(f"Error /documents/text: {str(e)}")
1493
+ logging.error(traceback.format_exc())
1494
  raise HTTPException(status_code=500, detail=str(e))
1495
 
1496
  @app.post(
 
1498
  response_model=InsertResponse,
1499
  dependencies=[Depends(optional_api_key)],
1500
  )
1501
+ async def insert_file(
1502
+ background_tasks: BackgroundTasks, file: UploadFile = File(...)
1503
+ ):
1504
  """Insert a file directly into the RAG system
1505
 
1506
  Args:
1507
+ background_tasks: FastAPI BackgroundTasks for async processing
1508
  file: Uploaded file
 
1509
 
1510
  Returns:
1511
  InsertResponse: Status of the insertion operation
 
1514
  HTTPException: For unsupported file types or processing errors
1515
  """
1516
  try:
1517
+ if not doc_manager.is_supported_file(file.filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1518
  raise HTTPException(
1519
  status_code=400,
1520
+ detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
1521
  )
1522
 
1523
+ # Create a temporary file to save the uploaded content
1524
+ temp_path = save_temp_file(file)
1525
+
1526
+ # Add to background tasks
1527
+ background_tasks.add_task(pipeline_index_file, temp_path)
1528
+
1529
+ return InsertResponse(
1530
+ status="success",
1531
+ message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
1532
+ )
1533
+
1534
  except Exception as e:
1535
+ logging.error(f"Error /documents/file: {str(e)}")
1536
+ logging.error(traceback.format_exc())
1537
  raise HTTPException(status_code=500, detail=str(e))
1538
 
1539
  @app.post(
 
1541
  response_model=InsertResponse,
1542
  dependencies=[Depends(optional_api_key)],
1543
  )
1544
+ async def insert_batch(
1545
+ background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)
1546
+ ):
1547
  """Process multiple files in batch mode
1548
 
1549
  Args:
1550
+ background_tasks: FastAPI BackgroundTasks for async processing
1551
  files: List of files to process
1552
 
1553
  Returns:
 
1559
  try:
1560
  inserted_count = 0
1561
  failed_files = []
1562
+ temp_files = []
1563
 
1564
  for file in files:
1565
+ if doc_manager.is_supported_file(file.filename):
1566
+ # Create a temporary file to save the uploaded content
1567
+ temp_files.append(save_temp_file(file))
1568
+ inserted_count += 1
1569
+ else:
1570
+ failed_files.append(f"{file.filename} (unsupported type)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1571
 
1572
+ if temp_files:
1573
+ background_tasks.add_task(pipeline_index_files, temp_files)
 
 
 
1574
 
1575
  # Prepare status message
1576
  if inserted_count == len(files):
 
1587
  if failed_files:
1588
  status_message += f". Failed files: {', '.join(failed_files)}"
1589
 
1590
+ return InsertResponse(status=status, message=status_message)
 
 
 
 
1591
 
1592
  except Exception as e:
1593
+ logging.error(f"Error /documents/batch: {file.filename}: {str(e)}")
1594
+ logging.error(traceback.format_exc())
1595
  raise HTTPException(status_code=500, detail=str(e))
1596
 
1597
  @app.delete(
 
1614
  rag.entities_vdb = None
1615
  rag.relationships_vdb = None
1616
  return InsertResponse(
1617
+ status="success", message="All documents cleared successfully"
1618
+ )
1619
+ except Exception as e:
1620
+ logging.error(f"Error DELETE /documents: {str(e)}")
1621
+ logging.error(traceback.format_exc())
1622
+ raise HTTPException(status_code=500, detail=str(e))
1623
+
1624
+ @app.post(
1625
+ "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
1626
+ )
1627
+ async def query_text(request: QueryRequest):
1628
+ """
1629
+ Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
1630
+
1631
+ Parameters:
1632
+ request (QueryRequest): The request object containing the query parameters.
1633
+ Returns:
1634
+ QueryResponse: A Pydantic model containing the result of the query processing.
1635
+ If a string is returned (e.g., cache hit), it's directly returned.
1636
+ Otherwise, an async generator may be used to build the response.
1637
+
1638
+ Raises:
1639
+ HTTPException: Raised when an error occurs during the request handling process,
1640
+ with status code 500 and detail containing the exception message.
1641
+ """
1642
+ try:
1643
+ response = await rag.aquery(
1644
+ request.query, param=QueryRequestToQueryParams(request)
1645
+ )
1646
+
1647
+ # If response is a string (e.g. cache hit), return directly
1648
+ if isinstance(response, str):
1649
+ return QueryResponse(response=response)
1650
+
1651
+ # If it's an async generator, decide whether to stream based on stream parameter
1652
+ if request.stream or hasattr(response, "__aiter__"):
1653
+ result = ""
1654
+ async for chunk in response:
1655
+ result += chunk
1656
+ return QueryResponse(response=result)
1657
+ elif isinstance(response, dict):
1658
+ result = json.dumps(response, indent=2)
1659
+ return QueryResponse(response=result)
1660
+ else:
1661
+ return QueryResponse(response=str(response))
1662
+ except Exception as e:
1663
+ trace_exception(e)
1664
+ raise HTTPException(status_code=500, detail=str(e))
1665
+
1666
+ @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
1667
+ async def query_text_stream(request: QueryRequest):
1668
+ """
1669
+ This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
1670
+
1671
+ Args:
1672
+ request (QueryRequest): The request object containing the query parameters.
1673
+ optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
1674
+
1675
+ Returns:
1676
+ StreamingResponse: A streaming response containing the RAG query results.
1677
+ """
1678
+ try:
1679
+ params = QueryRequestToQueryParams(request)
1680
+
1681
+ params.stream = True
1682
+ response = await rag.aquery( # Use aquery instead of query, and add await
1683
+ request.query, param=params
1684
+ )
1685
+
1686
+ from fastapi.responses import StreamingResponse
1687
+
1688
+ async def stream_generator():
1689
+ if isinstance(response, str):
1690
+ # If it's a string, send it all at once
1691
+ yield f"{json.dumps({'response': response})}\n"
1692
+ else:
1693
+ # If it's an async generator, send chunks one by one
1694
+ try:
1695
+ async for chunk in response:
1696
+ if chunk: # Only send non-empty content
1697
+ yield f"{json.dumps({'response': chunk})}\n"
1698
+ except Exception as e:
1699
+ logging.error(f"Streaming error: {str(e)}")
1700
+ yield f"{json.dumps({'error': str(e)})}\n"
1701
+
1702
+ return StreamingResponse(
1703
+ stream_generator(),
1704
+ media_type="application/x-ndjson",
1705
+ headers={
1706
+ "Cache-Control": "no-cache",
1707
+ "Connection": "keep-alive",
1708
+ "Content-Type": "application/x-ndjson",
1709
+ "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
1710
+ },
1711
  )
1712
  except Exception as e:
1713
+ trace_exception(e)
1714
  raise HTTPException(status_code=500, detail=str(e))
1715
 
1716
  # query all graph labels
lightrag/api/ollama_api.py CHANGED
@@ -316,9 +316,7 @@ class OllamaAPI:
316
  "Cache-Control": "no-cache",
317
  "Connection": "keep-alive",
318
  "Content-Type": "application/x-ndjson",
319
- "Access-Control-Allow-Origin": "*",
320
- "Access-Control-Allow-Methods": "POST, OPTIONS",
321
- "Access-Control-Allow-Headers": "Content-Type",
322
  },
323
  )
324
  else:
@@ -534,9 +532,7 @@ class OllamaAPI:
534
  "Cache-Control": "no-cache",
535
  "Connection": "keep-alive",
536
  "Content-Type": "application/x-ndjson",
537
- "Access-Control-Allow-Origin": "*",
538
- "Access-Control-Allow-Methods": "POST, OPTIONS",
539
- "Access-Control-Allow-Headers": "Content-Type",
540
  },
541
  )
542
  else:
 
316
  "Cache-Control": "no-cache",
317
  "Connection": "keep-alive",
318
  "Content-Type": "application/x-ndjson",
319
+ "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
 
 
320
  },
321
  )
322
  else:
 
532
  "Cache-Control": "no-cache",
533
  "Connection": "keep-alive",
534
  "Content-Type": "application/x-ndjson",
535
+ "X-Accel-Buffering": "no", # 确保在Nginx代理时正确处理流式响应
 
 
536
  },
537
  )
538
  else:
lightrag/base.py CHANGED
@@ -1,13 +1,13 @@
 
 
1
  import os
2
  from dataclasses import dataclass, field
3
  from enum import Enum
4
  from typing import (
5
  Any,
6
  Literal,
7
- Optional,
8
  TypedDict,
9
  TypeVar,
10
- Union,
11
  )
12
 
13
  import numpy as np
@@ -69,7 +69,7 @@ class QueryParam:
69
  ll_keywords: list[str] = field(default_factory=list)
70
  """List of low-level keywords to refine retrieval focus."""
71
 
72
- conversation_history: list[dict[str, Any]] = field(default_factory=list)
73
  """Stores past conversation history to maintain context.
74
  Format: [{"role": "user/assistant", "content": "message"}].
75
  """
@@ -83,19 +83,15 @@ class StorageNameSpace:
83
  namespace: str
84
  global_config: dict[str, Any]
85
 
86
- async def index_done_callback(self):
87
  """Commit the storage operations after indexing"""
88
  pass
89
 
90
- async def query_done_callback(self):
91
- """Commit the storage operations after querying"""
92
- pass
93
-
94
 
95
  @dataclass
96
  class BaseVectorStorage(StorageNameSpace):
97
  embedding_func: EmbeddingFunc
98
- meta_fields: set = field(default_factory=set)
99
 
100
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
101
  raise NotImplementedError
@@ -106,12 +102,20 @@ class BaseVectorStorage(StorageNameSpace):
106
  """
107
  raise NotImplementedError
108
 
 
 
 
 
 
 
 
 
109
 
110
  @dataclass
111
  class BaseKVStorage(StorageNameSpace):
112
- embedding_func: EmbeddingFunc
113
 
114
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
115
  raise NotImplementedError
116
 
117
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -130,50 +134,75 @@ class BaseKVStorage(StorageNameSpace):
130
 
131
  @dataclass
132
  class BaseGraphStorage(StorageNameSpace):
133
- embedding_func: EmbeddingFunc = None
 
134
 
135
  async def has_node(self, node_id: str) -> bool:
136
  raise NotImplementedError
137
 
 
 
138
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
139
  raise NotImplementedError
140
 
 
 
141
  async def node_degree(self, node_id: str) -> int:
142
  raise NotImplementedError
143
 
 
 
144
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
145
  raise NotImplementedError
146
 
147
- async def get_node(self, node_id: str) -> Union[dict, None]:
 
 
148
  raise NotImplementedError
149
 
 
 
150
  async def get_edge(
151
  self, source_node_id: str, target_node_id: str
152
- ) -> Union[dict, None]:
153
  raise NotImplementedError
154
 
155
- async def get_node_edges(
156
- self, source_node_id: str
157
- ) -> Union[list[tuple[str, str]], None]:
158
  raise NotImplementedError
159
 
160
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
 
 
161
  raise NotImplementedError
162
 
 
 
163
  async def upsert_edge(
164
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
165
- ):
166
  raise NotImplementedError
167
 
168
- async def delete_node(self, node_id: str):
 
 
169
  raise NotImplementedError
170
 
171
- async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
 
 
 
172
  raise NotImplementedError("Node embedding is not used in lightrag.")
173
 
 
 
174
  async def get_all_labels(self) -> list[str]:
175
  raise NotImplementedError
176
 
 
 
177
  async def get_knowledge_graph(
178
  self, node_label: str, max_depth: int = 5
179
  ) -> KnowledgeGraph:
@@ -205,9 +234,9 @@ class DocProcessingStatus:
205
  """ISO format timestamp when document was created"""
206
  updated_at: str
207
  """ISO format timestamp when document was last updated"""
208
- chunks_count: Optional[int] = None
209
  """Number of chunks after splitting, used for processing"""
210
- error: Optional[str] = None
211
  """Error message if failed"""
212
  metadata: dict[str, Any] = field(default_factory=dict)
213
  """Additional metadata"""
@@ -220,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
220
  """Get counts of documents in each status"""
221
  raise NotImplementedError
222
 
223
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
224
- """Get all failed documents"""
225
- raise NotImplementedError
226
-
227
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
228
- """Get all pending documents"""
229
- raise NotImplementedError
230
-
231
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
232
- """Get all processing documents"""
233
- raise NotImplementedError
234
-
235
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
236
- """Get all procesed documents"""
237
  raise NotImplementedError
238
 
239
  async def update_doc_status(self, data: dict[str, Any]) -> None:
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  from dataclasses import dataclass, field
5
  from enum import Enum
6
  from typing import (
7
  Any,
8
  Literal,
 
9
  TypedDict,
10
  TypeVar,
 
11
  )
12
 
13
  import numpy as np
 
69
  ll_keywords: list[str] = field(default_factory=list)
70
  """List of low-level keywords to refine retrieval focus."""
71
 
72
+ conversation_history: list[dict[str, str]] = field(default_factory=list)
73
  """Stores past conversation history to maintain context.
74
  Format: [{"role": "user/assistant", "content": "message"}].
75
  """
 
83
  namespace: str
84
  global_config: dict[str, Any]
85
 
86
+ async def index_done_callback(self) -> None:
87
  """Commit the storage operations after indexing"""
88
  pass
89
 
 
 
 
 
90
 
91
  @dataclass
92
  class BaseVectorStorage(StorageNameSpace):
93
  embedding_func: EmbeddingFunc
94
+ meta_fields: set[str] = field(default_factory=set)
95
 
96
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
97
  raise NotImplementedError
 
102
  """
103
  raise NotImplementedError
104
 
105
+ async def delete_entity(self, entity_name: str) -> None:
106
+ """Delete a single entity by its name"""
107
+ raise NotImplementedError
108
+
109
+ async def delete_entity_relation(self, entity_name: str) -> None:
110
+ """Delete relations for a given entity by scanning metadata"""
111
+ raise NotImplementedError
112
+
113
 
114
  @dataclass
115
  class BaseKVStorage(StorageNameSpace):
116
+ embedding_func: EmbeddingFunc | None = None
117
 
118
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
119
  raise NotImplementedError
120
 
121
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
 
134
 
135
  @dataclass
136
  class BaseGraphStorage(StorageNameSpace):
137
+ embedding_func: EmbeddingFunc | None = None
138
+ """Check if a node exists in the graph."""
139
 
140
  async def has_node(self, node_id: str) -> bool:
141
  raise NotImplementedError
142
 
143
+ """Check if an edge exists in the graph."""
144
+
145
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
146
  raise NotImplementedError
147
 
148
+ """Get the degree of a node."""
149
+
150
  async def node_degree(self, node_id: str) -> int:
151
  raise NotImplementedError
152
 
153
+ """Get the degree of an edge."""
154
+
155
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
156
  raise NotImplementedError
157
 
158
+ """Get a node by its id."""
159
+
160
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
161
  raise NotImplementedError
162
 
163
+ """Get an edge by its source and target node ids."""
164
+
165
  async def get_edge(
166
  self, source_node_id: str, target_node_id: str
167
+ ) -> dict[str, str] | None:
168
  raise NotImplementedError
169
 
170
+ """Get all edges connected to a node."""
171
+
172
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
173
  raise NotImplementedError
174
 
175
+ """Upsert a node into the graph."""
176
+
177
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
178
  raise NotImplementedError
179
 
180
+ """Upsert an edge into the graph."""
181
+
182
  async def upsert_edge(
183
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
184
+ ) -> None:
185
  raise NotImplementedError
186
 
187
+ """Delete a node from the graph."""
188
+
189
+ async def delete_node(self, node_id: str) -> None:
190
  raise NotImplementedError
191
 
192
+ """Embed nodes using an algorithm."""
193
+
194
+ async def embed_nodes(
195
+ self, algorithm: str
196
+ ) -> tuple[np.ndarray[Any, Any], list[str]]:
197
  raise NotImplementedError("Node embedding is not used in lightrag.")
198
 
199
+ """Get all labels in the graph."""
200
+
201
  async def get_all_labels(self) -> list[str]:
202
  raise NotImplementedError
203
 
204
+ """Get a knowledge graph of a node."""
205
+
206
  async def get_knowledge_graph(
207
  self, node_label: str, max_depth: int = 5
208
  ) -> KnowledgeGraph:
 
234
  """ISO format timestamp when document was created"""
235
  updated_at: str
236
  """ISO format timestamp when document was last updated"""
237
+ chunks_count: int | None = None
238
  """Number of chunks after splitting, used for processing"""
239
+ error: str | None = None
240
  """Error message if failed"""
241
  metadata: dict[str, Any] = field(default_factory=dict)
242
  """Additional metadata"""
 
249
  """Get counts of documents in each status"""
250
  raise NotImplementedError
251
 
252
+ async def get_docs_by_status(
253
+ self, status: DocStatus
254
+ ) -> dict[str, DocProcessingStatus]:
255
+ """Get all documents with a specific status"""
 
 
 
 
 
 
 
 
 
 
256
  raise NotImplementedError
257
 
258
  async def update_doc_status(self, data: dict[str, Any]) -> None:
lightrag/exceptions.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import httpx
2
  from typing import Literal
3
 
 
1
+ from __future__ import annotations
2
+
3
  import httpx
4
  from typing import Literal
5
 
lightrag/kg/chroma_impl.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
  from dataclasses import dataclass
3
  from typing import Union
4
  import numpy as np
5
- from chromadb import HttpClient
6
  from chromadb.config import Settings
7
  from lightrag.base import BaseVectorStorage
8
  from lightrag.utils import logger
@@ -49,31 +49,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
49
  **user_collection_settings,
50
  }
51
 
52
- auth_provider = config.get(
53
- "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
54
- )
55
- auth_credentials = config.get("auth_token", "secret-token")
56
- headers = {}
57
-
58
- if "token_authn" in auth_provider:
59
- headers = {
60
- config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
61
- }
62
- elif "basic_authn" in auth_provider:
63
- auth_credentials = config.get("auth_credentials", "admin:admin")
64
-
65
- self._client = HttpClient(
66
- host=config.get("host", "localhost"),
67
- port=config.get("port", 8000),
68
- headers=headers,
69
- settings=Settings(
70
- chroma_api_impl="rest",
71
- chroma_client_auth_provider=auth_provider,
72
- chroma_client_auth_credentials=auth_credentials,
73
- allow_reset=True,
74
- anonymized_telemetry=False,
75
- ),
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  self._collection = self._client.get_or_create_collection(
79
  name=self.namespace,
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
144
  embedding = await self.embedding_func([query])
145
 
146
  results = self._collection.query(
147
- query_embeddings=embedding.tolist(),
 
 
148
  n_results=top_k * 2, # Request more results to allow for filtering
149
  include=["metadatas", "distances", "documents"],
150
  )
 
2
  from dataclasses import dataclass
3
  from typing import Union
4
  import numpy as np
5
+ from chromadb import HttpClient, PersistentClient
6
  from chromadb.config import Settings
7
  from lightrag.base import BaseVectorStorage
8
  from lightrag.utils import logger
 
49
  **user_collection_settings,
50
  }
51
 
52
+ local_path = config.get("local_path", None)
53
+ if local_path:
54
+ self._client = PersistentClient(
55
+ path=local_path,
56
+ settings=Settings(
57
+ allow_reset=True,
58
+ anonymized_telemetry=False,
59
+ ),
60
+ )
61
+ else:
62
+ auth_provider = config.get(
63
+ "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
64
+ )
65
+ auth_credentials = config.get("auth_token", "secret-token")
66
+ headers = {}
67
+
68
+ if "token_authn" in auth_provider:
69
+ headers = {
70
+ config.get(
71
+ "auth_header_name", "X-Chroma-Token"
72
+ ): auth_credentials
73
+ }
74
+ elif "basic_authn" in auth_provider:
75
+ auth_credentials = config.get("auth_credentials", "admin:admin")
76
+
77
+ self._client = HttpClient(
78
+ host=config.get("host", "localhost"),
79
+ port=config.get("port", 8000),
80
+ headers=headers,
81
+ settings=Settings(
82
+ chroma_api_impl="rest",
83
+ chroma_client_auth_provider=auth_provider,
84
+ chroma_client_auth_credentials=auth_credentials,
85
+ allow_reset=True,
86
+ anonymized_telemetry=False,
87
+ ),
88
+ )
89
 
90
  self._collection = self._client.get_or_create_collection(
91
  name=self.namespace,
 
156
  embedding = await self.embedding_func([query])
157
 
158
  results = self._collection.query(
159
+ query_embeddings=embedding.tolist()
160
+ if not isinstance(embedding, list)
161
+ else embedding,
162
  n_results=top_k * 2, # Request more results to allow for filtering
163
  include=["metadatas", "distances", "documents"],
164
  )
lightrag/kg/faiss_impl.py CHANGED
@@ -27,8 +27,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
27
 
28
  def __post_init__(self):
29
  # Grab config values if available
30
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
- cosine_threshold = config.get("cosine_better_than_threshold")
32
  if cosine_threshold is None:
33
  raise ValueError(
34
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
219
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
220
  await self.delete([entity_id])
221
 
222
- async def delete_entity_relation(self, entity_name: str):
223
  """
224
  Delete relations for a given entity by scanning metadata.
225
  """
 
27
 
28
  def __post_init__(self):
29
  # Grab config values if available
30
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
32
  if cosine_threshold is None:
33
  raise ValueError(
34
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
 
219
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
220
  await self.delete([entity_id])
221
 
222
+ async def delete_entity_relation(self, entity_name: str) -> None:
223
  """
224
  Delete relations for a given entity by scanning metadata.
225
  """
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
93
  counts[doc["status"]] += 1
94
  return counts
95
 
96
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
97
- """Get all failed documents"""
 
 
98
  return {
99
  k: DocProcessingStatus(**v)
100
  for k, v in self._data.items()
101
- if v["status"] == DocStatus.FAILED
102
- }
103
-
104
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
105
- """Get all pending documents"""
106
- return {
107
- k: DocProcessingStatus(**v)
108
- for k, v in self._data.items()
109
- if v["status"] == DocStatus.PENDING
110
- }
111
-
112
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
113
- """Get all processed documents"""
114
- return {
115
- k: DocProcessingStatus(**v)
116
- for k, v in self._data.items()
117
- if v["status"] == DocStatus.PROCESSED
118
- }
119
-
120
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
121
- """Get all processing documents"""
122
- return {
123
- k: DocProcessingStatus(**v)
124
- for k, v in self._data.items()
125
- if v["status"] == DocStatus.PROCESSING
126
  }
127
 
128
  async def index_done_callback(self):
 
93
  counts[doc["status"]] += 1
94
  return counts
95
 
96
+ async def get_docs_by_status(
97
+ self, status: DocStatus
98
+ ) -> dict[str, DocProcessingStatus]:
99
+ """all documents with a specific status"""
100
  return {
101
  k: DocProcessingStatus(**v)
102
  for k, v in self._data.items()
103
+ if v["status"] == status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  }
105
 
106
  async def index_done_callback(self):
lightrag/kg/json_kv_impl.py CHANGED
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
47
 
48
  async def drop(self) -> None:
49
  self._data = {}
 
 
 
 
 
 
47
 
48
  async def drop(self) -> None:
49
  self._data = {}
50
+
51
+ async def delete(self, ids: list[str]) -> None:
52
+ for doc_id in ids:
53
+ self._data.pop(doc_id, None)
54
+ await self.index_done_callback()
lightrag/kg/mongo_impl.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
 
7
 
8
  if not pm.is_installed("pymongo"):
9
  pm.install("pymongo")
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
14
  from typing import Any, List, Tuple, Union
15
  from motor.motor_asyncio import AsyncIOMotorClient
16
  from pymongo import MongoClient
 
 
17
 
18
  from ..base import (
19
  BaseGraphStorage,
20
  BaseKVStorage,
 
21
  DocProcessingStatus,
22
  DocStatus,
23
  DocStatusStorage,
24
  )
25
  from ..namespace import NameSpace, is_namespace
26
  from ..utils import logger
 
27
 
28
 
29
  config = configparser.ConfigParser()
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
33
  @dataclass
34
  class MongoKVStorage(BaseKVStorage):
35
  def __post_init__(self):
36
- client = MongoClient(
37
- os.environ.get(
38
- "MONGO_URI",
39
- config.get(
40
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
41
- ),
42
- )
43
  )
 
44
  database = client.get_database(
45
  os.environ.get(
46
  "MONGO_DATABASE",
47
  config.get("mongodb", "database", fallback="LightRAG"),
48
  )
49
  )
50
- self._data = database.get_collection(self.namespace)
51
- logger.info(f"Use MongoDB as KV {self.namespace}")
 
 
 
 
 
 
52
 
53
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
54
- return self._data.find_one({"_id": id})
55
 
56
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
57
- return list(self._data.find({"_id": {"$in": ids}}))
 
58
 
59
  async def filter_keys(self, data: set[str]) -> set[str]:
60
- existing_ids = [
61
- str(x["_id"])
62
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
63
- ]
64
- return set([s for s in data if s not in existing_ids])
65
 
66
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
67
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
68
  for mode, items in data.items():
69
- for k, v in tqdm_async(items.items(), desc="Upserting"):
70
  key = f"{mode}_{k}"
71
- result = self._data.update_one(
72
- {"_id": key}, {"$setOnInsert": v}, upsert=True
 
 
 
73
  )
74
- if result.upserted_id:
75
- logger.debug(f"\nInserted new document with key: {key}")
76
- data[mode][k]["_id"] = key
77
  else:
78
- for k, v in tqdm_async(data.items(), desc="Upserting"):
79
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
80
  data[k]["_id"] = k
 
 
 
 
81
 
82
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
83
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
84
  res = {}
85
- v = self._data.find_one({"_id": mode + "_" + id})
86
  if v:
87
  res[id] = v
88
  logger.debug(f"llm_response_cache find one by:{id}")
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
100
  @dataclass
101
  class MongoDocStatusStorage(DocStatusStorage):
102
  def __post_init__(self):
103
- client = MongoClient(
104
- os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
 
 
 
105
  )
106
- database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
107
- self._data = database.get_collection(self.namespace)
108
- logger.info(f"Use MongoDB as doc status {self.namespace}")
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
111
- return self._data.find_one({"_id": id})
112
 
113
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
114
- return list(self._data.find({"_id": {"$in": ids}}))
 
115
 
116
  async def filter_keys(self, data: set[str]) -> set[str]:
117
- existing_ids = [
118
- str(x["_id"])
119
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
120
- ]
121
- return set([s for s in data if s not in existing_ids])
122
 
123
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
124
  for k, v in data.items():
125
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
126
  data[k]["_id"] = k
 
 
 
 
127
 
128
  async def drop(self) -> None:
129
  """Drop the collection"""
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
132
  async def get_status_counts(self) -> dict[str, int]:
133
  """Get counts of documents in each status"""
134
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
135
- result = list(self._data.aggregate(pipeline))
 
136
  counts = {}
137
  for doc in result:
138
  counts[doc["_id"]] = doc["count"]
@@ -141,8 +175,9 @@ class MongoDocStatusStorage(DocStatusStorage):
141
  async def get_docs_by_status(
142
  self, status: DocStatus
143
  ) -> dict[str, DocProcessingStatus]:
144
- """Get all documents by status"""
145
- result = list(self._data.find({"status": status.value}))
 
146
  return {
147
  doc["_id"]: DocProcessingStatus(
148
  content=doc["content"],
@@ -156,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
156
  for doc in result
157
  }
158
 
159
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
160
- """Get all failed documents"""
161
- return await self.get_docs_by_status(DocStatus.FAILED)
162
-
163
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
164
- """Get all pending documents"""
165
- return await self.get_docs_by_status(DocStatus.PENDING)
166
-
167
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
168
- """Get all processing documents"""
169
- return await self.get_docs_by_status(DocStatus.PROCESSING)
170
-
171
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
172
- """Get all procesed documents"""
173
- return await self.get_docs_by_status(DocStatus.PROCESSED)
174
-
175
 
176
  @dataclass
177
  class MongoGraphStorage(BaseGraphStorage):
@@ -185,26 +204,27 @@ class MongoGraphStorage(BaseGraphStorage):
185
  global_config=global_config,
186
  embedding_func=embedding_func,
187
  )
188
- self.client = AsyncIOMotorClient(
189
- os.environ.get(
190
- "MONGO_URI",
191
- config.get(
192
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
193
- ),
194
- )
195
  )
196
- self.db = self.client[
 
197
  os.environ.get(
198
  "MONGO_DATABASE",
199
- mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
200
- )
201
- ]
202
- self.collection = self.db[
203
- os.environ.get(
204
- "MONGO_KG_COLLECTION",
205
- config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
206
  )
207
- ]
 
 
 
 
 
 
 
 
208
 
209
  #
210
  # -------------------------------------------------------------------------
@@ -451,7 +471,7 @@ class MongoGraphStorage(BaseGraphStorage):
451
  self, source_node_id: str
452
  ) -> Union[List[Tuple[str, str]], None]:
453
  """
454
- Return a list of (target_id, relation) for direct edges from source_node_id.
455
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
456
  """
457
  pipeline = [
@@ -475,7 +495,7 @@ class MongoGraphStorage(BaseGraphStorage):
475
  return None
476
 
477
  edges = result[0].get("edges", [])
478
- return [(e["target"], e["relation"]) for e in edges]
479
 
480
  #
481
  # -------------------------------------------------------------------------
@@ -522,7 +542,7 @@ class MongoGraphStorage(BaseGraphStorage):
522
 
523
  async def delete_node(self, node_id: str):
524
  """
525
- 1) Remove nodes doc entirely.
526
  2) Remove inbound edges from any doc that references node_id.
527
  """
528
  # Remove inbound edges from all other docs
@@ -542,3 +562,359 @@ class MongoGraphStorage(BaseGraphStorage):
542
  Placeholder for demonstration, raises NotImplementedError.
543
  """
544
  raise NotImplementedError("Node embedding is not used in lightrag.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ import asyncio
8
 
9
  if not pm.is_installed("pymongo"):
10
  pm.install("pymongo")
 
15
  from typing import Any, List, Tuple, Union
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from pymongo import MongoClient
18
+ from pymongo.operations import SearchIndexModel
19
+ from pymongo.errors import PyMongoError
20
 
21
  from ..base import (
22
  BaseGraphStorage,
23
  BaseKVStorage,
24
+ BaseVectorStorage,
25
  DocProcessingStatus,
26
  DocStatus,
27
  DocStatusStorage,
28
  )
29
  from ..namespace import NameSpace, is_namespace
30
  from ..utils import logger
31
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
32
 
33
 
34
  config = configparser.ConfigParser()
 
38
  @dataclass
39
  class MongoKVStorage(BaseKVStorage):
40
  def __post_init__(self):
41
+ uri = os.environ.get(
42
+ "MONGO_URI",
43
+ config.get(
44
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
45
+ ),
 
 
46
  )
47
+ client = AsyncIOMotorClient(uri)
48
  database = client.get_database(
49
  os.environ.get(
50
  "MONGO_DATABASE",
51
  config.get("mongodb", "database", fallback="LightRAG"),
52
  )
53
  )
54
+
55
+ self._collection_name = self.namespace
56
+
57
+ self._data = database.get_collection(self._collection_name)
58
+ logger.debug(f"Use MongoDB as KV {self._collection_name}")
59
+
60
+ # Ensure collection exists
61
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
62
 
63
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
64
+ return await self._data.find_one({"_id": id})
65
 
66
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
67
+ cursor = self._data.find({"_id": {"$in": ids}})
68
+ return await cursor.to_list()
69
 
70
  async def filter_keys(self, data: set[str]) -> set[str]:
71
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
72
+ existing_ids = {str(x["_id"]) async for x in cursor}
73
+ return data - existing_ids
 
 
74
 
75
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
76
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
77
+ update_tasks = []
78
  for mode, items in data.items():
79
+ for k, v in items.items():
80
  key = f"{mode}_{k}"
81
+ data[mode][k]["_id"] = f"{mode}_{k}"
82
+ update_tasks.append(
83
+ self._data.update_one(
84
+ {"_id": key}, {"$setOnInsert": v}, upsert=True
85
+ )
86
  )
87
+ await asyncio.gather(*update_tasks)
 
 
88
  else:
89
+ update_tasks = []
90
+ for k, v in data.items():
91
  data[k]["_id"] = k
92
+ update_tasks.append(
93
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
94
+ )
95
+ await asyncio.gather(*update_tasks)
96
 
97
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
98
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
99
  res = {}
100
+ v = await self._data.find_one({"_id": mode + "_" + id})
101
  if v:
102
  res[id] = v
103
  logger.debug(f"llm_response_cache find one by:{id}")
 
115
  @dataclass
116
  class MongoDocStatusStorage(DocStatusStorage):
117
  def __post_init__(self):
118
+ uri = os.environ.get(
119
+ "MONGO_URI",
120
+ config.get(
121
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
122
+ ),
123
  )
124
+ client = AsyncIOMotorClient(uri)
125
+ database = client.get_database(
126
+ os.environ.get(
127
+ "MONGO_DATABASE",
128
+ config.get("mongodb", "database", fallback="LightRAG"),
129
+ )
130
+ )
131
+
132
+ self._collection_name = self.namespace
133
+ self._data = database.get_collection(self._collection_name)
134
+
135
+ logger.debug(f"Use MongoDB as doc status {self._collection_name}")
136
+
137
+ # Ensure collection exists
138
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
139
 
140
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
141
+ return await self._data.find_one({"_id": id})
142
 
143
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
144
+ cursor = self._data.find({"_id": {"$in": ids}})
145
+ return await cursor.to_list()
146
 
147
  async def filter_keys(self, data: set[str]) -> set[str]:
148
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
149
+ existing_ids = {str(x["_id"]) async for x in cursor}
150
+ return data - existing_ids
 
 
151
 
152
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
153
+ update_tasks = []
154
  for k, v in data.items():
 
155
  data[k]["_id"] = k
156
+ update_tasks.append(
157
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
158
+ )
159
+ await asyncio.gather(*update_tasks)
160
 
161
  async def drop(self) -> None:
162
  """Drop the collection"""
 
165
  async def get_status_counts(self) -> dict[str, int]:
166
  """Get counts of documents in each status"""
167
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
168
+ cursor = self._data.aggregate(pipeline)
169
+ result = await cursor.to_list()
170
  counts = {}
171
  for doc in result:
172
  counts[doc["_id"]] = doc["count"]
 
175
  async def get_docs_by_status(
176
  self, status: DocStatus
177
  ) -> dict[str, DocProcessingStatus]:
178
+ """Get all documents with a specific status"""
179
+ cursor = self._data.find({"status": status.value})
180
+ result = await cursor.to_list()
181
  return {
182
  doc["_id"]: DocProcessingStatus(
183
  content=doc["content"],
 
191
  for doc in result
192
  }
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  @dataclass
196
  class MongoGraphStorage(BaseGraphStorage):
 
204
  global_config=global_config,
205
  embedding_func=embedding_func,
206
  )
207
+ uri = os.environ.get(
208
+ "MONGO_URI",
209
+ config.get(
210
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
211
+ ),
 
 
212
  )
213
+ client = AsyncIOMotorClient(uri)
214
+ database = client.get_database(
215
  os.environ.get(
216
  "MONGO_DATABASE",
217
+ config.get("mongodb", "database", fallback="LightRAG"),
 
 
 
 
 
 
218
  )
219
+ )
220
+
221
+ self._collection_name = self.namespace
222
+ self.collection = database.get_collection(self._collection_name)
223
+
224
+ logger.debug(f"Use MongoDB as KG {self._collection_name}")
225
+
226
+ # Ensure collection exists
227
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
228
 
229
  #
230
  # -------------------------------------------------------------------------
 
471
  self, source_node_id: str
472
  ) -> Union[List[Tuple[str, str]], None]:
473
  """
474
+ Return a list of (source_id, target_id) for direct edges from source_node_id.
475
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
476
  """
477
  pipeline = [
 
495
  return None
496
 
497
  edges = result[0].get("edges", [])
498
+ return [(source_node_id, e["target"]) for e in edges]
499
 
500
  #
501
  # -------------------------------------------------------------------------
 
542
 
543
  async def delete_node(self, node_id: str):
544
  """
545
+ 1) Remove node's doc entirely.
546
  2) Remove inbound edges from any doc that references node_id.
547
  """
548
  # Remove inbound edges from all other docs
 
562
  Placeholder for demonstration, raises NotImplementedError.
563
  """
564
  raise NotImplementedError("Node embedding is not used in lightrag.")
565
+
566
+ #
567
+ # -------------------------------------------------------------------------
568
+ # QUERY
569
+ # -------------------------------------------------------------------------
570
+ #
571
+
572
+ async def get_all_labels(self) -> list[str]:
573
+ """
574
+ Get all existing node _id in the database
575
+ Returns:
576
+ [id1, id2, ...] # Alphabetically sorted id list
577
+ """
578
+ # Use MongoDB's distinct and aggregation to get all unique labels
579
+ pipeline = [
580
+ {"$group": {"_id": "$_id"}}, # Group by _id
581
+ {"$sort": {"_id": 1}}, # Sort alphabetically
582
+ ]
583
+
584
+ cursor = self.collection.aggregate(pipeline)
585
+ labels = []
586
+ async for doc in cursor:
587
+ labels.append(doc["_id"])
588
+ return labels
589
+
590
+ async def get_knowledge_graph(
591
+ self, node_label: str, max_depth: int = 5
592
+ ) -> KnowledgeGraph:
593
+ """
594
+ Get complete connected subgraph for specified node (including the starting node itself)
595
+
596
+ Args:
597
+ node_label: Label of the nodes to start from
598
+ max_depth: Maximum depth of traversal (default: 5)
599
+
600
+ Returns:
601
+ KnowledgeGraph object containing nodes and edges of the subgraph
602
+ """
603
+ label = node_label
604
+ result = KnowledgeGraph()
605
+ seen_nodes = set()
606
+ seen_edges = set()
607
+
608
+ try:
609
+ if label == "*":
610
+ # Get all nodes and edges
611
+ async for node_doc in self.collection.find({}):
612
+ node_id = str(node_doc["_id"])
613
+ if node_id not in seen_nodes:
614
+ result.nodes.append(
615
+ KnowledgeGraphNode(
616
+ id=node_id,
617
+ labels=[node_doc.get("_id")],
618
+ properties={
619
+ k: v
620
+ for k, v in node_doc.items()
621
+ if k not in ["_id", "edges"]
622
+ },
623
+ )
624
+ )
625
+ seen_nodes.add(node_id)
626
+
627
+ # Process edges
628
+ for edge in node_doc.get("edges", []):
629
+ edge_id = f"{node_id}-{edge['target']}"
630
+ if edge_id not in seen_edges:
631
+ result.edges.append(
632
+ KnowledgeGraphEdge(
633
+ id=edge_id,
634
+ type=edge.get("relation", ""),
635
+ source=node_id,
636
+ target=edge["target"],
637
+ properties={
638
+ k: v
639
+ for k, v in edge.items()
640
+ if k not in ["target", "relation"]
641
+ },
642
+ )
643
+ )
644
+ seen_edges.add(edge_id)
645
+ else:
646
+ # Verify if starting node exists
647
+ start_nodes = self.collection.find({"_id": label})
648
+ start_nodes_exist = await start_nodes.to_list(length=1)
649
+ if not start_nodes_exist:
650
+ logger.warning(f"Starting node with label {label} does not exist!")
651
+ return result
652
+
653
+ # Use $graphLookup for traversal
654
+ pipeline = [
655
+ {
656
+ "$match": {"_id": label}
657
+ }, # Start with nodes having the specified label
658
+ {
659
+ "$graphLookup": {
660
+ "from": self._collection_name,
661
+ "startWith": "$edges.target",
662
+ "connectFromField": "edges.target",
663
+ "connectToField": "_id",
664
+ "maxDepth": max_depth,
665
+ "depthField": "depth",
666
+ "as": "connected_nodes",
667
+ }
668
+ },
669
+ ]
670
+
671
+ async for doc in self.collection.aggregate(pipeline):
672
+ # Add the start node
673
+ node_id = str(doc["_id"])
674
+ if node_id not in seen_nodes:
675
+ result.nodes.append(
676
+ KnowledgeGraphNode(
677
+ id=node_id,
678
+ labels=[
679
+ doc.get(
680
+ "_id",
681
+ )
682
+ ],
683
+ properties={
684
+ k: v
685
+ for k, v in doc.items()
686
+ if k
687
+ not in [
688
+ "_id",
689
+ "edges",
690
+ "connected_nodes",
691
+ "depth",
692
+ ]
693
+ },
694
+ )
695
+ )
696
+ seen_nodes.add(node_id)
697
+
698
+ # Add edges from start node
699
+ for edge in doc.get("edges", []):
700
+ edge_id = f"{node_id}-{edge['target']}"
701
+ if edge_id not in seen_edges:
702
+ result.edges.append(
703
+ KnowledgeGraphEdge(
704
+ id=edge_id,
705
+ type=edge.get("relation", ""),
706
+ source=node_id,
707
+ target=edge["target"],
708
+ properties={
709
+ k: v
710
+ for k, v in edge.items()
711
+ if k not in ["target", "relation"]
712
+ },
713
+ )
714
+ )
715
+ seen_edges.add(edge_id)
716
+
717
+ # Add connected nodes and their edges
718
+ for connected in doc.get("connected_nodes", []):
719
+ node_id = str(connected["_id"])
720
+ if node_id not in seen_nodes:
721
+ result.nodes.append(
722
+ KnowledgeGraphNode(
723
+ id=node_id,
724
+ labels=[connected.get("_id")],
725
+ properties={
726
+ k: v
727
+ for k, v in connected.items()
728
+ if k not in ["_id", "edges", "depth"]
729
+ },
730
+ )
731
+ )
732
+ seen_nodes.add(node_id)
733
+
734
+ # Add edges from connected nodes
735
+ for edge in connected.get("edges", []):
736
+ edge_id = f"{node_id}-{edge['target']}"
737
+ if edge_id not in seen_edges:
738
+ result.edges.append(
739
+ KnowledgeGraphEdge(
740
+ id=edge_id,
741
+ type=edge.get("relation", ""),
742
+ source=node_id,
743
+ target=edge["target"],
744
+ properties={
745
+ k: v
746
+ for k, v in edge.items()
747
+ if k not in ["target", "relation"]
748
+ },
749
+ )
750
+ )
751
+ seen_edges.add(edge_id)
752
+
753
+ logger.info(
754
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
755
+ )
756
+
757
+ except PyMongoError as e:
758
+ logger.error(f"MongoDB query failed: {str(e)}")
759
+
760
+ return result
761
+
762
+
763
+ @dataclass
764
+ class MongoVectorDBStorage(BaseVectorStorage):
765
+ cosine_better_than_threshold: float = None
766
+
767
+ def __post_init__(self):
768
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
769
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
770
+ if cosine_threshold is None:
771
+ raise ValueError(
772
+ "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
773
+ )
774
+ self.cosine_better_than_threshold = cosine_threshold
775
+
776
+ uri = os.environ.get(
777
+ "MONGO_URI",
778
+ config.get(
779
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
780
+ ),
781
+ )
782
+ client = AsyncIOMotorClient(uri)
783
+ database = client.get_database(
784
+ os.environ.get(
785
+ "MONGO_DATABASE",
786
+ config.get("mongodb", "database", fallback="LightRAG"),
787
+ )
788
+ )
789
+
790
+ self._collection_name = self.namespace
791
+ self._data = database.get_collection(self._collection_name)
792
+ self._max_batch_size = self.global_config["embedding_batch_num"]
793
+
794
+ logger.debug(f"Use MongoDB as VDB {self._collection_name}")
795
+
796
+ # Ensure collection exists
797
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
798
+
799
+ # Ensure vector index exists
800
+ self.create_vector_index(uri, database.name, self._collection_name)
801
+
802
+ def create_vector_index(self, uri: str, database_name: str, collection_name: str):
803
+ """Creates an Atlas Vector Search index."""
804
+ client = MongoClient(uri)
805
+ collection = client.get_database(database_name).get_collection(
806
+ self._collection_name
807
+ )
808
+
809
+ try:
810
+ search_index_model = SearchIndexModel(
811
+ definition={
812
+ "fields": [
813
+ {
814
+ "type": "vector",
815
+ "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
816
+ "path": "vector",
817
+ "similarity": "cosine", # Options: euclidean, cosine, dotProduct
818
+ }
819
+ ]
820
+ },
821
+ name="vector_knn_index",
822
+ type="vectorSearch",
823
+ )
824
+
825
+ collection.create_search_index(search_index_model)
826
+ logger.info("Vector index created successfully.")
827
+
828
+ except PyMongoError as _:
829
+ logger.debug("vector index already exist")
830
+
831
+ async def upsert(self, data: dict[str, dict]):
832
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
833
+ if not data:
834
+ logger.warning("You are inserting an empty data set to vector DB")
835
+ return []
836
+
837
+ list_data = [
838
+ {
839
+ "_id": k,
840
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
841
+ }
842
+ for k, v in data.items()
843
+ ]
844
+ contents = [v["content"] for v in data.values()]
845
+ batches = [
846
+ contents[i : i + self._max_batch_size]
847
+ for i in range(0, len(contents), self._max_batch_size)
848
+ ]
849
+
850
+ async def wrapped_task(batch):
851
+ result = await self.embedding_func(batch)
852
+ pbar.update(1)
853
+ return result
854
+
855
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
856
+ pbar = tqdm_async(
857
+ total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
858
+ )
859
+ embeddings_list = await asyncio.gather(*embedding_tasks)
860
+
861
+ embeddings = np.concatenate(embeddings_list)
862
+ for i, d in enumerate(list_data):
863
+ d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
864
+
865
+ update_tasks = []
866
+ for doc in list_data:
867
+ update_tasks.append(
868
+ self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
869
+ )
870
+ await asyncio.gather(*update_tasks)
871
+
872
+ return list_data
873
+
874
+ async def query(self, query, top_k=5):
875
+ """Queries the vector database using Atlas Vector Search."""
876
+ # Generate the embedding
877
+ embedding = await self.embedding_func([query])
878
+
879
+ # Convert numpy array to a list to ensure compatibility with MongoDB
880
+ query_vector = embedding[0].tolist()
881
+
882
+ # Define the aggregation pipeline with the converted query vector
883
+ pipeline = [
884
+ {
885
+ "$vectorSearch": {
886
+ "index": "vector_knn_index", # Ensure this matches the created index name
887
+ "path": "vector",
888
+ "queryVector": query_vector,
889
+ "numCandidates": 100, # Adjust for performance
890
+ "limit": top_k,
891
+ }
892
+ },
893
+ {"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
894
+ {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
895
+ {"$project": {"vector": 0}},
896
+ ]
897
+
898
+ # Execute the aggregation pipeline
899
+ cursor = self._data.aggregate(pipeline)
900
+ results = await cursor.to_list()
901
+
902
+ # Format and return the results
903
+ return [
904
+ {**doc, "id": doc["_id"], "distance": doc.get("score", None)}
905
+ for doc in results
906
+ ]
907
+
908
+
909
+ def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
910
+ """Check if the collection exists. if not, create it."""
911
+ client = MongoClient(uri)
912
+ database = client.get_database(database_name)
913
+
914
+ collection_names = database.list_collection_names()
915
+
916
+ if collection_name not in collection_names:
917
+ database.create_collection(collection_name)
918
+ logger.info(f"Created collection: {collection_name}")
919
+ else:
920
+ logger.debug(f"Collection '{collection_name}' already exists.")
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -79,8 +79,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
79
  # Initialize lock only for file operations
80
  self._save_lock = asyncio.Lock()
81
  # Use global config value if specified, otherwise use default
82
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
83
- cosine_threshold = config.get("cosine_better_than_threshold")
84
  if cosine_threshold is None:
85
  raise ValueError(
86
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
191
  except Exception as e:
192
  logger.error(f"Error deleting entity {entity_name}: {e}")
193
 
194
- async def delete_entity_relation(self, entity_name: str):
195
  try:
196
  relations = [
197
  dp
 
79
  # Initialize lock only for file operations
80
  self._save_lock = asyncio.Lock()
81
  # Use global config value if specified, otherwise use default
82
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
83
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
84
  if cosine_threshold is None:
85
  raise ValueError(
86
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
 
191
  except Exception as e:
192
  logger.error(f"Error deleting entity {entity_name}: {e}")
193
 
194
+ async def delete_entity_relation(self, entity_name: str) -> None:
195
  try:
196
  relations = [
197
  dp
lightrag/kg/neo4j_impl.py CHANGED
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
- async def has_node(self, node_id: str) -> bool:
147
- entity_name_label = node_id.strip('"')
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
149
  async with self._driver.session(database=self._DATABASE) as session:
150
  query = (
151
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
174
  return single_result["edgeExists"]
175
 
176
  async def get_node(self, node_id: str) -> Union[dict, None]:
 
 
 
 
 
 
 
 
 
177
  async with self._driver.session(database=self._DATABASE) as session:
178
- entity_name_label = node_id.strip('"')
179
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
180
  result = await session.run(query)
181
  record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
226
  async def get_edge(
227
  self, source_node_id: str, target_node_id: str
228
  ) -> Union[dict, None]:
229
- entity_name_label_source = source_node_id.strip('"')
230
- entity_name_label_target = target_node_id.strip('"')
231
- """
232
- Find all edges between nodes of two given labels
233
 
234
  Args:
235
- source_node_label (str): Label of the source nodes
236
- target_node_label (str): Label of the target nodes
237
 
238
  Returns:
239
- list: List of all relationships/edges found
 
240
  """
241
- async with self._driver.session(database=self._DATABASE) as session:
242
- query = f"""
243
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
244
- RETURN properties(r) as edge_properties
245
- LIMIT 1
246
- """.format(
247
- entity_name_label_source=entity_name_label_source,
248
- entity_name_label_target=entity_name_label_target,
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- result = await session.run(query)
252
- record = await result.single()
253
- if record:
254
- result = dict(record["edge_properties"])
255
  logger.debug(
256
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
257
  )
258
- return result
259
- else:
260
- return None
 
 
 
 
 
 
261
 
262
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
263
  node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
310
  node_id: The unique identifier for the node (used as label)
311
  node_data: Dictionary of node properties
312
  """
313
- label = node_id.strip('"')
314
  properties = node_data
315
 
316
  async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
338
  neo4jExceptions.ServiceUnavailable,
339
  neo4jExceptions.TransientError,
340
  neo4jExceptions.WriteServiceUnavailable,
 
341
  )
342
  ),
343
  )
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
352
  target_node_id (str): Label of the target node (used as identifier)
353
  edge_data (dict): Dictionary of properties to set on the edge
354
  """
355
- source_node_label = source_node_id.strip('"')
356
- target_node_label = target_node_id.strip('"')
357
  edge_properties = edge_data
358
 
359
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
360
  query = f"""
361
- MATCH (source:`{source_node_label}`)
362
  WITH source
363
- MATCH (target:`{target_node_label}`)
364
  MERGE (source)-[r:DIRECTED]->(target)
365
  SET r += $properties
366
  RETURN r
367
  """
368
- await tx.run(query, properties=edge_properties)
 
369
  logger.debug(
370
- f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
371
  )
372
 
373
  try:
 
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
+ async def _label_exists(self, label: str) -> bool:
147
+ """Check if a label exists in the Neo4j database."""
148
+ query = "CALL db.labels() YIELD label RETURN label"
149
+ try:
150
+ async with self._driver.session(database=self._DATABASE) as session:
151
+ result = await session.run(query)
152
+ labels = [record["label"] for record in await result.data()]
153
+ return label in labels
154
+ except Exception as e:
155
+ logger.error(f"Error checking label existence: {e}")
156
+ return False
157
 
158
+ async def _ensure_label(self, label: str) -> str:
159
+ """Ensure a label exists by validating it."""
160
+ clean_label = label.strip('"')
161
+ if not await self._label_exists(clean_label):
162
+ logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
163
+ return clean_label
164
+
165
+ async def has_node(self, node_id: str) -> bool:
166
+ entity_name_label = await self._ensure_label(node_id)
167
  async with self._driver.session(database=self._DATABASE) as session:
168
  query = (
169
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
 
192
  return single_result["edgeExists"]
193
 
194
  async def get_node(self, node_id: str) -> Union[dict, None]:
195
+ """Get node by its label identifier.
196
+
197
+ Args:
198
+ node_id: The node label to look up
199
+
200
+ Returns:
201
+ dict: Node properties if found
202
+ None: If node not found
203
+ """
204
  async with self._driver.session(database=self._DATABASE) as session:
205
+ entity_name_label = await self._ensure_label(node_id)
206
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
207
  result = await session.run(query)
208
  record = await result.single()
 
253
  async def get_edge(
254
  self, source_node_id: str, target_node_id: str
255
  ) -> Union[dict, None]:
256
+ """Find edge between two nodes identified by their labels.
 
 
 
257
 
258
  Args:
259
+ source_node_id (str): Label of the source node
260
+ target_node_id (str): Label of the target node
261
 
262
  Returns:
263
+ dict: Edge properties if found, with at least {"weight": 0.0}
264
+ None: If error occurs
265
  """
266
+ try:
267
+ entity_name_label_source = source_node_id.strip('"')
268
+ entity_name_label_target = target_node_id.strip('"')
269
+
270
+ async with self._driver.session(database=self._DATABASE) as session:
271
+ query = f"""
272
+ MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
273
+ RETURN properties(r) as edge_properties
274
+ LIMIT 1
275
+ """.format(
276
+ entity_name_label_source=entity_name_label_source,
277
+ entity_name_label_target=entity_name_label_target,
278
+ )
279
+
280
+ result = await session.run(query)
281
+ record = await result.single()
282
+ if record and "edge_properties" in record:
283
+ try:
284
+ result = dict(record["edge_properties"])
285
+ # Ensure required keys exist with defaults
286
+ required_keys = {
287
+ "weight": 0.0,
288
+ "source_id": None,
289
+ "target_id": None,
290
+ }
291
+ for key, default_value in required_keys.items():
292
+ if key not in result:
293
+ result[key] = default_value
294
+ logger.warning(
295
+ f"Edge between {entity_name_label_source} and {entity_name_label_target} "
296
+ f"missing {key}, using default: {default_value}"
297
+ )
298
+
299
+ logger.debug(
300
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
301
+ )
302
+ return result
303
+ except (KeyError, TypeError, ValueError) as e:
304
+ logger.error(
305
+ f"Error processing edge properties between {entity_name_label_source} "
306
+ f"and {entity_name_label_target}: {str(e)}"
307
+ )
308
+ # Return default edge properties on error
309
+ return {"weight": 0.0, "source_id": None, "target_id": None}
310
 
 
 
 
 
311
  logger.debug(
312
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
313
  )
314
+ # Return default edge properties when no edge found
315
+ return {"weight": 0.0, "source_id": None, "target_id": None}
316
+
317
+ except Exception as e:
318
+ logger.error(
319
+ f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
320
+ )
321
+ # Return default edge properties on error
322
+ return {"weight": 0.0, "source_id": None, "target_id": None}
323
 
324
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
325
  node_label = source_node_id.strip('"')
 
372
  node_id: The unique identifier for the node (used as label)
373
  node_data: Dictionary of node properties
374
  """
375
+ label = await self._ensure_label(node_id)
376
  properties = node_data
377
 
378
  async def _do_upsert(tx: AsyncManagedTransaction):
 
400
  neo4jExceptions.ServiceUnavailable,
401
  neo4jExceptions.TransientError,
402
  neo4jExceptions.WriteServiceUnavailable,
403
+ neo4jExceptions.ClientError,
404
  )
405
  ),
406
  )
 
415
  target_node_id (str): Label of the target node (used as identifier)
416
  edge_data (dict): Dictionary of properties to set on the edge
417
  """
418
+ source_label = await self._ensure_label(source_node_id)
419
+ target_label = await self._ensure_label(target_node_id)
420
  edge_properties = edge_data
421
 
422
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
423
  query = f"""
424
+ MATCH (source:`{source_label}`)
425
  WITH source
426
+ MATCH (target:`{target_label}`)
427
  MERGE (source)-[r:DIRECTED]->(target)
428
  SET r += $properties
429
  RETURN r
430
  """
431
+ result = await tx.run(query, properties=edge_properties)
432
+ record = await result.single()
433
  logger.debug(
434
+ f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
435
  )
436
 
437
  try:
lightrag/kg/postgres_impl.py CHANGED
@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
468
  async def get_docs_by_status(
469
  self, status: DocStatus
470
  ) -> Dict[str, DocProcessingStatus]:
471
- """Get all documents by status"""
472
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
473
  params = {"workspace": self.db.workspace, "status": status}
474
  result = await self.db.query(sql, params, True)
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
485
  for element in result
486
  }
487
 
488
- async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
489
- """Get all failed documents"""
490
- return await self.get_docs_by_status(DocStatus.FAILED)
491
-
492
- async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
493
- """Get all pending documents"""
494
- return await self.get_docs_by_status(DocStatus.PENDING)
495
-
496
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
497
- """Get all processing documents"""
498
- return await self.get_docs_by_status(DocStatus.PROCESSING)
499
-
500
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
501
- """Get all procesed documents"""
502
- return await self.get_docs_by_status(DocStatus.PROCESSED)
503
-
504
  async def index_done_callback(self):
505
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
506
  logger.info("Doc status had been saved into postgresql db!")
 
468
  async def get_docs_by_status(
469
  self, status: DocStatus
470
  ) -> Dict[str, DocProcessingStatus]:
471
+ """all documents with a specific status"""
472
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
473
  params = {"workspace": self.db.workspace, "status": status}
474
  result = await self.db.query(sql, params, True)
 
485
  for element in result
486
  }
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  async def index_done_callback(self):
489
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
490
  logger.info("Doc status had been saved into postgresql db!")
lightrag/lightrag.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import asyncio
2
  import os
3
  import configparser
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
- from typing import Any, Callable, Optional, Type, Union, cast
8
 
9
  from .base import (
10
  BaseGraphStorage,
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
76
  "FaissVectorDBStorage",
77
  "QdrantVectorDBStorage",
78
  "OracleVectorDBStorage",
 
79
  ],
80
  "required_methods": ["query", "upsert"],
81
  },
@@ -86,12 +89,12 @@ STORAGE_IMPLEMENTATIONS = {
86
  "PGDocStatusStorage",
87
  "MongoDocStatusStorage",
88
  ],
89
- "required_methods": ["get_pending_docs"],
90
  },
91
  }
92
 
93
  # Storage implementation environment variable without default value
94
- STORAGE_ENV_REQUIREMENTS = {
95
  # KV Storage Implementations
96
  "JsonKVStorage": [],
97
  "MongoKVStorage": [],
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
140
  "ORACLE_PASSWORD",
141
  "ORACLE_CONFIG_DIR",
142
  ],
 
143
  # Document Status Storage Implementations
144
  "JsonDocStatusStorage": [],
145
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -160,6 +164,7 @@ STORAGES = {
160
  "MongoKVStorage": ".kg.mongo_impl",
161
  "MongoDocStatusStorage": ".kg.mongo_impl",
162
  "MongoGraphStorage": ".kg.mongo_impl",
 
163
  "RedisKVStorage": ".kg.redis_impl",
164
  "ChromaVectorDBStorage": ".kg.chroma_impl",
165
  "TiDBKVStorage": ".kg.tidb_impl",
@@ -176,7 +181,7 @@ STORAGES = {
176
  }
177
 
178
 
179
- def lazy_external_import(module_name: str, class_name: str):
180
  """Lazily import a class from an external module based on the package of the caller."""
181
  # Get the caller's module and package
182
  import inspect
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
185
  module = inspect.getmodule(caller_frame)
186
  package = module.__package__ if module else None
187
 
188
- def import_class(*args, **kwargs):
189
  import importlib
190
 
191
  module = importlib.import_module(module_name, package=package)
@@ -225,7 +230,7 @@ class LightRAG:
225
  """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
226
 
227
  working_dir: str = field(
228
- default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
229
  )
230
  """Directory where cache and temporary files are stored."""
231
 
@@ -302,7 +307,7 @@ class LightRAG:
302
  - random_seed: Seed value for reproducibility.
303
  """
304
 
305
- embedding_func: EmbeddingFunc = None
306
  """Function for computing text embeddings. Must be set before use."""
307
 
308
  embedding_batch_num: int = 32
@@ -312,7 +317,7 @@ class LightRAG:
312
  """Maximum number of concurrent embedding function calls."""
313
 
314
  # LLM Configuration
315
- llm_model_func: callable = None
316
  """Function for interacting with the large language model (LLM). Must be set before use."""
317
 
318
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -342,10 +347,8 @@ class LightRAG:
342
 
343
  # Extensions
344
  addon_params: dict[str, Any] = field(default_factory=dict)
345
- """Dictionary for additional parameters and extensions."""
346
 
347
- # extension
348
- addon_params: dict[str, Any] = field(default_factory=dict)
349
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
350
  convert_response_to_json
351
  )
@@ -354,7 +357,7 @@ class LightRAG:
354
  chunking_func: Callable[
355
  [
356
  str,
357
- Optional[str],
358
  bool,
359
  int,
360
  int,
@@ -443,77 +446,74 @@ class LightRAG:
443
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
444
 
445
  # Init LLM
446
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
447
  self.embedding_func
448
  )
449
 
450
  # Initialize all storages
451
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
452
  self._get_storage_class(self.kv_storage)
453
- )
454
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
455
  self.vector_storage
456
- )
457
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
458
  self.graph_storage
459
- )
460
-
461
- self.key_string_value_json_storage_cls = partial(
462
  self.key_string_value_json_storage_cls, global_config=global_config
463
  )
464
-
465
- self.vector_db_storage_cls = partial(
466
  self.vector_db_storage_cls, global_config=global_config
467
  )
468
-
469
- self.graph_storage_cls = partial(
470
  self.graph_storage_cls, global_config=global_config
471
  )
472
 
473
  # Initialize document status storage
474
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
475
 
476
- self.llm_response_cache = self.key_string_value_json_storage_cls(
477
  namespace=make_namespace(
478
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
479
  ),
480
  embedding_func=self.embedding_func,
481
  )
482
 
483
- self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
484
  namespace=make_namespace(
485
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
486
  ),
487
  embedding_func=self.embedding_func,
488
  )
489
- self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
490
  namespace=make_namespace(
491
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
492
  ),
493
  embedding_func=self.embedding_func,
494
  )
495
- self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
496
  namespace=make_namespace(
497
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
498
  ),
499
  embedding_func=self.embedding_func,
500
  )
501
 
502
- self.entities_vdb = self.vector_db_storage_cls(
503
  namespace=make_namespace(
504
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
505
  ),
506
  embedding_func=self.embedding_func,
507
  meta_fields={"entity_name"},
508
  )
509
- self.relationships_vdb = self.vector_db_storage_cls(
510
  namespace=make_namespace(
511
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
512
  ),
513
  embedding_func=self.embedding_func,
514
  meta_fields={"src_id", "tgt_id"},
515
  )
516
- self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
517
  namespace=make_namespace(
518
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
519
  ),
@@ -527,13 +527,12 @@ class LightRAG:
527
  embedding_func=None,
528
  )
529
 
530
- # What's for, Is this nessisary ?
531
  if self.llm_response_cache and hasattr(
532
  self.llm_response_cache, "global_config"
533
  ):
534
  hashing_kv = self.llm_response_cache
535
  else:
536
- hashing_kv = self.key_string_value_json_storage_cls(
537
  namespace=make_namespace(
538
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
539
  ),
@@ -542,7 +541,7 @@ class LightRAG:
542
 
543
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
544
  partial(
545
- self.llm_model_func,
546
  hashing_kv=hashing_kv,
547
  **self.llm_model_kwargs,
548
  )
@@ -559,68 +558,45 @@ class LightRAG:
559
  node_label=nodel_label, max_depth=max_depth
560
  )
561
 
562
- def _get_storage_class(self, storage_name: str) -> dict:
563
  import_path = STORAGES[storage_name]
564
  storage_class = lazy_external_import(import_path, storage_name)
565
  return storage_class
566
 
567
- def set_storage_client(self, db_client):
568
- # Deprecated, seting correct value to *_storage of LightRAG insteaded
569
- # Inject db to storage implementation (only tested on Oracle Database)
570
- for storage in [
571
- self.vector_db_storage_cls,
572
- self.graph_storage_cls,
573
- self.doc_status,
574
- self.full_docs,
575
- self.text_chunks,
576
- self.llm_response_cache,
577
- self.key_string_value_json_storage_cls,
578
- self.chunks_vdb,
579
- self.relationships_vdb,
580
- self.entities_vdb,
581
- self.graph_storage_cls,
582
- self.chunk_entity_relation_graph,
583
- self.llm_response_cache,
584
- ]:
585
- # set client
586
- storage.db = db_client
587
-
588
  def insert(
589
  self,
590
- string_or_strings: Union[str, list[str]],
591
  split_by_character: str | None = None,
592
  split_by_character_only: bool = False,
593
  ):
594
  """Sync Insert documents with checkpoint support
595
 
596
  Args:
597
- string_or_strings: Single document string or list of document strings
598
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
599
- chunk_size, split the sub chunk by token size.
600
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
601
  split_by_character is None, this parameter is ignored.
602
  """
603
  loop = always_get_an_event_loop()
604
  return loop.run_until_complete(
605
- self.ainsert(string_or_strings, split_by_character, split_by_character_only)
606
  )
607
 
608
  async def ainsert(
609
  self,
610
- string_or_strings: Union[str, list[str]],
611
  split_by_character: str | None = None,
612
  split_by_character_only: bool = False,
613
  ):
614
  """Async Insert documents with checkpoint support
615
 
616
  Args:
617
- string_or_strings: Single document string or list of document strings
618
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
619
- chunk_size, split the sub chunk by token size.
620
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
621
  split_by_character is None, this parameter is ignored.
622
  """
623
- await self.apipeline_enqueue_documents(string_or_strings)
624
  await self.apipeline_process_enqueue_documents(
625
  split_by_character, split_by_character_only
626
  )
@@ -677,7 +653,7 @@ class LightRAG:
677
  if update_storage:
678
  await self._insert_done()
679
 
680
- async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
681
  """
682
  Pipeline for Processing Documents
683
 
@@ -686,11 +662,11 @@ class LightRAG:
686
  3. Filter out already processed documents
687
  4. Enqueue document in status
688
  """
689
- if isinstance(string_or_strings, str):
690
- string_or_strings = [string_or_strings]
691
 
692
  # 1. Remove duplicate contents from the list
693
- unique_contents = list(set(doc.strip() for doc in string_or_strings))
694
 
695
  # 2. Generate document IDs and initial status
696
  new_docs: dict[str, Any] = {
@@ -739,11 +715,11 @@ class LightRAG:
739
  # 1. Get all pending, failed, and abnormally terminated processing documents.
740
  to_process_docs: dict[str, DocProcessingStatus] = {}
741
 
742
- processing_docs = await self.doc_status.get_processing_docs()
743
  to_process_docs.update(processing_docs)
744
- failed_docs = await self.doc_status.get_failed_docs()
745
  to_process_docs.update(failed_docs)
746
- pendings_docs = await self.doc_status.get_pending_docs()
747
  to_process_docs.update(pendings_docs)
748
 
749
  if not to_process_docs:
@@ -857,32 +833,32 @@ class LightRAG:
857
  raise e
858
 
859
  async def _insert_done(self):
860
- tasks = []
861
- for storage_inst in [
862
- self.full_docs,
863
- self.text_chunks,
864
- self.llm_response_cache,
865
- self.entities_vdb,
866
- self.relationships_vdb,
867
- self.chunks_vdb,
868
- self.chunk_entity_relation_graph,
869
- ]:
870
- if storage_inst is None:
871
- continue
872
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
873
  await asyncio.gather(*tasks)
874
 
875
- def insert_custom_kg(self, custom_kg: dict):
876
  loop = always_get_an_event_loop()
877
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
878
 
879
- async def ainsert_custom_kg(self, custom_kg: dict):
880
  update_storage = False
881
  try:
882
  # Insert chunks into vector storage
883
- all_chunks_data = {}
884
- chunk_to_source_map = {}
885
- for chunk_data in custom_kg.get("chunks", []):
886
  chunk_content = chunk_data["content"]
887
  source_id = chunk_data["source_id"]
888
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -892,13 +868,13 @@ class LightRAG:
892
  chunk_to_source_map[source_id] = chunk_id
893
  update_storage = True
894
 
895
- if self.chunks_vdb is not None and all_chunks_data:
896
  await self.chunks_vdb.upsert(all_chunks_data)
897
- if self.text_chunks is not None and all_chunks_data:
898
  await self.text_chunks.upsert(all_chunks_data)
899
 
900
  # Insert entities into knowledge graph
901
- all_entities_data = []
902
  for entity_data in custom_kg.get("entities", []):
903
  entity_name = f'"{entity_data["entity_name"].upper()}"'
904
  entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -914,7 +890,7 @@ class LightRAG:
914
  )
915
 
916
  # Prepare node data
917
- node_data = {
918
  "entity_type": entity_type,
919
  "description": description,
920
  "source_id": source_id,
@@ -928,7 +904,7 @@ class LightRAG:
928
  update_storage = True
929
 
930
  # Insert relationships into knowledge graph
931
- all_relationships_data = []
932
  for relationship_data in custom_kg.get("relationships", []):
933
  src_id = f'"{relationship_data["src_id"].upper()}"'
934
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -970,7 +946,7 @@ class LightRAG:
970
  "source_id": source_id,
971
  },
972
  )
973
- edge_data = {
974
  "src_id": src_id,
975
  "tgt_id": tgt_id,
976
  "description": description,
@@ -980,41 +956,68 @@ class LightRAG:
980
  update_storage = True
981
 
982
  # Insert entities into vector storage if needed
983
- if self.entities_vdb is not None:
984
- data_for_vdb = {
985
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
986
- "content": dp["entity_name"] + dp["description"],
987
- "entity_name": dp["entity_name"],
988
- }
989
- for dp in all_entities_data
990
  }
991
- await self.entities_vdb.upsert(data_for_vdb)
 
 
992
 
993
  # Insert relationships into vector storage if needed
994
- if self.relationships_vdb is not None:
995
- data_for_vdb = {
996
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
997
- "src_id": dp["src_id"],
998
- "tgt_id": dp["tgt_id"],
999
- "content": dp["keywords"]
1000
- + dp["src_id"]
1001
- + dp["tgt_id"]
1002
- + dp["description"],
1003
- }
1004
- for dp in all_relationships_data
1005
  }
1006
- await self.relationships_vdb.upsert(data_for_vdb)
 
 
 
1007
  finally:
1008
  if update_storage:
1009
  await self._insert_done()
1010
 
1011
- def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  loop = always_get_an_event_loop()
1013
- return loop.run_until_complete(self.aquery(query, prompt, param))
 
1014
 
1015
  async def aquery(
1016
- self, query: str, prompt: str = "", param: QueryParam = QueryParam()
1017
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  if param.mode in ["local", "global", "hybrid"]:
1019
  response = await kg_query(
1020
  query,
@@ -1094,7 +1097,7 @@ class LightRAG:
1094
 
1095
  async def aquery_with_separate_keyword_extraction(
1096
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1097
- ):
1098
  """
1099
  1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
1100
  2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1117,8 +1120,8 @@ class LightRAG:
1117
  ),
1118
  )
1119
 
1120
- param.hl_keywords = (hl_keywords,)
1121
- param.ll_keywords = (ll_keywords,)
1122
 
1123
  # ---------------------
1124
  # STEP 2: Final Query Logic
@@ -1146,7 +1149,7 @@ class LightRAG:
1146
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1147
  ),
1148
  global_config=asdict(self),
1149
- embedding_func=self.embedding_funcne,
1150
  ),
1151
  )
1152
  elif param.mode == "naive":
@@ -1195,12 +1198,7 @@ class LightRAG:
1195
  return response
1196
 
1197
  async def _query_done(self):
1198
- tasks = []
1199
- for storage_inst in [self.llm_response_cache]:
1200
- if storage_inst is None:
1201
- continue
1202
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
1203
- await asyncio.gather(*tasks)
1204
 
1205
  def delete_by_entity(self, entity_name: str):
1206
  loop = always_get_an_event_loop()
@@ -1222,16 +1220,16 @@ class LightRAG:
1222
  logger.error(f"Error while deleting entity '{entity_name}': {e}")
1223
 
1224
  async def _delete_by_entity_done(self):
1225
- tasks = []
1226
- for storage_inst in [
1227
- self.entities_vdb,
1228
- self.relationships_vdb,
1229
- self.chunk_entity_relation_graph,
1230
- ]:
1231
- if storage_inst is None:
1232
- continue
1233
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
1234
- await asyncio.gather(*tasks)
1235
 
1236
  def _get_content_summary(self, content: str, max_length: int = 100) -> str:
1237
  """Get summary of document content
@@ -1256,7 +1254,7 @@ class LightRAG:
1256
  """
1257
  return await self.doc_status.get_status_counts()
1258
 
1259
- async def adelete_by_doc_id(self, doc_id: str):
1260
  """Delete a document and all its related data
1261
 
1262
  Args:
@@ -1273,6 +1271,9 @@ class LightRAG:
1273
 
1274
  # 2. Get all related chunks
1275
  chunks = await self.text_chunks.get_by_id(doc_id)
 
 
 
1276
  chunk_ids = list(chunks.keys())
1277
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1278
 
@@ -1443,13 +1444,9 @@ class LightRAG:
1443
  except Exception as e:
1444
  logger.error(f"Error while deleting document {doc_id}: {e}")
1445
 
1446
- def delete_by_doc_id(self, doc_id: str):
1447
- """Synchronous version of adelete"""
1448
- return asyncio.run(self.adelete_by_doc_id(doc_id))
1449
-
1450
  async def get_entity_info(
1451
  self, entity_name: str, include_vector_data: bool = False
1452
- ):
1453
  """Get detailed information of an entity
1454
 
1455
  Args:
@@ -1469,7 +1466,7 @@ class LightRAG:
1469
  node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
1470
  source_id = node_data.get("source_id") if node_data else None
1471
 
1472
- result = {
1473
  "entity_name": entity_name,
1474
  "source_id": source_id,
1475
  "graph_data": node_data,
@@ -1483,21 +1480,6 @@ class LightRAG:
1483
 
1484
  return result
1485
 
1486
- def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
1487
- """Synchronous version of getting entity information
1488
-
1489
- Args:
1490
- entity_name: Entity name (no need for quotes)
1491
- include_vector_data: Whether to include data from the vector database
1492
- """
1493
- try:
1494
- import tracemalloc
1495
-
1496
- tracemalloc.start()
1497
- return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
1498
- finally:
1499
- tracemalloc.stop()
1500
-
1501
  async def get_relation_info(
1502
  self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1503
  ):
@@ -1525,7 +1507,7 @@ class LightRAG:
1525
  )
1526
  source_id = edge_data.get("source_id") if edge_data else None
1527
 
1528
- result = {
1529
  "src_entity": src_entity,
1530
  "tgt_entity": tgt_entity,
1531
  "source_id": source_id,
@@ -1539,23 +1521,3 @@ class LightRAG:
1539
  result["vector_data"] = vector_data[0] if vector_data else None
1540
 
1541
  return result
1542
-
1543
- def get_relation_info_sync(
1544
- self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1545
- ):
1546
- """Synchronous version of getting relationship information
1547
-
1548
- Args:
1549
- src_entity: Source entity name (no need for quotes)
1550
- tgt_entity: Target entity name (no need for quotes)
1551
- include_vector_data: Whether to include data from the vector database
1552
- """
1553
- try:
1554
- import tracemalloc
1555
-
1556
- tracemalloc.start()
1557
- return asyncio.run(
1558
- self.get_relation_info(src_entity, tgt_entity, include_vector_data)
1559
- )
1560
- finally:
1561
- tracemalloc.stop()
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import os
5
  import configparser
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
+ from typing import Any, AsyncIterator, Callable, Iterator, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
 
78
  "FaissVectorDBStorage",
79
  "QdrantVectorDBStorage",
80
  "OracleVectorDBStorage",
81
+ "MongoVectorDBStorage",
82
  ],
83
  "required_methods": ["query", "upsert"],
84
  },
 
89
  "PGDocStatusStorage",
90
  "MongoDocStatusStorage",
91
  ],
92
+ "required_methods": ["get_docs_by_status"],
93
  },
94
  }
95
 
96
  # Storage implementation environment variable without default value
97
+ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
98
  # KV Storage Implementations
99
  "JsonKVStorage": [],
100
  "MongoKVStorage": [],
 
143
  "ORACLE_PASSWORD",
144
  "ORACLE_CONFIG_DIR",
145
  ],
146
+ "MongoVectorDBStorage": [],
147
  # Document Status Storage Implementations
148
  "JsonDocStatusStorage": [],
149
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
 
164
  "MongoKVStorage": ".kg.mongo_impl",
165
  "MongoDocStatusStorage": ".kg.mongo_impl",
166
  "MongoGraphStorage": ".kg.mongo_impl",
167
+ "MongoVectorDBStorage": ".kg.mongo_impl",
168
  "RedisKVStorage": ".kg.redis_impl",
169
  "ChromaVectorDBStorage": ".kg.chroma_impl",
170
  "TiDBKVStorage": ".kg.tidb_impl",
 
181
  }
182
 
183
 
184
+ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
185
  """Lazily import a class from an external module based on the package of the caller."""
186
  # Get the caller's module and package
187
  import inspect
 
190
  module = inspect.getmodule(caller_frame)
191
  package = module.__package__ if module else None
192
 
193
+ def import_class(*args: Any, **kwargs: Any):
194
  import importlib
195
 
196
  module = importlib.import_module(module_name, package=package)
 
230
  """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
231
 
232
  working_dir: str = field(
233
+ default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
234
  )
235
  """Directory where cache and temporary files are stored."""
236
 
 
307
  - random_seed: Seed value for reproducibility.
308
  """
309
 
310
+ embedding_func: EmbeddingFunc | None = None
311
  """Function for computing text embeddings. Must be set before use."""
312
 
313
  embedding_batch_num: int = 32
 
317
  """Maximum number of concurrent embedding function calls."""
318
 
319
  # LLM Configuration
320
+ llm_model_func: Callable[..., object] | None = None
321
  """Function for interacting with the large language model (LLM). Must be set before use."""
322
 
323
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
 
347
 
348
  # Extensions
349
  addon_params: dict[str, Any] = field(default_factory=dict)
 
350
 
351
+ """Dictionary for additional parameters and extensions."""
 
352
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
353
  convert_response_to_json
354
  )
 
357
  chunking_func: Callable[
358
  [
359
  str,
360
+ str | None,
361
  bool,
362
  int,
363
  int,
 
446
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
447
 
448
  # Init LLM
449
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
450
  self.embedding_func
451
  )
452
 
453
  # Initialize all storages
454
+ self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
455
  self._get_storage_class(self.kv_storage)
456
+ ) # type: ignore
457
+ self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
458
  self.vector_storage
459
+ ) # type: ignore
460
+ self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
461
  self.graph_storage
462
+ ) # type: ignore
463
+ self.key_string_value_json_storage_cls = partial( # type: ignore
 
464
  self.key_string_value_json_storage_cls, global_config=global_config
465
  )
466
+ self.vector_db_storage_cls = partial( # type: ignore
 
467
  self.vector_db_storage_cls, global_config=global_config
468
  )
469
+ self.graph_storage_cls = partial( # type: ignore
 
470
  self.graph_storage_cls, global_config=global_config
471
  )
472
 
473
  # Initialize document status storage
474
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
475
 
476
+ self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
477
  namespace=make_namespace(
478
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
479
  ),
480
  embedding_func=self.embedding_func,
481
  )
482
 
483
+ self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
484
  namespace=make_namespace(
485
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
486
  ),
487
  embedding_func=self.embedding_func,
488
  )
489
+ self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
490
  namespace=make_namespace(
491
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
492
  ),
493
  embedding_func=self.embedding_func,
494
  )
495
+ self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
496
  namespace=make_namespace(
497
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
498
  ),
499
  embedding_func=self.embedding_func,
500
  )
501
 
502
+ self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
503
  namespace=make_namespace(
504
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
505
  ),
506
  embedding_func=self.embedding_func,
507
  meta_fields={"entity_name"},
508
  )
509
+ self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
510
  namespace=make_namespace(
511
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
512
  ),
513
  embedding_func=self.embedding_func,
514
  meta_fields={"src_id", "tgt_id"},
515
  )
516
+ self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
517
  namespace=make_namespace(
518
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
519
  ),
 
527
  embedding_func=None,
528
  )
529
 
 
530
  if self.llm_response_cache and hasattr(
531
  self.llm_response_cache, "global_config"
532
  ):
533
  hashing_kv = self.llm_response_cache
534
  else:
535
+ hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
536
  namespace=make_namespace(
537
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
538
  ),
 
541
 
542
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
543
  partial(
544
+ self.llm_model_func, # type: ignore
545
  hashing_kv=hashing_kv,
546
  **self.llm_model_kwargs,
547
  )
 
558
  node_label=nodel_label, max_depth=max_depth
559
  )
560
 
561
+ def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
562
  import_path = STORAGES[storage_name]
563
  storage_class = lazy_external_import(import_path, storage_name)
564
  return storage_class
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def insert(
567
  self,
568
+ input: str | list[str],
569
  split_by_character: str | None = None,
570
  split_by_character_only: bool = False,
571
  ):
572
  """Sync Insert documents with checkpoint support
573
 
574
  Args:
575
+ input: Single document string or list of document strings
576
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
 
577
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
578
  split_by_character is None, this parameter is ignored.
579
  """
580
  loop = always_get_an_event_loop()
581
  return loop.run_until_complete(
582
+ self.ainsert(input, split_by_character, split_by_character_only)
583
  )
584
 
585
  async def ainsert(
586
  self,
587
+ input: str | list[str],
588
  split_by_character: str | None = None,
589
  split_by_character_only: bool = False,
590
  ):
591
  """Async Insert documents with checkpoint support
592
 
593
  Args:
594
+ input: Single document string or list of document strings
595
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
 
596
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
597
  split_by_character is None, this parameter is ignored.
598
  """
599
+ await self.apipeline_enqueue_documents(input)
600
  await self.apipeline_process_enqueue_documents(
601
  split_by_character, split_by_character_only
602
  )
 
653
  if update_storage:
654
  await self._insert_done()
655
 
656
+ async def apipeline_enqueue_documents(self, input: str | list[str]):
657
  """
658
  Pipeline for Processing Documents
659
 
 
662
  3. Filter out already processed documents
663
  4. Enqueue document in status
664
  """
665
+ if isinstance(input, str):
666
+ input = [input]
667
 
668
  # 1. Remove duplicate contents from the list
669
+ unique_contents = list(set(doc.strip() for doc in input))
670
 
671
  # 2. Generate document IDs and initial status
672
  new_docs: dict[str, Any] = {
 
715
  # 1. Get all pending, failed, and abnormally terminated processing documents.
716
  to_process_docs: dict[str, DocProcessingStatus] = {}
717
 
718
+ processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
719
  to_process_docs.update(processing_docs)
720
+ failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
721
  to_process_docs.update(failed_docs)
722
+ pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
723
  to_process_docs.update(pendings_docs)
724
 
725
  if not to_process_docs:
 
833
  raise e
834
 
835
  async def _insert_done(self):
836
+ tasks = [
837
+ cast(StorageNameSpace, storage_inst).index_done_callback()
838
+ for storage_inst in [ # type: ignore
839
+ self.full_docs,
840
+ self.text_chunks,
841
+ self.llm_response_cache,
842
+ self.entities_vdb,
843
+ self.relationships_vdb,
844
+ self.chunks_vdb,
845
+ self.chunk_entity_relation_graph,
846
+ ]
847
+ if storage_inst is not None
848
+ ]
849
  await asyncio.gather(*tasks)
850
 
851
+ def insert_custom_kg(self, custom_kg: dict[str, Any]):
852
  loop = always_get_an_event_loop()
853
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
854
 
855
+ async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
856
  update_storage = False
857
  try:
858
  # Insert chunks into vector storage
859
+ all_chunks_data: dict[str, dict[str, str]] = {}
860
+ chunk_to_source_map: dict[str, str] = {}
861
+ for chunk_data in custom_kg.get("chunks", {}):
862
  chunk_content = chunk_data["content"]
863
  source_id = chunk_data["source_id"]
864
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
 
868
  chunk_to_source_map[source_id] = chunk_id
869
  update_storage = True
870
 
871
+ if all_chunks_data:
872
  await self.chunks_vdb.upsert(all_chunks_data)
873
+ if all_chunks_data:
874
  await self.text_chunks.upsert(all_chunks_data)
875
 
876
  # Insert entities into knowledge graph
877
+ all_entities_data: list[dict[str, str]] = []
878
  for entity_data in custom_kg.get("entities", []):
879
  entity_name = f'"{entity_data["entity_name"].upper()}"'
880
  entity_type = entity_data.get("entity_type", "UNKNOWN")
 
890
  )
891
 
892
  # Prepare node data
893
+ node_data: dict[str, str] = {
894
  "entity_type": entity_type,
895
  "description": description,
896
  "source_id": source_id,
 
904
  update_storage = True
905
 
906
  # Insert relationships into knowledge graph
907
+ all_relationships_data: list[dict[str, str]] = []
908
  for relationship_data in custom_kg.get("relationships", []):
909
  src_id = f'"{relationship_data["src_id"].upper()}"'
910
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
 
946
  "source_id": source_id,
947
  },
948
  )
949
+ edge_data: dict[str, str] = {
950
  "src_id": src_id,
951
  "tgt_id": tgt_id,
952
  "description": description,
 
956
  update_storage = True
957
 
958
  # Insert entities into vector storage if needed
959
+ data_for_vdb = {
960
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
961
+ "content": dp["entity_name"] + dp["description"],
962
+ "entity_name": dp["entity_name"],
 
 
 
963
  }
964
+ for dp in all_entities_data
965
+ }
966
+ await self.entities_vdb.upsert(data_for_vdb)
967
 
968
  # Insert relationships into vector storage if needed
969
+ data_for_vdb = {
970
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
971
+ "src_id": dp["src_id"],
972
+ "tgt_id": dp["tgt_id"],
973
+ "content": dp["keywords"]
974
+ + dp["src_id"]
975
+ + dp["tgt_id"]
976
+ + dp["description"],
 
 
 
977
  }
978
+ for dp in all_relationships_data
979
+ }
980
+ await self.relationships_vdb.upsert(data_for_vdb)
981
+
982
  finally:
983
  if update_storage:
984
  await self._insert_done()
985
 
986
+ def query(
987
+ self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
988
+ ) -> str | Iterator[str]:
989
+ """
990
+ Perform a sync query.
991
+
992
+ Args:
993
+ query (str): The query to be executed.
994
+ param (QueryParam): Configuration parameters for query execution.
995
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
996
+
997
+ Returns:
998
+ str: The result of the query execution.
999
+ """
1000
  loop = always_get_an_event_loop()
1001
+
1002
+ return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
1003
 
1004
  async def aquery(
1005
+ self,
1006
+ query: str,
1007
+ param: QueryParam = QueryParam(),
1008
+ prompt: str | None = None,
1009
+ ) -> str | AsyncIterator[str]:
1010
+ """
1011
+ Perform a async query.
1012
+
1013
+ Args:
1014
+ query (str): The query to be executed.
1015
+ param (QueryParam): Configuration parameters for query execution.
1016
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1017
+
1018
+ Returns:
1019
+ str: The result of the query execution.
1020
+ """
1021
  if param.mode in ["local", "global", "hybrid"]:
1022
  response = await kg_query(
1023
  query,
 
1097
 
1098
  async def aquery_with_separate_keyword_extraction(
1099
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1100
+ ) -> str | AsyncIterator[str]:
1101
  """
1102
  1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
1103
  2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
 
1120
  ),
1121
  )
1122
 
1123
+ param.hl_keywords = hl_keywords
1124
+ param.ll_keywords = ll_keywords
1125
 
1126
  # ---------------------
1127
  # STEP 2: Final Query Logic
 
1149
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1150
  ),
1151
  global_config=asdict(self),
1152
+ embedding_func=self.embedding_func,
1153
  ),
1154
  )
1155
  elif param.mode == "naive":
 
1198
  return response
1199
 
1200
  async def _query_done(self):
1201
+ await self.llm_response_cache.index_done_callback()
 
 
 
 
 
1202
 
1203
  def delete_by_entity(self, entity_name: str):
1204
  loop = always_get_an_event_loop()
 
1220
  logger.error(f"Error while deleting entity '{entity_name}': {e}")
1221
 
1222
  async def _delete_by_entity_done(self):
1223
+ await asyncio.gather(
1224
+ *[
1225
+ cast(StorageNameSpace, storage_inst).index_done_callback()
1226
+ for storage_inst in [ # type: ignore
1227
+ self.entities_vdb,
1228
+ self.relationships_vdb,
1229
+ self.chunk_entity_relation_graph,
1230
+ ]
1231
+ ]
1232
+ )
1233
 
1234
  def _get_content_summary(self, content: str, max_length: int = 100) -> str:
1235
  """Get summary of document content
 
1254
  """
1255
  return await self.doc_status.get_status_counts()
1256
 
1257
+ async def adelete_by_doc_id(self, doc_id: str) -> None:
1258
  """Delete a document and all its related data
1259
 
1260
  Args:
 
1271
 
1272
  # 2. Get all related chunks
1273
  chunks = await self.text_chunks.get_by_id(doc_id)
1274
+ if not chunks:
1275
+ return
1276
+
1277
  chunk_ids = list(chunks.keys())
1278
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1279
 
 
1444
  except Exception as e:
1445
  logger.error(f"Error while deleting document {doc_id}: {e}")
1446
 
 
 
 
 
1447
  async def get_entity_info(
1448
  self, entity_name: str, include_vector_data: bool = False
1449
+ ) -> dict[str, str | None | dict[str, str]]:
1450
  """Get detailed information of an entity
1451
 
1452
  Args:
 
1466
  node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
1467
  source_id = node_data.get("source_id") if node_data else None
1468
 
1469
+ result: dict[str, str | None | dict[str, str]] = {
1470
  "entity_name": entity_name,
1471
  "source_id": source_id,
1472
  "graph_data": node_data,
 
1480
 
1481
  return result
1482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1483
  async def get_relation_info(
1484
  self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1485
  ):
 
1507
  )
1508
  source_id = edge_data.get("source_id") if edge_data else None
1509
 
1510
+ result: dict[str, str | None | dict[str, str]] = {
1511
  "src_entity": src_entity,
1512
  "tgt_entity": tgt_entity,
1513
  "source_id": source_id,
 
1521
  result["vector_data"] = vector_data[0] if vector_data else None
1522
 
1523
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/llm.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import List, Dict, Callable, Any
 
 
2
  from pydantic import BaseModel, Field
3
 
4
 
@@ -23,7 +25,7 @@ class Model(BaseModel):
23
  ...,
24
  description="A function that generates the response from the llm. The response must be a string",
25
  )
26
- kwargs: Dict[str, Any] = Field(
27
  ...,
28
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
29
  )
@@ -57,7 +59,7 @@ class MultiModel:
57
  ```
58
  """
59
 
60
- def __init__(self, models: List[Model]):
61
  self._models = models
62
  self._current_model = 0
63
 
@@ -66,7 +68,11 @@ class MultiModel:
66
  return self._models[self._current_model]
67
 
68
  async def llm_model_func(
69
- self, prompt, system_prompt=None, history_messages=[], **kwargs
 
 
 
 
70
  ) -> str:
71
  kwargs.pop("model", None) # stop from overwriting the custom model name
72
  kwargs.pop("keyword_extraction", None)
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, Any
4
  from pydantic import BaseModel, Field
5
 
6
 
 
25
  ...,
26
  description="A function that generates the response from the llm. The response must be a string",
27
  )
28
+ kwargs: dict[str, Any] = Field(
29
  ...,
30
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
31
  )
 
59
  ```
60
  """
61
 
62
+ def __init__(self, models: list[Model]):
63
  self._models = models
64
  self._current_model = 0
65
 
 
68
  return self._models[self._current_model]
69
 
70
  async def llm_model_func(
71
+ self,
72
+ prompt: str,
73
+ system_prompt: str | None = None,
74
+ history_messages: list[dict[str, Any]] = [],
75
+ **kwargs: Any,
76
  ) -> str:
77
  kwargs.pop("model", None) # stop from overwriting the custom model name
78
  kwargs.pop("keyword_extraction", None)
lightrag/namespace.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import Iterable
2
 
3
 
 
1
+ from __future__ import annotations
2
+
3
  from typing import Iterable
4
 
5
 
lightrag/operate.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import asyncio
2
  import json
3
  import re
4
  from tqdm.asyncio import tqdm as tqdm_async
5
- from typing import Any, Union
6
  from collections import Counter, defaultdict
7
  from .utils import (
8
  logger,
@@ -36,7 +38,7 @@ import time
36
 
37
  def chunking_by_token_size(
38
  content: str,
39
- split_by_character: Union[str, None] = None,
40
  split_by_character_only: bool = False,
41
  overlap_token_size: int = 128,
42
  max_token_size: int = 1024,
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
237
 
238
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
239
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
240
- already_weights.append(already_edge["weight"])
241
- already_source_ids.extend(
242
- split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
243
- )
244
- already_description.append(already_edge["description"])
245
- already_keywords.extend(
246
- split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
247
- )
 
 
 
 
 
 
 
 
 
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
250
  description = GRAPH_FIELD_SEP.join(
251
- sorted(set([dp["description"] for dp in edges_data] + already_description))
 
 
 
 
 
252
  )
253
  keywords = GRAPH_FIELD_SEP.join(
254
- sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
 
 
 
 
 
255
  )
256
  source_id = GRAPH_FIELD_SEP.join(
257
- set([dp["source_id"] for dp in edges_data] + already_source_ids)
 
 
 
258
  )
 
259
  for need_insert_id in [src_id, tgt_id]:
260
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
261
  await knowledge_graph_inst.upsert_node(
@@ -295,9 +337,9 @@ async def extract_entities(
295
  knowledge_graph_inst: BaseGraphStorage,
296
  entity_vdb: BaseVectorStorage,
297
  relationships_vdb: BaseVectorStorage,
298
- global_config: dict,
299
- llm_response_cache: BaseKVStorage = None,
300
- ) -> Union[BaseGraphStorage, None]:
301
  use_llm_func: callable = global_config["llm_model_func"]
302
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
303
  enable_llm_cache_for_entity_extract: bool = global_config[
@@ -563,15 +605,15 @@ async def extract_entities(
563
 
564
 
565
  async def kg_query(
566
- query,
567
  knowledge_graph_inst: BaseGraphStorage,
568
  entities_vdb: BaseVectorStorage,
569
  relationships_vdb: BaseVectorStorage,
570
  text_chunks_db: BaseKVStorage,
571
  query_param: QueryParam,
572
- global_config: dict,
573
- hashing_kv: BaseKVStorage = None,
574
- prompt: str = "",
575
  ) -> str:
576
  # Handle cache
577
  use_model_func = global_config["llm_model_func"]
@@ -681,8 +723,8 @@ async def kg_query(
681
  async def extract_keywords_only(
682
  text: str,
683
  param: QueryParam,
684
- global_config: dict,
685
- hashing_kv: BaseKVStorage = None,
686
  ) -> tuple[list[str], list[str]]:
687
  """
688
  Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -778,9 +820,9 @@ async def mix_kg_vector_query(
778
  chunks_vdb: BaseVectorStorage,
779
  text_chunks_db: BaseKVStorage,
780
  query_param: QueryParam,
781
- global_config: dict,
782
- hashing_kv: BaseKVStorage = None,
783
- ) -> str:
784
  """
785
  Hybrid retrieval implementation combining knowledge graph and vector search.
786
 
@@ -1499,13 +1541,13 @@ def combine_contexts(entities, relationships, sources):
1499
 
1500
 
1501
  async def naive_query(
1502
- query,
1503
  chunks_vdb: BaseVectorStorage,
1504
  text_chunks_db: BaseKVStorage,
1505
  query_param: QueryParam,
1506
- global_config: dict,
1507
- hashing_kv: BaseKVStorage = None,
1508
- ):
1509
  # Handle cache
1510
  use_model_func = global_config["llm_model_func"]
1511
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1606,9 +1648,9 @@ async def kg_query_with_keywords(
1606
  relationships_vdb: BaseVectorStorage,
1607
  text_chunks_db: BaseKVStorage,
1608
  query_param: QueryParam,
1609
- global_config: dict,
1610
- hashing_kv: BaseKVStorage = None,
1611
- ) -> str:
1612
  """
1613
  Refactored kg_query that does NOT extract keywords by itself.
1614
  It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import json
5
  import re
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ from typing import Any, AsyncIterator
8
  from collections import Counter, defaultdict
9
  from .utils import (
10
  logger,
 
38
 
39
  def chunking_by_token_size(
40
  content: str,
41
+ split_by_character: str | None = None,
42
  split_by_character_only: bool = False,
43
  overlap_token_size: int = 128,
44
  max_token_size: int = 1024,
 
239
 
240
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
241
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
242
+ # Handle the case where get_edge returns None or missing fields
243
+ if already_edge:
244
+ # Get weight with default 0.0 if missing
245
+ if "weight" in already_edge:
246
+ already_weights.append(already_edge["weight"])
247
+ else:
248
+ logger.warning(
249
+ f"Edge between {src_id} and {tgt_id} missing weight field"
250
+ )
251
+ already_weights.append(0.0)
252
+
253
+ # Get source_id with empty string default if missing or None
254
+ if "source_id" in already_edge and already_edge["source_id"] is not None:
255
+ already_source_ids.extend(
256
+ split_string_by_multi_markers(
257
+ already_edge["source_id"], [GRAPH_FIELD_SEP]
258
+ )
259
+ )
260
 
261
+ # Get description with empty string default if missing or None
262
+ if (
263
+ "description" in already_edge
264
+ and already_edge["description"] is not None
265
+ ):
266
+ already_description.append(already_edge["description"])
267
+
268
+ # Get keywords with empty string default if missing or None
269
+ if "keywords" in already_edge and already_edge["keywords"] is not None:
270
+ already_keywords.extend(
271
+ split_string_by_multi_markers(
272
+ already_edge["keywords"], [GRAPH_FIELD_SEP]
273
+ )
274
+ )
275
+
276
+ # Process edges_data with None checks
277
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
278
  description = GRAPH_FIELD_SEP.join(
279
+ sorted(
280
+ set(
281
+ [dp["description"] for dp in edges_data if dp.get("description")]
282
+ + already_description
283
+ )
284
+ )
285
  )
286
  keywords = GRAPH_FIELD_SEP.join(
287
+ sorted(
288
+ set(
289
+ [dp["keywords"] for dp in edges_data if dp.get("keywords")]
290
+ + already_keywords
291
+ )
292
+ )
293
  )
294
  source_id = GRAPH_FIELD_SEP.join(
295
+ set(
296
+ [dp["source_id"] for dp in edges_data if dp.get("source_id")]
297
+ + already_source_ids
298
+ )
299
  )
300
+
301
  for need_insert_id in [src_id, tgt_id]:
302
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
303
  await knowledge_graph_inst.upsert_node(
 
337
  knowledge_graph_inst: BaseGraphStorage,
338
  entity_vdb: BaseVectorStorage,
339
  relationships_vdb: BaseVectorStorage,
340
+ global_config: dict[str, str],
341
+ llm_response_cache: BaseKVStorage | None = None,
342
+ ) -> BaseGraphStorage | None:
343
  use_llm_func: callable = global_config["llm_model_func"]
344
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
345
  enable_llm_cache_for_entity_extract: bool = global_config[
 
605
 
606
 
607
  async def kg_query(
608
+ query: str,
609
  knowledge_graph_inst: BaseGraphStorage,
610
  entities_vdb: BaseVectorStorage,
611
  relationships_vdb: BaseVectorStorage,
612
  text_chunks_db: BaseKVStorage,
613
  query_param: QueryParam,
614
+ global_config: dict[str, str],
615
+ hashing_kv: BaseKVStorage | None = None,
616
+ prompt: str | None = None,
617
  ) -> str:
618
  # Handle cache
619
  use_model_func = global_config["llm_model_func"]
 
723
  async def extract_keywords_only(
724
  text: str,
725
  param: QueryParam,
726
+ global_config: dict[str, str],
727
+ hashing_kv: BaseKVStorage | None = None,
728
  ) -> tuple[list[str], list[str]]:
729
  """
730
  Extract high-level and low-level keywords from the given 'text' using the LLM.
 
820
  chunks_vdb: BaseVectorStorage,
821
  text_chunks_db: BaseKVStorage,
822
  query_param: QueryParam,
823
+ global_config: dict[str, str],
824
+ hashing_kv: BaseKVStorage | None = None,
825
+ ) -> str | AsyncIterator[str]:
826
  """
827
  Hybrid retrieval implementation combining knowledge graph and vector search.
828
 
 
1541
 
1542
 
1543
  async def naive_query(
1544
+ query: str,
1545
  chunks_vdb: BaseVectorStorage,
1546
  text_chunks_db: BaseKVStorage,
1547
  query_param: QueryParam,
1548
+ global_config: dict[str, str],
1549
+ hashing_kv: BaseKVStorage | None = None,
1550
+ ) -> str | AsyncIterator[str]:
1551
  # Handle cache
1552
  use_model_func = global_config["llm_model_func"]
1553
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
 
1648
  relationships_vdb: BaseVectorStorage,
1649
  text_chunks_db: BaseKVStorage,
1650
  query_param: QueryParam,
1651
+ global_config: dict[str, str],
1652
+ hashing_kv: BaseKVStorage | None = None,
1653
+ ) -> str | AsyncIterator[str]:
1654
  """
1655
  Refactored kg_query that does NOT extract keywords by itself.
1656
  It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
lightrag/prompt.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  GRAPH_FIELD_SEP = "<SEP>"
2
 
3
  PROMPTS = {}
 
1
+ from __future__ import annotations
2
+
3
  GRAPH_FIELD_SEP = "<SEP>"
4
 
5
  PROMPTS = {}
lightrag/types.py CHANGED
@@ -1,26 +1,28 @@
 
 
1
  from pydantic import BaseModel
2
- from typing import List, Dict, Any
3
 
4
 
5
  class GPTKeywordExtractionFormat(BaseModel):
6
- high_level_keywords: List[str]
7
- low_level_keywords: List[str]
8
 
9
 
10
  class KnowledgeGraphNode(BaseModel):
11
  id: str
12
- labels: List[str]
13
- properties: Dict[str, Any] # anything else goes here
14
 
15
 
16
  class KnowledgeGraphEdge(BaseModel):
17
  id: str
18
- type: str
19
  source: str # id of source node
20
  target: str # id of target node
21
- properties: Dict[str, Any] # anything else goes here
22
 
23
 
24
  class KnowledgeGraph(BaseModel):
25
- nodes: List[KnowledgeGraphNode] = []
26
- edges: List[KnowledgeGraphEdge] = []
 
1
+ from __future__ import annotations
2
+
3
  from pydantic import BaseModel
4
+ from typing import Any, Optional
5
 
6
 
7
  class GPTKeywordExtractionFormat(BaseModel):
8
+ high_level_keywords: list[str]
9
+ low_level_keywords: list[str]
10
 
11
 
12
  class KnowledgeGraphNode(BaseModel):
13
  id: str
14
+ labels: list[str]
15
+ properties: dict[str, Any] # anything else goes here
16
 
17
 
18
  class KnowledgeGraphEdge(BaseModel):
19
  id: str
20
+ type: Optional[str]
21
  source: str # id of source node
22
  target: str # id of target node
23
+ properties: dict[str, Any] # anything else goes here
24
 
25
 
26
  class KnowledgeGraph(BaseModel):
27
+ nodes: list[KnowledgeGraphNode] = []
28
+ edges: list[KnowledgeGraphEdge] = []
lightrag/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import html
3
  import io
@@ -9,7 +11,7 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
  import bs4
15
 
@@ -67,12 +69,12 @@ class EmbeddingFunc:
67
 
68
  @dataclass
69
  class ReasoningResponse:
70
- reasoning_content: str
71
  response_content: str
72
  tag: str
73
 
74
 
75
- def locate_json_string_body_from_string(content: str) -> Union[str, None]:
76
  """Locate the JSON string body from a string"""
77
  try:
78
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
109
  raise e from None
110
 
111
 
112
- def compute_args_hash(*args, cache_type: str = None) -> str:
113
  """Compute a hash for the given arguments.
114
  Args:
115
  *args: Arguments to hash
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
128
  return hashlib.md5(args_str.encode()).hexdigest()
129
 
130
 
131
- def compute_mdhash_id(content, prefix: str = ""):
 
 
 
 
 
132
  return prefix + md5(content.encode()).hexdigest()
133
 
134
 
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
215
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
216
 
217
 
218
- def is_float_regex(value):
219
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
220
 
221
 
222
- def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
 
 
223
  """Truncate a list of data by token size"""
224
  if max_token_size <= 0:
225
  return []
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
231
  return list_data
232
 
233
 
234
- def list_of_list_to_csv(data: List[List[str]]) -> str:
235
  output = io.StringIO()
236
  writer = csv.writer(
237
  output,
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
244
  return output.getvalue()
245
 
246
 
247
- def csv_string_to_list(csv_string: str) -> List[List[str]]:
248
  # Clean the string by removing NUL characters
249
  cleaned_string = csv_string.replace("\0", "")
250
 
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
329
  return None
330
 
331
 
332
- def process_combine_contexts(hl, ll):
333
  header = None
334
  list_hl = csv_string_to_list(hl.strip())
335
  list_ll = csv_string_to_list(ll.strip())
@@ -375,7 +384,7 @@ async def get_best_cached_response(
375
  llm_func=None,
376
  original_prompt=None,
377
  cache_type=None,
378
- ) -> Union[str, None]:
379
  logger.debug(
380
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
381
  )
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
479
  return dot_product / (norm1 * norm2)
480
 
481
 
482
- def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
483
  """Quantize embedding to specified bits"""
484
  # Convert list to numpy array if needed
485
  if isinstance(embedding, list):
@@ -570,9 +579,9 @@ class CacheData:
570
  args_hash: str
571
  content: str
572
  prompt: str
573
- quantized: Optional[np.ndarray] = None
574
- min_val: Optional[float] = None
575
- max_val: Optional[float] = None
576
  mode: str = "default"
577
  cache_type: str = "query"
578
 
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
635
  return False
636
 
637
 
638
- def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
 
 
639
  """
640
  Process conversation history to get the specified number of complete turns.
641
 
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
647
  Formatted string of the conversation history
648
  """
649
  # Group messages into turns
650
- turns = []
651
- messages = []
652
 
653
  # First, filter out keyword extraction messages
654
  for msg in conversation_history:
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
682
  turns = turns[-num_turns:]
683
 
684
  # Format the turns into a string
685
- formatted_turns = []
686
  for turn in turns:
687
  formatted_turns.extend(
688
  [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import html
5
  import io
 
11
  from dataclasses import dataclass
12
  from functools import wraps
13
  from hashlib import md5
14
+ from typing import Any, Callable
15
  import xml.etree.ElementTree as ET
16
  import bs4
17
 
 
69
 
70
  @dataclass
71
  class ReasoningResponse:
72
+ reasoning_content: str | None
73
  response_content: str
74
  tag: str
75
 
76
 
77
+ def locate_json_string_body_from_string(content: str) -> str | None:
78
  """Locate the JSON string body from a string"""
79
  try:
80
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
 
111
  raise e from None
112
 
113
 
114
+ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
115
  """Compute a hash for the given arguments.
116
  Args:
117
  *args: Arguments to hash
 
130
  return hashlib.md5(args_str.encode()).hexdigest()
131
 
132
 
133
+ def compute_mdhash_id(content: str, prefix: str = "") -> str:
134
+ """
135
+ Compute a unique ID for a given content string.
136
+
137
+ The ID is a combination of the given prefix and the MD5 hash of the content string.
138
+ """
139
  return prefix + md5(content.encode()).hexdigest()
140
 
141
 
 
222
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
223
 
224
 
225
+ def is_float_regex(value: str) -> bool:
226
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
227
 
228
 
229
+ def truncate_list_by_token_size(
230
+ list_data: list[Any], key: Callable[[Any], str], max_token_size: int
231
+ ) -> list[int]:
232
  """Truncate a list of data by token size"""
233
  if max_token_size <= 0:
234
  return []
 
240
  return list_data
241
 
242
 
243
+ def list_of_list_to_csv(data: list[list[str]]) -> str:
244
  output = io.StringIO()
245
  writer = csv.writer(
246
  output,
 
253
  return output.getvalue()
254
 
255
 
256
+ def csv_string_to_list(csv_string: str) -> list[list[str]]:
257
  # Clean the string by removing NUL characters
258
  cleaned_string = csv_string.replace("\0", "")
259
 
 
338
  return None
339
 
340
 
341
+ def process_combine_contexts(hl: str, ll: str):
342
  header = None
343
  list_hl = csv_string_to_list(hl.strip())
344
  list_ll = csv_string_to_list(ll.strip())
 
384
  llm_func=None,
385
  original_prompt=None,
386
  cache_type=None,
387
+ ) -> str | None:
388
  logger.debug(
389
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
390
  )
 
488
  return dot_product / (norm1 * norm2)
489
 
490
 
491
+ def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
492
  """Quantize embedding to specified bits"""
493
  # Convert list to numpy array if needed
494
  if isinstance(embedding, list):
 
579
  args_hash: str
580
  content: str
581
  prompt: str
582
+ quantized: np.ndarray | None = None
583
+ min_val: float | None = None
584
+ max_val: float | None = None
585
  mode: str = "default"
586
  cache_type: str = "query"
587
 
 
644
  return False
645
 
646
 
647
+ def get_conversation_turns(
648
+ conversation_history: list[dict[str, Any]], num_turns: int
649
+ ) -> str:
650
  """
651
  Process conversation history to get the specified number of complete turns.
652
 
 
658
  Formatted string of the conversation history
659
  """
660
  # Group messages into turns
661
+ turns: list[list[dict[str, Any]]] = []
662
+ messages: list[dict[str, Any]] = []
663
 
664
  # First, filter out keyword extraction messages
665
  for msg in conversation_history:
 
693
  turns = turns[-num_turns:]
694
 
695
  # Format the turns into a string
696
+ formatted_turns: list[str] = []
697
  for turn in turns:
698
  formatted_turns.extend(
699
  [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]