gzdaniel commited on
Commit
917d3f4
Β·
2 Parent(s): 3b2a14d a9c0b40

Merge branch 'main' into add-Memgraph-graph-db

Browse files
env.example CHANGED
@@ -58,6 +58,8 @@ SUMMARY_LANGUAGE=English
58
  # FORCE_LLM_SUMMARY_ON_MERGE=6
59
  ### Max tokens for entity/relations description after merge
60
  # MAX_TOKEN_SUMMARY=500
 
 
61
 
62
  ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
63
  # MAX_PARALLEL_INSERT=2
@@ -112,15 +114,6 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
112
  # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
113
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
114
 
115
- ### TiDB Configuration (Deprecated)
116
- # TIDB_HOST=localhost
117
- # TIDB_PORT=4000
118
- # TIDB_USER=your_username
119
- # TIDB_PASSWORD='your_password'
120
- # TIDB_DATABASE=your_database
121
- ### separating all data from difference Lightrag instances(deprecating)
122
- # TIDB_WORKSPACE=default
123
-
124
  ### PostgreSQL Configuration
125
  POSTGRES_HOST=localhost
126
  POSTGRES_PORT=5432
@@ -128,7 +121,7 @@ POSTGRES_USER=your_username
128
  POSTGRES_PASSWORD='your_password'
129
  POSTGRES_DATABASE=your_database
130
  POSTGRES_MAX_CONNECTIONS=12
131
- ### separating all data from difference Lightrag instances(deprecating)
132
  # POSTGRES_WORKSPACE=default
133
 
134
  ### Neo4j Configuration
@@ -144,14 +137,15 @@ NEO4J_PASSWORD='your_password'
144
  # AGE_POSTGRES_PORT=8529
145
 
146
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
147
- ### AGE_GRAPH_NAME is precated
148
  # AGE_GRAPH_NAME=lightrag
149
 
150
  ### MongoDB Configuration
151
  MONGO_URI=mongodb://root:root@localhost:27017/
152
  MONGO_DATABASE=LightRAG
153
  ### separating all data from difference Lightrag instances(deprecating)
154
- # MONGODB_GRAPH=false
 
155
 
156
  ### Milvus Configuration
157
  MILVUS_URI=http://localhost:19530
 
58
  # FORCE_LLM_SUMMARY_ON_MERGE=6
59
  ### Max tokens for entity/relations description after merge
60
  # MAX_TOKEN_SUMMARY=500
61
+ ### Maximum number of entity extraction attempts for ambiguous content
62
+ # MAX_GLEANING=1
63
 
64
  ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
65
  # MAX_PARALLEL_INSERT=2
 
114
  # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage
115
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
116
 
 
 
 
 
 
 
 
 
 
117
  ### PostgreSQL Configuration
118
  POSTGRES_HOST=localhost
119
  POSTGRES_PORT=5432
 
121
  POSTGRES_PASSWORD='your_password'
122
  POSTGRES_DATABASE=your_database
123
  POSTGRES_MAX_CONNECTIONS=12
124
+ ### separating all data from difference Lightrag instances
125
  # POSTGRES_WORKSPACE=default
126
 
127
  ### Neo4j Configuration
 
137
  # AGE_POSTGRES_PORT=8529
138
 
139
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
140
+ ### AGE_GRAPH_NAME is deprecated
141
  # AGE_GRAPH_NAME=lightrag
142
 
143
  ### MongoDB Configuration
144
  MONGO_URI=mongodb://root:root@localhost:27017/
145
  MONGO_DATABASE=LightRAG
146
  ### separating all data from difference Lightrag instances(deprecating)
147
+ ### separating all data from difference Lightrag instances
148
+ # MONGODB_WORKSPACE=default
149
 
150
  ### Milvus Configuration
151
  MILVUS_URI=http://localhost:19530
examples/raganything_example.py CHANGED
@@ -11,9 +11,74 @@ This example shows how to:
11
  import os
12
  import argparse
13
  import asyncio
 
 
 
 
 
 
 
 
 
14
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
15
- from lightrag.utils import EmbeddingFunc
16
- from raganything.raganything import RAGAnything
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  async def process_with_rag(
@@ -31,15 +96,21 @@ async def process_with_rag(
31
  output_dir: Output directory for RAG results
32
  api_key: OpenAI API key
33
  base_url: Optional base URL for API
 
34
  """
35
  try:
36
- # Initialize RAGAnything
37
- rag = RAGAnything(
38
- working_dir=working_dir,
39
- llm_model_func=lambda prompt,
40
- system_prompt=None,
41
- history_messages=[],
42
- **kwargs: openai_complete_if_cache(
 
 
 
 
 
43
  "gpt-4o-mini",
44
  prompt,
45
  system_prompt=system_prompt,
@@ -47,81 +118,123 @@ async def process_with_rag(
47
  api_key=api_key,
48
  base_url=base_url,
49
  **kwargs,
50
- ),
51
- vision_model_func=lambda prompt,
52
- system_prompt=None,
53
- history_messages=[],
54
- image_data=None,
55
- **kwargs: openai_complete_if_cache(
56
- "gpt-4o",
57
- "",
58
- system_prompt=None,
59
- history_messages=[],
60
- messages=[
61
- {"role": "system", "content": system_prompt}
62
- if system_prompt
63
- else None,
64
- {
65
- "role": "user",
66
- "content": [
67
- {"type": "text", "text": prompt},
68
- {
69
- "type": "image_url",
70
- "image_url": {
71
- "url": f"data:image/jpeg;base64,{image_data}"
72
- },
73
- },
74
- ],
75
- }
76
- if image_data
77
- else {"role": "user", "content": prompt},
78
- ],
79
- api_key=api_key,
80
- base_url=base_url,
81
- **kwargs,
82
  )
83
- if image_data
84
- else openai_complete_if_cache(
85
- "gpt-4o-mini",
86
- prompt,
87
- system_prompt=system_prompt,
88
- history_messages=history_messages,
89
- api_key=api_key,
90
- base_url=base_url,
91
- **kwargs,
92
- ),
93
- embedding_func=EmbeddingFunc(
94
- embedding_dim=3072,
95
- max_token_size=8192,
96
- func=lambda texts: openai_embed(
97
- texts,
98
- model="text-embedding-3-large",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  api_key=api_key,
100
  base_url=base_url,
101
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  ),
103
  )
104
 
 
 
 
 
 
 
 
 
105
  # Process document
106
  await rag.process_document_complete(
107
  file_path=file_path, output_dir=output_dir, parse_method="auto"
108
  )
109
 
110
- # Example queries
111
- queries = [
 
 
 
112
  "What is the main content of the document?",
113
- "Describe the images and figures in the document",
114
- "Tell me about the experimental results and data tables",
115
  ]
116
 
117
- print("\nQuerying processed document:")
118
- for query in queries:
119
- print(f"\nQuery: {query}")
120
- result = await rag.query_with_multimodal(query, mode="hybrid")
121
- print(f"Answer: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  except Exception as e:
124
- print(f"Error processing with RAG: {str(e)}")
 
 
 
125
 
126
 
127
  def main():
@@ -135,12 +248,20 @@ def main():
135
  "--output", "-o", default="./output", help="Output directory path"
136
  )
137
  parser.add_argument(
138
- "--api-key", required=True, help="OpenAI API key for RAG processing"
 
 
139
  )
140
  parser.add_argument("--base-url", help="Optional base URL for API")
141
 
142
  args = parser.parse_args()
143
 
 
 
 
 
 
 
144
  # Create output directory if specified
145
  if args.output:
146
  os.makedirs(args.output, exist_ok=True)
@@ -154,4 +275,12 @@ def main():
154
 
155
 
156
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
157
  main()
 
11
  import os
12
  import argparse
13
  import asyncio
14
+ import logging
15
+ import logging.config
16
+ from pathlib import Path
17
+
18
+ # Add project root directory to Python path
19
+ import sys
20
+
21
+ sys.path.append(str(Path(__file__).parent.parent))
22
+
23
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
24
+ from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
25
+ from raganything import RAGAnything, RAGAnythingConfig
26
+
27
+
28
+ def configure_logging():
29
+ """Configure logging for the application"""
30
+ # Get log directory path from environment variable or use current directory
31
+ log_dir = os.getenv("LOG_DIR", os.getcwd())
32
+ log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log"))
33
+
34
+ print(f"\nRAGAnything example log file: {log_file_path}\n")
35
+ os.makedirs(os.path.dirname(log_dir), exist_ok=True)
36
+
37
+ # Get log file max size and backup count from environment variables
38
+ log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
39
+ log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
40
+
41
+ logging.config.dictConfig(
42
+ {
43
+ "version": 1,
44
+ "disable_existing_loggers": False,
45
+ "formatters": {
46
+ "default": {
47
+ "format": "%(levelname)s: %(message)s",
48
+ },
49
+ "detailed": {
50
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
51
+ },
52
+ },
53
+ "handlers": {
54
+ "console": {
55
+ "formatter": "default",
56
+ "class": "logging.StreamHandler",
57
+ "stream": "ext://sys.stderr",
58
+ },
59
+ "file": {
60
+ "formatter": "detailed",
61
+ "class": "logging.handlers.RotatingFileHandler",
62
+ "filename": log_file_path,
63
+ "maxBytes": log_max_bytes,
64
+ "backupCount": log_backup_count,
65
+ "encoding": "utf-8",
66
+ },
67
+ },
68
+ "loggers": {
69
+ "lightrag": {
70
+ "handlers": ["console", "file"],
71
+ "level": "INFO",
72
+ "propagate": False,
73
+ },
74
+ },
75
+ }
76
+ )
77
+
78
+ # Set the logger level to INFO
79
+ logger.setLevel(logging.INFO)
80
+ # Enable verbose debug if needed
81
+ set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true")
82
 
83
 
84
  async def process_with_rag(
 
96
  output_dir: Output directory for RAG results
97
  api_key: OpenAI API key
98
  base_url: Optional base URL for API
99
+ working_dir: Working directory for RAG storage
100
  """
101
  try:
102
+ # Create RAGAnything configuration
103
+ config = RAGAnythingConfig(
104
+ working_dir=working_dir or "./rag_storage",
105
+ mineru_parse_method="auto",
106
+ enable_image_processing=True,
107
+ enable_table_processing=True,
108
+ enable_equation_processing=True,
109
+ )
110
+
111
+ # Define LLM model function
112
+ def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
113
+ return openai_complete_if_cache(
114
  "gpt-4o-mini",
115
  prompt,
116
  system_prompt=system_prompt,
 
118
  api_key=api_key,
119
  base_url=base_url,
120
  **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
+
123
+ # Define vision model function for image processing
124
+ def vision_model_func(
125
+ prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
126
+ ):
127
+ if image_data:
128
+ return openai_complete_if_cache(
129
+ "gpt-4o",
130
+ "",
131
+ system_prompt=None,
132
+ history_messages=[],
133
+ messages=[
134
+ {"role": "system", "content": system_prompt}
135
+ if system_prompt
136
+ else None,
137
+ {
138
+ "role": "user",
139
+ "content": [
140
+ {"type": "text", "text": prompt},
141
+ {
142
+ "type": "image_url",
143
+ "image_url": {
144
+ "url": f"data:image/jpeg;base64,{image_data}"
145
+ },
146
+ },
147
+ ],
148
+ }
149
+ if image_data
150
+ else {"role": "user", "content": prompt},
151
+ ],
152
  api_key=api_key,
153
  base_url=base_url,
154
+ **kwargs,
155
+ )
156
+ else:
157
+ return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
158
+
159
+ # Define embedding function
160
+ embedding_func = EmbeddingFunc(
161
+ embedding_dim=3072,
162
+ max_token_size=8192,
163
+ func=lambda texts: openai_embed(
164
+ texts,
165
+ model="text-embedding-3-large",
166
+ api_key=api_key,
167
+ base_url=base_url,
168
  ),
169
  )
170
 
171
+ # Initialize RAGAnything with new dataclass structure
172
+ rag = RAGAnything(
173
+ config=config,
174
+ llm_model_func=llm_model_func,
175
+ vision_model_func=vision_model_func,
176
+ embedding_func=embedding_func,
177
+ )
178
+
179
  # Process document
180
  await rag.process_document_complete(
181
  file_path=file_path, output_dir=output_dir, parse_method="auto"
182
  )
183
 
184
+ # Example queries - demonstrating different query approaches
185
+ logger.info("\nQuerying processed document:")
186
+
187
+ # 1. Pure text queries using aquery()
188
+ text_queries = [
189
  "What is the main content of the document?",
190
+ "What are the key topics discussed?",
 
191
  ]
192
 
193
+ for query in text_queries:
194
+ logger.info(f"\n[Text Query]: {query}")
195
+ result = await rag.aquery(query, mode="hybrid")
196
+ logger.info(f"Answer: {result}")
197
+
198
+ # 2. Multimodal query with specific multimodal content using aquery_with_multimodal()
199
+ logger.info(
200
+ "\n[Multimodal Query]: Analyzing performance data in context of document"
201
+ )
202
+ multimodal_result = await rag.aquery_with_multimodal(
203
+ "Compare this performance data with any similar results mentioned in the document",
204
+ multimodal_content=[
205
+ {
206
+ "type": "table",
207
+ "table_data": """Method,Accuracy,Processing_Time
208
+ RAGAnything,95.2%,120ms
209
+ Traditional_RAG,87.3%,180ms
210
+ Baseline,82.1%,200ms""",
211
+ "table_caption": "Performance comparison results",
212
+ }
213
+ ],
214
+ mode="hybrid",
215
+ )
216
+ logger.info(f"Answer: {multimodal_result}")
217
+
218
+ # 3. Another multimodal query with equation content
219
+ logger.info("\n[Multimodal Query]: Mathematical formula analysis")
220
+ equation_result = await rag.aquery_with_multimodal(
221
+ "Explain this formula and relate it to any mathematical concepts in the document",
222
+ multimodal_content=[
223
+ {
224
+ "type": "equation",
225
+ "latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}",
226
+ "equation_caption": "F1-score calculation formula",
227
+ }
228
+ ],
229
+ mode="hybrid",
230
+ )
231
+ logger.info(f"Answer: {equation_result}")
232
 
233
  except Exception as e:
234
+ logger.error(f"Error processing with RAG: {str(e)}")
235
+ import traceback
236
+
237
+ logger.error(traceback.format_exc())
238
 
239
 
240
  def main():
 
248
  "--output", "-o", default="./output", help="Output directory path"
249
  )
250
  parser.add_argument(
251
+ "--api-key",
252
+ default=os.getenv("OPENAI_API_KEY"),
253
+ help="OpenAI API key (defaults to OPENAI_API_KEY env var)",
254
  )
255
  parser.add_argument("--base-url", help="Optional base URL for API")
256
 
257
  args = parser.parse_args()
258
 
259
+ # Check if API key is provided
260
+ if not args.api_key:
261
+ logger.error("Error: OpenAI API key is required")
262
+ logger.error("Set OPENAI_API_KEY environment variable or use --api-key option")
263
+ return
264
+
265
  # Create output directory if specified
266
  if args.output:
267
  os.makedirs(args.output, exist_ok=True)
 
275
 
276
 
277
  if __name__ == "__main__":
278
+ # Configure logging first
279
+ configure_logging()
280
+
281
+ print("RAGAnything Example")
282
+ print("=" * 30)
283
+ print("Processing document with multimodal RAG pipeline")
284
+ print("=" * 30)
285
+
286
  main()
examples/unofficial-sample/copy_llm_cache_to_another_storage.py CHANGED
@@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
52
  embedding_func=None,
53
  )
54
 
 
 
 
 
55
  kv = {}
56
- for c_id in await from_llm_response_cache.all_keys():
57
- print(f"Copying {c_id}")
58
- workspace = c_id["workspace"]
59
- mode = c_id["mode"]
60
- _id = c_id["id"]
61
- postgres_db.workspace = workspace
62
- obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
63
- if mode not in kv:
64
- kv[mode] = {}
65
- kv[mode][_id] = obj[_id]
66
- print(f"Object {obj}")
 
67
  await to_llm_response_cache.upsert(kv)
68
  await to_llm_response_cache.index_done_callback()
69
  print("Mission accomplished!")
@@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
85
  db=postgres_db,
86
  )
87
 
88
- for mode in await from_llm_response_cache.all_keys():
89
- print(f"Copying {mode}")
90
- caches = await from_llm_response_cache.get_by_id(mode)
91
- for k, v in caches.items():
92
- item = {mode: {k: v}}
93
- print(f"\tCopying {item}")
94
- await to_llm_response_cache.upsert(item)
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  if __name__ == "__main__":
 
52
  embedding_func=None,
53
  )
54
 
55
+ # Get all cache data using the new flattened structure
56
+ all_data = await from_llm_response_cache.get_all()
57
+
58
+ # Convert flattened data to hierarchical structure for JsonKVStorage
59
  kv = {}
60
+ for flattened_key, cache_entry in all_data.items():
61
+ # Parse flattened key: {mode}:{cache_type}:{hash}
62
+ parts = flattened_key.split(":", 2)
63
+ if len(parts) == 3:
64
+ mode, cache_type, hash_value = parts
65
+ if mode not in kv:
66
+ kv[mode] = {}
67
+ kv[mode][hash_value] = cache_entry
68
+ print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
69
+ else:
70
+ print(f"Skipping invalid key format: {flattened_key}")
71
+
72
  await to_llm_response_cache.upsert(kv)
73
  await to_llm_response_cache.index_done_callback()
74
  print("Mission accomplished!")
 
90
  db=postgres_db,
91
  )
92
 
93
+ # Get all cache data from JsonKVStorage (hierarchical structure)
94
+ all_data = await from_llm_response_cache.get_all()
95
+
96
+ # Convert hierarchical data to flattened structure for PGKVStorage
97
+ flattened_data = {}
98
+ for mode, mode_data in all_data.items():
99
+ print(f"Processing mode: {mode}")
100
+ for hash_value, cache_entry in mode_data.items():
101
+ # Determine cache_type from cache entry or use default
102
+ cache_type = cache_entry.get("cache_type", "extract")
103
+ # Create flattened key: {mode}:{cache_type}:{hash}
104
+ flattened_key = f"{mode}:{cache_type}:{hash_value}"
105
+ flattened_data[flattened_key] = cache_entry
106
+ print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
107
+
108
+ # Upsert the flattened data
109
+ await to_llm_response_cache.upsert(flattened_data)
110
+ print("Mission accomplished!")
111
 
112
 
113
  if __name__ == "__main__":
lightrag/api/__init__.py CHANGED
@@ -1 +1 @@
1
- __api_version__ = "0176"
 
1
+ __api_version__ = "0178"
lightrag/api/routers/document_routes.py CHANGED
@@ -62,6 +62,51 @@ router = APIRouter(
62
  temp_prefix = "__tmp__"
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class ScanResponse(BaseModel):
66
  """Response model for document scanning operation
67
 
@@ -783,7 +828,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
783
  try:
784
  new_files = doc_manager.scan_directory_for_new_files()
785
  total_files = len(new_files)
786
- logger.info(f"Found {total_files} new files to index.")
787
 
788
  if not new_files:
789
  return
@@ -816,8 +861,13 @@ async def background_delete_documents(
816
  successful_deletions = []
817
  failed_deletions = []
818
 
819
- # Set pipeline status to busy for deletion
820
  async with pipeline_status_lock:
 
 
 
 
 
821
  pipeline_status.update(
822
  {
823
  "busy": True,
@@ -926,13 +976,26 @@ async def background_delete_documents(
926
  async with pipeline_status_lock:
927
  pipeline_status["history_messages"].append(error_msg)
928
  finally:
929
- # Final summary
930
  async with pipeline_status_lock:
931
  pipeline_status["busy"] = False
932
  completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
933
  pipeline_status["latest_message"] = completion_msg
934
  pipeline_status["history_messages"].append(completion_msg)
935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
 
937
  def create_document_routes(
938
  rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
@@ -986,18 +1049,21 @@ def create_document_routes(
986
  HTTPException: If the file type is not supported (400) or other errors occur (500).
987
  """
988
  try:
989
- if not doc_manager.is_supported_file(file.filename):
 
 
 
990
  raise HTTPException(
991
  status_code=400,
992
  detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
993
  )
994
 
995
- file_path = doc_manager.input_dir / file.filename
996
  # Check if file already exists
997
  if file_path.exists():
998
  return InsertResponse(
999
  status="duplicated",
1000
- message=f"File '{file.filename}' already exists in the input directory.",
1001
  )
1002
 
1003
  with open(file_path, "wb") as buffer:
@@ -1008,7 +1074,7 @@ def create_document_routes(
1008
 
1009
  return InsertResponse(
1010
  status="success",
1011
- message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
1012
  )
1013
  except Exception as e:
1014
  logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
 
62
  temp_prefix = "__tmp__"
63
 
64
 
65
+ def sanitize_filename(filename: str, input_dir: Path) -> str:
66
+ """
67
+ Sanitize uploaded filename to prevent Path Traversal attacks.
68
+
69
+ Args:
70
+ filename: The original filename from the upload
71
+ input_dir: The target input directory
72
+
73
+ Returns:
74
+ str: Sanitized filename that is safe to use
75
+
76
+ Raises:
77
+ HTTPException: If the filename is unsafe or invalid
78
+ """
79
+ # Basic validation
80
+ if not filename or not filename.strip():
81
+ raise HTTPException(status_code=400, detail="Filename cannot be empty")
82
+
83
+ # Remove path separators and traversal sequences
84
+ clean_name = filename.replace("/", "").replace("\\", "")
85
+ clean_name = clean_name.replace("..", "")
86
+
87
+ # Remove control characters and null bytes
88
+ clean_name = "".join(c for c in clean_name if ord(c) >= 32 and c != "\x7f")
89
+
90
+ # Remove leading/trailing whitespace and dots
91
+ clean_name = clean_name.strip().strip(".")
92
+
93
+ # Check if anything is left after sanitization
94
+ if not clean_name:
95
+ raise HTTPException(
96
+ status_code=400, detail="Invalid filename after sanitization"
97
+ )
98
+
99
+ # Verify the final path stays within the input directory
100
+ try:
101
+ final_path = (input_dir / clean_name).resolve()
102
+ if not final_path.is_relative_to(input_dir.resolve()):
103
+ raise HTTPException(status_code=400, detail="Unsafe filename detected")
104
+ except (OSError, ValueError):
105
+ raise HTTPException(status_code=400, detail="Invalid filename")
106
+
107
+ return clean_name
108
+
109
+
110
  class ScanResponse(BaseModel):
111
  """Response model for document scanning operation
112
 
 
828
  try:
829
  new_files = doc_manager.scan_directory_for_new_files()
830
  total_files = len(new_files)
831
+ logger.info(f"Found {total_files} files to index.")
832
 
833
  if not new_files:
834
  return
 
861
  successful_deletions = []
862
  failed_deletions = []
863
 
864
+ # Double-check pipeline status before proceeding
865
  async with pipeline_status_lock:
866
+ if pipeline_status.get("busy", False):
867
+ logger.warning("Error: Unexpected pipeline busy state, aborting deletion.")
868
+ return # Abort deletion operation
869
+
870
+ # Set pipeline status to busy for deletion
871
  pipeline_status.update(
872
  {
873
  "busy": True,
 
976
  async with pipeline_status_lock:
977
  pipeline_status["history_messages"].append(error_msg)
978
  finally:
979
+ # Final summary and check for pending requests
980
  async with pipeline_status_lock:
981
  pipeline_status["busy"] = False
982
  completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed"
983
  pipeline_status["latest_message"] = completion_msg
984
  pipeline_status["history_messages"].append(completion_msg)
985
 
986
+ # Check if there are pending document indexing requests
987
+ has_pending_request = pipeline_status.get("request_pending", False)
988
+
989
+ # If there are pending requests, start document processing pipeline
990
+ if has_pending_request:
991
+ try:
992
+ logger.info(
993
+ "Processing pending document indexing requests after deletion"
994
+ )
995
+ await rag.apipeline_process_enqueue_documents()
996
+ except Exception as e:
997
+ logger.error(f"Error processing pending documents after deletion: {e}")
998
+
999
 
1000
  def create_document_routes(
1001
  rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
 
1049
  HTTPException: If the file type is not supported (400) or other errors occur (500).
1050
  """
1051
  try:
1052
+ # Sanitize filename to prevent Path Traversal attacks
1053
+ safe_filename = sanitize_filename(file.filename, doc_manager.input_dir)
1054
+
1055
+ if not doc_manager.is_supported_file(safe_filename):
1056
  raise HTTPException(
1057
  status_code=400,
1058
  detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
1059
  )
1060
 
1061
+ file_path = doc_manager.input_dir / safe_filename
1062
  # Check if file already exists
1063
  if file_path.exists():
1064
  return InsertResponse(
1065
  status="duplicated",
1066
+ message=f"File '{safe_filename}' already exists in the input directory.",
1067
  )
1068
 
1069
  with open(file_path, "wb") as buffer:
 
1074
 
1075
  return InsertResponse(
1076
  status="success",
1077
+ message=f"File '{safe_filename}' uploaded successfully. Processing will continue in background.",
1078
  )
1079
  except Exception as e:
1080
  logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
lightrag/api/routers/ollama_api.py CHANGED
@@ -234,7 +234,7 @@ class OllamaAPI:
234
  @self.router.get("/version", dependencies=[Depends(combined_auth)])
235
  async def get_version():
236
  """Get Ollama version information"""
237
- return OllamaVersionResponse(version="0.5.4")
238
 
239
  @self.router.get("/tags", dependencies=[Depends(combined_auth)])
240
  async def get_tags():
@@ -244,9 +244,9 @@ class OllamaAPI:
244
  {
245
  "name": self.ollama_server_infos.LIGHTRAG_MODEL,
246
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
 
247
  "size": self.ollama_server_infos.LIGHTRAG_SIZE,
248
  "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
249
- "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
250
  "details": {
251
  "parent_model": "",
252
  "format": "gguf",
@@ -337,7 +337,10 @@ class OllamaAPI:
337
  data = {
338
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
339
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
340
  "done": True,
 
 
341
  "total_duration": total_time,
342
  "load_duration": 0,
343
  "prompt_eval_count": prompt_tokens,
@@ -377,6 +380,7 @@ class OllamaAPI:
377
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
378
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
379
  "response": f"\n\nError: {error_msg}",
 
380
  "done": False,
381
  }
382
  yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@@ -385,6 +389,7 @@ class OllamaAPI:
385
  final_data = {
386
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
387
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
388
  "done": True,
389
  }
390
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@@ -399,7 +404,10 @@ class OllamaAPI:
399
  data = {
400
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
401
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
402
  "done": True,
 
 
403
  "total_duration": total_time,
404
  "load_duration": 0,
405
  "prompt_eval_count": prompt_tokens,
@@ -444,6 +452,8 @@ class OllamaAPI:
444
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
445
  "response": str(response_text),
446
  "done": True,
 
 
447
  "total_duration": total_time,
448
  "load_duration": 0,
449
  "prompt_eval_count": prompt_tokens,
@@ -557,6 +567,12 @@ class OllamaAPI:
557
  data = {
558
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
559
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
 
 
 
 
 
560
  "done": True,
561
  "total_duration": total_time,
562
  "load_duration": 0,
@@ -605,6 +621,7 @@ class OllamaAPI:
605
  "content": f"\n\nError: {error_msg}",
606
  "images": None,
607
  },
 
608
  "done": False,
609
  }
610
  yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
@@ -613,6 +630,11 @@ class OllamaAPI:
613
  final_data = {
614
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
615
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
 
 
 
 
 
616
  "done": True,
617
  }
618
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
@@ -633,6 +655,7 @@ class OllamaAPI:
633
  "content": "",
634
  "images": None,
635
  },
 
636
  "done": True,
637
  "total_duration": total_time,
638
  "load_duration": 0,
@@ -697,6 +720,7 @@ class OllamaAPI:
697
  "content": str(response_text),
698
  "images": None,
699
  },
 
700
  "done": True,
701
  "total_duration": total_time,
702
  "load_duration": 0,
 
234
  @self.router.get("/version", dependencies=[Depends(combined_auth)])
235
  async def get_version():
236
  """Get Ollama version information"""
237
+ return OllamaVersionResponse(version="0.9.3")
238
 
239
  @self.router.get("/tags", dependencies=[Depends(combined_auth)])
240
  async def get_tags():
 
244
  {
245
  "name": self.ollama_server_infos.LIGHTRAG_MODEL,
246
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
247
+ "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
248
  "size": self.ollama_server_infos.LIGHTRAG_SIZE,
249
  "digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
 
250
  "details": {
251
  "parent_model": "",
252
  "format": "gguf",
 
337
  data = {
338
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
339
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
340
+ "response": "",
341
  "done": True,
342
+ "done_reason": "stop",
343
+ "context": [],
344
  "total_duration": total_time,
345
  "load_duration": 0,
346
  "prompt_eval_count": prompt_tokens,
 
380
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
381
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
382
  "response": f"\n\nError: {error_msg}",
383
+ "error": f"\n\nError: {error_msg}",
384
  "done": False,
385
  }
386
  yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
 
389
  final_data = {
390
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
391
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
392
+ "response": "",
393
  "done": True,
394
  }
395
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
 
404
  data = {
405
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
406
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
407
+ "response": "",
408
  "done": True,
409
+ "done_reason": "stop",
410
+ "context": [],
411
  "total_duration": total_time,
412
  "load_duration": 0,
413
  "prompt_eval_count": prompt_tokens,
 
452
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
453
  "response": str(response_text),
454
  "done": True,
455
+ "done_reason": "stop",
456
+ "context": [],
457
  "total_duration": total_time,
458
  "load_duration": 0,
459
  "prompt_eval_count": prompt_tokens,
 
567
  data = {
568
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
569
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
570
+ "message": {
571
+ "role": "assistant",
572
+ "content": "",
573
+ "images": None,
574
+ },
575
+ "done_reason": "stop",
576
  "done": True,
577
  "total_duration": total_time,
578
  "load_duration": 0,
 
621
  "content": f"\n\nError: {error_msg}",
622
  "images": None,
623
  },
624
+ "error": f"\n\nError: {error_msg}",
625
  "done": False,
626
  }
627
  yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
 
630
  final_data = {
631
  "model": self.ollama_server_infos.LIGHTRAG_MODEL,
632
  "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
633
+ "message": {
634
+ "role": "assistant",
635
+ "content": "",
636
+ "images": None,
637
+ },
638
  "done": True,
639
  }
640
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
 
655
  "content": "",
656
  "images": None,
657
  },
658
+ "done_reason": "stop",
659
  "done": True,
660
  "total_duration": total_time,
661
  "load_duration": 0,
 
720
  "content": str(response_text),
721
  "images": None,
722
  },
723
+ "done_reason": "stop",
724
  "done": True,
725
  "total_duration": total_time,
726
  "load_duration": 0,
lightrag/api/routers/query_routes.py CHANGED
@@ -183,6 +183,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
183
  if isinstance(response, str):
184
  # If it's a string, send it all at once
185
  yield f"{json.dumps({'response': response})}\n"
 
 
 
186
  else:
187
  # If it's an async generator, send chunks one by one
188
  try:
 
183
  if isinstance(response, str):
184
  # If it's a string, send it all at once
185
  yield f"{json.dumps({'response': response})}\n"
186
+ elif response is None:
187
+ # Handle None response (e.g., when only_need_context=True but no context found)
188
+ yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
189
  else:
190
  # If it's an async generator, send chunks one by one
191
  try:
lightrag/base.py CHANGED
@@ -297,6 +297,8 @@ class BaseKVStorage(StorageNameSpace, ABC):
297
 
298
  @dataclass
299
  class BaseGraphStorage(StorageNameSpace, ABC):
 
 
300
  embedding_func: EmbeddingFunc
301
 
302
  @abstractmethod
@@ -468,17 +470,6 @@ class BaseGraphStorage(StorageNameSpace, ABC):
468
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
469
  An empty list if no matching nodes are found.
470
  """
471
- # Default implementation iterates through all nodes, which is inefficient.
472
- # This method should be overridden by subclasses for better performance.
473
- all_nodes = []
474
- all_labels = await self.get_all_labels()
475
- for label in all_labels:
476
- node = await self.get_node(label)
477
- if node and "source_id" in node:
478
- source_ids = set(node["source_id"].split(GRAPH_FIELD_SEP))
479
- if not source_ids.isdisjoint(chunk_ids):
480
- all_nodes.append(node)
481
- return all_nodes
482
 
483
  @abstractmethod
484
  async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
@@ -643,6 +634,8 @@ class DocProcessingStatus:
643
  """ISO format timestamp when document was last updated"""
644
  chunks_count: int | None = None
645
  """Number of chunks after splitting, used for processing"""
 
 
646
  error: str | None = None
647
  """Error message if failed"""
648
  metadata: dict[str, Any] = field(default_factory=dict)
 
297
 
298
  @dataclass
299
  class BaseGraphStorage(StorageNameSpace, ABC):
300
+ """All operations related to edges in graph should be undirected."""
301
+
302
  embedding_func: EmbeddingFunc
303
 
304
  @abstractmethod
 
470
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
471
  An empty list if no matching nodes are found.
472
  """
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  @abstractmethod
475
  async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 
634
  """ISO format timestamp when document was last updated"""
635
  chunks_count: int | None = None
636
  """Number of chunks after splitting, used for processing"""
637
+ chunks_list: list[str] | None = field(default_factory=list)
638
+ """List of chunk IDs associated with this document, used for deletion"""
639
  error: str | None = None
640
  """Error message if failed"""
641
  metadata: dict[str, Any] = field(default_factory=dict)
lightrag/constants.py CHANGED
@@ -7,6 +7,7 @@ consistency and makes maintenance easier.
7
  """
8
 
9
  # Default values for environment variables
 
10
  DEFAULT_MAX_TOKEN_SUMMARY = 500
11
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
12
  DEFAULT_WOKERS = 2
 
7
  """
8
 
9
  # Default values for environment variables
10
+ DEFAULT_MAX_GLEANING = 1
11
  DEFAULT_MAX_TOKEN_SUMMARY = 500
12
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 6
13
  DEFAULT_WOKERS = 2
lightrag/kg/__init__.py CHANGED
@@ -26,11 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
26
  "implementations": [
27
  "NanoVectorDBStorage",
28
  "MilvusVectorDBStorage",
29
- "ChromaVectorDBStorage",
30
  "PGVectorStorage",
31
  "FaissVectorDBStorage",
32
  "QdrantVectorDBStorage",
33
  "MongoVectorDBStorage",
 
34
  # "TiDBVectorDBStorage",
35
  ],
36
  "required_methods": ["query", "upsert"],
@@ -38,6 +38,7 @@ STORAGE_IMPLEMENTATIONS = {
38
  "DOC_STATUS_STORAGE": {
39
  "implementations": [
40
  "JsonDocStatusStorage",
 
41
  "PGDocStatusStorage",
42
  "MongoDocStatusStorage",
43
  ],
@@ -81,6 +82,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
81
  "MongoVectorDBStorage": [],
82
  # Document Status Storage Implementations
83
  "JsonDocStatusStorage": [],
 
84
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
85
  "MongoDocStatusStorage": [],
86
  }
@@ -98,6 +100,7 @@ STORAGES = {
98
  "MongoGraphStorage": ".kg.mongo_impl",
99
  "MongoVectorDBStorage": ".kg.mongo_impl",
100
  "RedisKVStorage": ".kg.redis_impl",
 
101
  "ChromaVectorDBStorage": ".kg.chroma_impl",
102
  # "TiDBKVStorage": ".kg.tidb_impl",
103
  # "TiDBVectorDBStorage": ".kg.tidb_impl",
 
26
  "implementations": [
27
  "NanoVectorDBStorage",
28
  "MilvusVectorDBStorage",
 
29
  "PGVectorStorage",
30
  "FaissVectorDBStorage",
31
  "QdrantVectorDBStorage",
32
  "MongoVectorDBStorage",
33
+ # "ChromaVectorDBStorage",
34
  # "TiDBVectorDBStorage",
35
  ],
36
  "required_methods": ["query", "upsert"],
 
38
  "DOC_STATUS_STORAGE": {
39
  "implementations": [
40
  "JsonDocStatusStorage",
41
+ "RedisDocStatusStorage",
42
  "PGDocStatusStorage",
43
  "MongoDocStatusStorage",
44
  ],
 
82
  "MongoVectorDBStorage": [],
83
  # Document Status Storage Implementations
84
  "JsonDocStatusStorage": [],
85
+ "RedisDocStatusStorage": ["REDIS_URI"],
86
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
87
  "MongoDocStatusStorage": [],
88
  }
 
100
  "MongoGraphStorage": ".kg.mongo_impl",
101
  "MongoVectorDBStorage": ".kg.mongo_impl",
102
  "RedisKVStorage": ".kg.redis_impl",
103
+ "RedisDocStatusStorage": ".kg.redis_impl",
104
  "ChromaVectorDBStorage": ".kg.chroma_impl",
105
  # "TiDBKVStorage": ".kg.tidb_impl",
106
  # "TiDBVectorDBStorage": ".kg.tidb_impl",
lightrag/kg/{chroma_impl.py β†’ deprecated/chroma_impl.py} RENAMED
@@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
109
  raise
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
- logger.info(f"Inserting {len(data)} to {self.namespace}")
113
  if not data:
114
  return
115
 
@@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
234
  ids: List of vector IDs to be deleted
235
  """
236
  try:
237
- logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
238
  self._collection.delete(ids=ids)
239
  logger.debug(
240
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
 
109
  raise
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
113
  if not data:
114
  return
115
 
 
234
  ids: List of vector IDs to be deleted
235
  """
236
  try:
 
237
  self._collection.delete(ids=ids)
238
  logger.debug(
239
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
lightrag/kg/{gremlin_impl.py β†’ deprecated/gremlin_impl.py} RENAMED
File without changes
lightrag/kg/{tidb_impl.py β†’ deprecated/tidb_impl.py} RENAMED
@@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
257
 
258
  ################ INSERT full_doc AND chunks ################
259
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
260
- logger.info(f"Inserting {len(data)} to {self.namespace}")
261
  if not data:
262
  return
263
  left_data = {k: v for k, v in data.items() if k not in self._data}
@@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
454
 
455
  ###### INSERT entities And relationships ######
456
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
457
- logger.info(f"Inserting {len(data)} to {self.namespace}")
458
  if not data:
459
  return
460
-
461
- logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
462
 
463
  # Get current time as UNIX timestamp
464
  import time
@@ -522,11 +520,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
522
  }
523
  await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
524
 
525
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
526
- SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
527
- params = {"workspace": self.db.workspace, "status": status}
528
- return await self.db.query(SQL, params, multirows=True)
529
-
530
  async def delete(self, ids: list[str]) -> None:
531
  """Delete vectors with specified IDs from the storage.
532
 
 
257
 
258
  ################ INSERT full_doc AND chunks ################
259
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
260
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
261
  if not data:
262
  return
263
  left_data = {k: v for k, v in data.items() if k not in self._data}
 
454
 
455
  ###### INSERT entities And relationships ######
456
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
457
  if not data:
458
  return
459
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
 
460
 
461
  # Get current time as UNIX timestamp
462
  import time
 
520
  }
521
  await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
522
 
 
 
 
 
 
523
  async def delete(self, ids: list[str]) -> None:
524
  """Delete vectors with specified IDs from the storage.
525
 
lightrag/kg/faiss_impl.py CHANGED
@@ -17,14 +17,13 @@ from .shared_storage import (
17
  set_all_update_flags,
18
  )
19
 
20
- import faiss # type: ignore
21
-
22
  USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
23
  FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
24
-
25
  if not pm.is_installed(FAISS_PACKAGE):
26
  pm.install(FAISS_PACKAGE)
27
 
 
 
28
 
29
  @final
30
  @dataclass
 
17
  set_all_update_flags,
18
  )
19
 
 
 
20
  USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
21
  FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
 
22
  if not pm.is_installed(FAISS_PACKAGE):
23
  pm.install(FAISS_PACKAGE)
24
 
25
+ import faiss # type: ignore
26
+
27
 
28
  @final
29
  @dataclass
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
118
  return
119
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
 
 
 
 
121
  self._data.update(data)
122
  await set_all_update_flags(self.namespace)
123
 
 
118
  return
119
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
121
+ # Ensure chunks_list field exists for new documents
122
+ for doc_id, doc_data in data.items():
123
+ if "chunks_list" not in doc_data:
124
+ doc_data["chunks_list"] = []
125
  self._data.update(data)
126
  await set_all_update_flags(self.namespace)
127
 
lightrag/kg/json_kv_impl.py CHANGED
@@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
42
  if need_init:
43
  loaded_data = load_json(self._file_name) or {}
44
  async with self._storage_lock:
45
- self._data.update(loaded_data)
46
-
47
- # Calculate data count based on namespace
48
- if self.namespace.endswith("cache"):
49
- # For cache namespaces, sum the cache entries across all cache types
50
- data_count = sum(
51
- len(first_level_dict)
52
- for first_level_dict in loaded_data.values()
53
- if isinstance(first_level_dict, dict)
54
  )
55
- else:
56
- # For non-cache namespaces, use the original count method
57
- data_count = len(loaded_data)
58
 
59
  logger.info(
60
  f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
@@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
67
  dict(self._data) if hasattr(self._data, "_getvalue") else self._data
68
  )
69
 
70
- # Calculate data count based on namespace
71
- if self.namespace.endswith("cache"):
72
- # # For cache namespaces, sum the cache entries across all cache types
73
- data_count = sum(
74
- len(first_level_dict)
75
- for first_level_dict in data_dict.values()
76
- if isinstance(first_level_dict, dict)
77
- )
78
- else:
79
- # For non-cache namespaces, use the original count method
80
- data_count = len(data_dict)
81
 
82
  logger.debug(
83
  f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
@@ -92,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
92
  Dictionary containing all stored data
93
  """
94
  async with self._storage_lock:
95
- return dict(self._data)
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
98
  async with self._storage_lock:
99
- return self._data.get(id)
 
 
 
 
 
 
 
 
 
100
 
101
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
102
  async with self._storage_lock:
103
- return [
104
- (
105
- {k: v for k, v in self._data[id].items()}
106
- if self._data.get(id, None)
107
- else None
108
- )
109
- for id in ids
110
- ]
 
 
 
 
 
 
 
111
 
112
  async def filter_keys(self, keys: set[str]) -> set[str]:
113
  async with self._storage_lock:
@@ -121,8 +134,29 @@ class JsonKVStorage(BaseKVStorage):
121
  """
122
  if not data:
123
  return
 
 
 
 
 
124
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
125
  async with self._storage_lock:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  self._data.update(data)
127
  await set_all_update_flags(self.namespace)
128
 
@@ -150,14 +184,14 @@ class JsonKVStorage(BaseKVStorage):
150
  await set_all_update_flags(self.namespace)
151
 
152
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
153
- """Delete specific records from storage by by cache mode
154
 
155
  Importance notes for in-memory storage:
156
  1. Changes will be persisted to disk during the next index_done_callback
157
  2. update flags to notify other processes that data persistence is needed
158
 
159
  Args:
160
- ids (list[str]): List of cache mode to be drop from storage
161
 
162
  Returns:
163
  True: if the cache drop successfully
@@ -167,9 +201,29 @@ class JsonKVStorage(BaseKVStorage):
167
  return False
168
 
169
  try:
170
- await self.delete(modes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return True
172
- except Exception:
 
173
  return False
174
 
175
  # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
@@ -245,9 +299,58 @@ class JsonKVStorage(BaseKVStorage):
245
  logger.error(f"Error dropping {self.namespace}: {e}")
246
  return {"status": "error", "message": str(e)}
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  async def finalize(self):
249
  """Finalize storage resources
250
  Persistence cache data to disk before exiting
251
  """
252
- if self.namespace.endswith("cache"):
253
  await self.index_done_callback()
 
42
  if need_init:
43
  loaded_data = load_json(self._file_name) or {}
44
  async with self._storage_lock:
45
+ # Migrate legacy cache structure if needed
46
+ if self.namespace.endswith("_cache"):
47
+ loaded_data = await self._migrate_legacy_cache_structure(
48
+ loaded_data
 
 
 
 
 
49
  )
50
+
51
+ self._data.update(loaded_data)
52
+ data_count = len(loaded_data)
53
 
54
  logger.info(
55
  f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
 
62
  dict(self._data) if hasattr(self._data, "_getvalue") else self._data
63
  )
64
 
65
+ # Calculate data count - all data is now flattened
66
+ data_count = len(data_dict)
 
 
 
 
 
 
 
 
 
67
 
68
  logger.debug(
69
  f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
 
78
  Dictionary containing all stored data
79
  """
80
  async with self._storage_lock:
81
+ result = {}
82
+ for key, value in self._data.items():
83
+ if value:
84
+ # Create a copy to avoid modifying the original data
85
+ data = dict(value)
86
+ # Ensure time fields are present, provide default values for old data
87
+ data.setdefault("create_time", 0)
88
+ data.setdefault("update_time", 0)
89
+ result[key] = data
90
+ else:
91
+ result[key] = value
92
+ return result
93
 
94
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
95
  async with self._storage_lock:
96
+ result = self._data.get(id)
97
+ if result:
98
+ # Create a copy to avoid modifying the original data
99
+ result = dict(result)
100
+ # Ensure time fields are present, provide default values for old data
101
+ result.setdefault("create_time", 0)
102
+ result.setdefault("update_time", 0)
103
+ # Ensure _id field contains the clean ID
104
+ result["_id"] = id
105
+ return result
106
 
107
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
108
  async with self._storage_lock:
109
+ results = []
110
+ for id in ids:
111
+ data = self._data.get(id, None)
112
+ if data:
113
+ # Create a copy to avoid modifying the original data
114
+ result = {k: v for k, v in data.items()}
115
+ # Ensure time fields are present, provide default values for old data
116
+ result.setdefault("create_time", 0)
117
+ result.setdefault("update_time", 0)
118
+ # Ensure _id field contains the clean ID
119
+ result["_id"] = id
120
+ results.append(result)
121
+ else:
122
+ results.append(None)
123
+ return results
124
 
125
  async def filter_keys(self, keys: set[str]) -> set[str]:
126
  async with self._storage_lock:
 
134
  """
135
  if not data:
136
  return
137
+
138
+ import time
139
+
140
+ current_time = int(time.time()) # Get current Unix timestamp
141
+
142
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
143
  async with self._storage_lock:
144
+ # Add timestamps to data based on whether key exists
145
+ for k, v in data.items():
146
+ # For text_chunks namespace, ensure llm_cache_list field exists
147
+ if "text_chunks" in self.namespace:
148
+ if "llm_cache_list" not in v:
149
+ v["llm_cache_list"] = []
150
+
151
+ # Add timestamps based on whether key exists
152
+ if k in self._data: # Key exists, only update update_time
153
+ v["update_time"] = current_time
154
+ else: # New key, set both create_time and update_time
155
+ v["create_time"] = current_time
156
+ v["update_time"] = current_time
157
+
158
+ v["_id"] = k
159
+
160
  self._data.update(data)
161
  await set_all_update_flags(self.namespace)
162
 
 
184
  await set_all_update_flags(self.namespace)
185
 
186
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
187
+ """Delete specific records from storage by cache mode
188
 
189
  Importance notes for in-memory storage:
190
  1. Changes will be persisted to disk during the next index_done_callback
191
  2. update flags to notify other processes that data persistence is needed
192
 
193
  Args:
194
+ modes (list[str]): List of cache modes to be dropped from storage
195
 
196
  Returns:
197
  True: if the cache drop successfully
 
201
  return False
202
 
203
  try:
204
+ async with self._storage_lock:
205
+ keys_to_delete = []
206
+ modes_set = set(modes) # Convert to set for efficient lookup
207
+
208
+ for key in list(self._data.keys()):
209
+ # Parse flattened cache key: mode:cache_type:hash
210
+ parts = key.split(":", 2)
211
+ if len(parts) == 3 and parts[0] in modes_set:
212
+ keys_to_delete.append(key)
213
+
214
+ # Batch delete
215
+ for key in keys_to_delete:
216
+ self._data.pop(key, None)
217
+
218
+ if keys_to_delete:
219
+ await set_all_update_flags(self.namespace)
220
+ logger.info(
221
+ f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
222
+ )
223
+
224
  return True
225
+ except Exception as e:
226
+ logger.error(f"Error dropping cache by modes: {e}")
227
  return False
228
 
229
  # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
 
299
  logger.error(f"Error dropping {self.namespace}: {e}")
300
  return {"status": "error", "message": str(e)}
301
 
302
+ async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
303
+ """Migrate legacy nested cache structure to flattened structure
304
+
305
+ Args:
306
+ data: Original data dictionary that may contain legacy structure
307
+
308
+ Returns:
309
+ Migrated data dictionary with flattened cache keys
310
+ """
311
+ from lightrag.utils import generate_cache_key
312
+
313
+ # Early return if data is empty
314
+ if not data:
315
+ return data
316
+
317
+ # Check first entry to see if it's already in new format
318
+ first_key = next(iter(data.keys()))
319
+ if ":" in first_key and len(first_key.split(":")) == 3:
320
+ # Already in flattened format, return as-is
321
+ return data
322
+
323
+ migrated_data = {}
324
+ migration_count = 0
325
+
326
+ for key, value in data.items():
327
+ # Check if this is a legacy nested cache structure
328
+ if isinstance(value, dict) and all(
329
+ isinstance(v, dict) and "return" in v for v in value.values()
330
+ ):
331
+ # This looks like a legacy cache mode with nested structure
332
+ mode = key
333
+ for cache_hash, cache_entry in value.items():
334
+ cache_type = cache_entry.get("cache_type", "extract")
335
+ flattened_key = generate_cache_key(mode, cache_type, cache_hash)
336
+ migrated_data[flattened_key] = cache_entry
337
+ migration_count += 1
338
+ else:
339
+ # Keep non-cache data or already flattened cache data as-is
340
+ migrated_data[key] = value
341
+
342
+ if migration_count > 0:
343
+ logger.info(
344
+ f"Migrated {migration_count} legacy cache entries to flattened structure"
345
+ )
346
+ # Persist migrated data immediately
347
+ write_json(migrated_data, self._file_name)
348
+
349
+ return migrated_data
350
+
351
  async def finalize(self):
352
  """Finalize storage resources
353
  Persistence cache data to disk before exiting
354
  """
355
+ if self.namespace.endswith("_cache"):
356
  await self.index_done_callback()
lightrag/kg/milvus_impl.py CHANGED
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
- from pymilvus import MilvusClient # type: ignore
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
@@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
24
  @final
25
  @dataclass
26
  class MilvusVectorDBStorage(BaseVectorStorage):
27
- @staticmethod
28
- def create_collection_if_not_exist(
29
- client: MilvusClient, collection_name: str, **kwargs
30
- ):
31
- if client.has_collection(collection_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return
33
- client.create_collection(
34
- collection_name, max_length=64, id_type="string", **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def __post_init__(self):
38
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
39
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
43
  )
44
  self.cosine_better_than_threshold = cosine_threshold
45
 
 
 
 
 
46
  self._client = MilvusClient(
47
  uri=os.environ.get(
48
  "MILVUS_URI",
@@ -68,14 +661,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
68
  ),
69
  )
70
  self._max_batch_size = self.global_config["embedding_batch_num"]
71
- MilvusVectorDBStorage.create_collection_if_not_exist(
72
- self._client,
73
- self.namespace,
74
- dimension=self.embedding_func.embedding_dim,
75
- )
76
 
77
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
78
- logger.info(f"Inserting {len(data)} to {self.namespace}")
79
  if not data:
80
  return
81
 
@@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
112
  embedding = await self.embedding_func(
113
  [query], _priority=5
114
  ) # higher priority for query
 
 
 
 
115
  results = self._client.search(
116
  collection_name=self.namespace,
117
  data=embedding,
118
  limit=top_k,
119
- output_fields=list(self.meta_fields) + ["created_at"],
120
  search_params={
121
  "metric_type": "COSINE",
122
  "params": {"radius": self.cosine_better_than_threshold},
123
  },
124
  )
125
- print(results)
126
  return [
127
  {
128
  **dp["entity"],
129
  "id": dp["id"],
130
  "distance": dp["distance"],
131
- # created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
132
  "created_at": dp.get("created_at"),
133
  }
134
  for dp in results[0]
@@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
232
  The vector data if found, or None if not found
233
  """
234
  try:
 
 
 
235
  # Query Milvus for a specific ID
236
  result = self._client.query(
237
  collection_name=self.namespace,
238
  filter=f'id == "{id}"',
239
- output_fields=list(self.meta_fields) + ["id", "created_at"],
240
  )
241
 
242
  if not result or len(result) == 0:
243
  return None
244
 
245
- # Ensure the result contains created_at field
246
- if "created_at" not in result[0]:
247
- result[0]["created_at"] = None
248
-
249
  return result[0]
250
  except Exception as e:
251
  logger.error(f"Error retrieving vector data for ID {id}: {e}")
@@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
264
  return []
265
 
266
  try:
 
 
 
267
  # Prepare the ID filter expression
268
  id_list = '", "'.join(ids)
269
  filter_expr = f'id in ["{id_list}"]'
@@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
272
  result = self._client.query(
273
  collection_name=self.namespace,
274
  filter=filter_expr,
275
- output_fields=list(self.meta_fields) + ["id", "created_at"],
276
  )
277
 
278
- # Ensure each result contains created_at field
279
- for item in result:
280
- if "created_at" not in item:
281
- item["created_at"] = None
282
-
283
  return result or []
284
  except Exception as e:
285
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
@@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
301
  self._client.drop_collection(self.namespace)
302
 
303
  # Recreate the collection
304
- MilvusVectorDBStorage.create_collection_if_not_exist(
305
- self._client,
306
- self.namespace,
307
- dimension=self.embedding_func.embedding_dim,
308
- )
309
 
310
  logger.info(
311
  f"Process {os.getpid()} drop Milvus collection {self.namespace}"
 
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
+ from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
 
24
  @final
25
  @dataclass
26
  class MilvusVectorDBStorage(BaseVectorStorage):
27
+ def _create_schema_for_namespace(self) -> CollectionSchema:
28
+ """Create schema based on the current instance's namespace"""
29
+
30
+ # Get vector dimension from embedding_func
31
+ dimension = self.embedding_func.embedding_dim
32
+
33
+ # Base fields (common to all collections)
34
+ base_fields = [
35
+ FieldSchema(
36
+ name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
37
+ ),
38
+ FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
39
+ FieldSchema(name="created_at", dtype=DataType.INT64),
40
+ ]
41
+
42
+ # Determine specific fields based on namespace
43
+ if "entities" in self.namespace.lower():
44
+ specific_fields = [
45
+ FieldSchema(
46
+ name="entity_name",
47
+ dtype=DataType.VARCHAR,
48
+ max_length=256,
49
+ nullable=True,
50
+ ),
51
+ FieldSchema(
52
+ name="entity_type",
53
+ dtype=DataType.VARCHAR,
54
+ max_length=64,
55
+ nullable=True,
56
+ ),
57
+ FieldSchema(
58
+ name="file_path",
59
+ dtype=DataType.VARCHAR,
60
+ max_length=512,
61
+ nullable=True,
62
+ ),
63
+ ]
64
+ description = "LightRAG entities vector storage"
65
+
66
+ elif "relationships" in self.namespace.lower():
67
+ specific_fields = [
68
+ FieldSchema(
69
+ name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
70
+ ),
71
+ FieldSchema(
72
+ name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
73
+ ),
74
+ FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
75
+ FieldSchema(
76
+ name="file_path",
77
+ dtype=DataType.VARCHAR,
78
+ max_length=512,
79
+ nullable=True,
80
+ ),
81
+ ]
82
+ description = "LightRAG relationships vector storage"
83
+
84
+ elif "chunks" in self.namespace.lower():
85
+ specific_fields = [
86
+ FieldSchema(
87
+ name="full_doc_id",
88
+ dtype=DataType.VARCHAR,
89
+ max_length=64,
90
+ nullable=True,
91
+ ),
92
+ FieldSchema(
93
+ name="file_path",
94
+ dtype=DataType.VARCHAR,
95
+ max_length=512,
96
+ nullable=True,
97
+ ),
98
+ ]
99
+ description = "LightRAG chunks vector storage"
100
+
101
+ else:
102
+ # Default generic schema (backward compatibility)
103
+ specific_fields = [
104
+ FieldSchema(
105
+ name="file_path",
106
+ dtype=DataType.VARCHAR,
107
+ max_length=512,
108
+ nullable=True,
109
+ ),
110
+ ]
111
+ description = "LightRAG generic vector storage"
112
+
113
+ # Merge all fields
114
+ all_fields = base_fields + specific_fields
115
+
116
+ return CollectionSchema(
117
+ fields=all_fields,
118
+ description=description,
119
+ enable_dynamic_field=True, # Support dynamic fields
120
+ )
121
+
122
+ def _get_index_params(self):
123
+ """Get IndexParams in a version-compatible way"""
124
+ try:
125
+ # Try to use client's prepare_index_params method (most common)
126
+ if hasattr(self._client, "prepare_index_params"):
127
+ return self._client.prepare_index_params()
128
+ except Exception:
129
+ pass
130
+
131
+ try:
132
+ # Try to import IndexParams from different possible locations
133
+ from pymilvus.client.prepare import IndexParams
134
+
135
+ return IndexParams()
136
+ except ImportError:
137
+ pass
138
+
139
+ try:
140
+ from pymilvus.client.types import IndexParams
141
+
142
+ return IndexParams()
143
+ except ImportError:
144
+ pass
145
+
146
+ try:
147
+ from pymilvus import IndexParams
148
+
149
+ return IndexParams()
150
+ except ImportError:
151
+ pass
152
+
153
+ # If all else fails, return None to use fallback method
154
+ return None
155
+
156
+ def _create_vector_index_fallback(self):
157
+ """Fallback method to create vector index using direct API"""
158
+ try:
159
+ self._client.create_index(
160
+ collection_name=self.namespace,
161
+ field_name="vector",
162
+ index_params={
163
+ "index_type": "HNSW",
164
+ "metric_type": "COSINE",
165
+ "params": {"M": 16, "efConstruction": 256},
166
+ },
167
+ )
168
+ logger.debug("Created vector index using fallback method")
169
+ except Exception as e:
170
+ logger.warning(f"Failed to create vector index using fallback method: {e}")
171
+
172
+ def _create_scalar_index_fallback(self, field_name: str, index_type: str):
173
+ """Fallback method to create scalar index using direct API"""
174
+ # Skip unsupported index types
175
+ if index_type == "SORTED":
176
+ logger.info(
177
+ f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
178
+ )
179
  return
180
+
181
+ try:
182
+ self._client.create_index(
183
+ collection_name=self.namespace,
184
+ field_name=field_name,
185
+ index_params={"index_type": index_type},
186
+ )
187
+ logger.debug(f"Created {field_name} index using fallback method")
188
+ except Exception as e:
189
+ logger.info(
190
+ f"Could not create {field_name} index using fallback method: {e}"
191
+ )
192
+
193
+ def _create_indexes_after_collection(self):
194
+ """Create indexes after collection is created"""
195
+ try:
196
+ # Try to get IndexParams in a version-compatible way
197
+ IndexParamsClass = self._get_index_params()
198
+
199
+ if IndexParamsClass is not None:
200
+ # Use IndexParams approach if available
201
+ try:
202
+ # Create vector index first (required for most operations)
203
+ vector_index = IndexParamsClass
204
+ vector_index.add_index(
205
+ field_name="vector",
206
+ index_type="HNSW",
207
+ metric_type="COSINE",
208
+ params={"M": 16, "efConstruction": 256},
209
+ )
210
+ self._client.create_index(
211
+ collection_name=self.namespace, index_params=vector_index
212
+ )
213
+ logger.debug("Created vector index using IndexParams")
214
+ except Exception as e:
215
+ logger.debug(f"IndexParams method failed for vector index: {e}")
216
+ self._create_vector_index_fallback()
217
+
218
+ # Create scalar indexes based on namespace
219
+ if "entities" in self.namespace.lower():
220
+ # Create indexes for entity fields
221
+ try:
222
+ entity_name_index = self._get_index_params()
223
+ entity_name_index.add_index(
224
+ field_name="entity_name", index_type="INVERTED"
225
+ )
226
+ self._client.create_index(
227
+ collection_name=self.namespace,
228
+ index_params=entity_name_index,
229
+ )
230
+ except Exception as e:
231
+ logger.debug(f"IndexParams method failed for entity_name: {e}")
232
+ self._create_scalar_index_fallback("entity_name", "INVERTED")
233
+
234
+ try:
235
+ entity_type_index = self._get_index_params()
236
+ entity_type_index.add_index(
237
+ field_name="entity_type", index_type="INVERTED"
238
+ )
239
+ self._client.create_index(
240
+ collection_name=self.namespace,
241
+ index_params=entity_type_index,
242
+ )
243
+ except Exception as e:
244
+ logger.debug(f"IndexParams method failed for entity_type: {e}")
245
+ self._create_scalar_index_fallback("entity_type", "INVERTED")
246
+
247
+ elif "relationships" in self.namespace.lower():
248
+ # Create indexes for relationship fields
249
+ try:
250
+ src_id_index = self._get_index_params()
251
+ src_id_index.add_index(
252
+ field_name="src_id", index_type="INVERTED"
253
+ )
254
+ self._client.create_index(
255
+ collection_name=self.namespace, index_params=src_id_index
256
+ )
257
+ except Exception as e:
258
+ logger.debug(f"IndexParams method failed for src_id: {e}")
259
+ self._create_scalar_index_fallback("src_id", "INVERTED")
260
+
261
+ try:
262
+ tgt_id_index = self._get_index_params()
263
+ tgt_id_index.add_index(
264
+ field_name="tgt_id", index_type="INVERTED"
265
+ )
266
+ self._client.create_index(
267
+ collection_name=self.namespace, index_params=tgt_id_index
268
+ )
269
+ except Exception as e:
270
+ logger.debug(f"IndexParams method failed for tgt_id: {e}")
271
+ self._create_scalar_index_fallback("tgt_id", "INVERTED")
272
+
273
+ elif "chunks" in self.namespace.lower():
274
+ # Create indexes for chunk fields
275
+ try:
276
+ doc_id_index = self._get_index_params()
277
+ doc_id_index.add_index(
278
+ field_name="full_doc_id", index_type="INVERTED"
279
+ )
280
+ self._client.create_index(
281
+ collection_name=self.namespace, index_params=doc_id_index
282
+ )
283
+ except Exception as e:
284
+ logger.debug(f"IndexParams method failed for full_doc_id: {e}")
285
+ self._create_scalar_index_fallback("full_doc_id", "INVERTED")
286
+
287
+ # No common indexes needed
288
+
289
+ else:
290
+ # Fallback to direct API calls if IndexParams is not available
291
+ logger.info(
292
+ f"IndexParams not available, using fallback methods for {self.namespace}"
293
+ )
294
+
295
+ # Create vector index using fallback
296
+ self._create_vector_index_fallback()
297
+
298
+ # Create scalar indexes using fallback
299
+ if "entities" in self.namespace.lower():
300
+ self._create_scalar_index_fallback("entity_name", "INVERTED")
301
+ self._create_scalar_index_fallback("entity_type", "INVERTED")
302
+ elif "relationships" in self.namespace.lower():
303
+ self._create_scalar_index_fallback("src_id", "INVERTED")
304
+ self._create_scalar_index_fallback("tgt_id", "INVERTED")
305
+ elif "chunks" in self.namespace.lower():
306
+ self._create_scalar_index_fallback("full_doc_id", "INVERTED")
307
+
308
+ logger.info(f"Created indexes for collection: {self.namespace}")
309
+
310
+ except Exception as e:
311
+ logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
312
+
313
+ def _get_required_fields_for_namespace(self) -> dict:
314
+ """Get required core field definitions for current namespace"""
315
+
316
+ # Base fields (common to all types)
317
+ base_fields = {
318
+ "id": {"type": "VarChar", "is_primary": True},
319
+ "vector": {"type": "FloatVector"},
320
+ "created_at": {"type": "Int64"},
321
+ }
322
+
323
+ # Add specific fields based on namespace
324
+ if "entities" in self.namespace.lower():
325
+ specific_fields = {
326
+ "entity_name": {"type": "VarChar"},
327
+ "entity_type": {"type": "VarChar"},
328
+ "file_path": {"type": "VarChar"},
329
+ }
330
+ elif "relationships" in self.namespace.lower():
331
+ specific_fields = {
332
+ "src_id": {"type": "VarChar"},
333
+ "tgt_id": {"type": "VarChar"},
334
+ "weight": {"type": "Double"},
335
+ "file_path": {"type": "VarChar"},
336
+ }
337
+ elif "chunks" in self.namespace.lower():
338
+ specific_fields = {
339
+ "full_doc_id": {"type": "VarChar"},
340
+ "file_path": {"type": "VarChar"},
341
+ }
342
+ else:
343
+ specific_fields = {
344
+ "file_path": {"type": "VarChar"},
345
+ }
346
+
347
+ return {**base_fields, **specific_fields}
348
+
349
+ def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
350
+ """Check compatibility of a single field"""
351
+ field_name = existing_field.get("name", "unknown")
352
+ existing_type = existing_field.get("type")
353
+ expected_type = expected_config.get("type")
354
+
355
+ logger.debug(
356
+ f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
357
+ )
358
+
359
+ # Convert DataType enum values to string names if needed
360
+ original_existing_type = existing_type
361
+ if hasattr(existing_type, "name"):
362
+ existing_type = existing_type.name
363
+ logger.debug(
364
+ f"Converted enum to name: {original_existing_type} -> {existing_type}"
365
+ )
366
+ elif isinstance(existing_type, int):
367
+ # Map common Milvus internal type codes to type names for backward compatibility
368
+ type_mapping = {
369
+ 21: "VarChar",
370
+ 101: "FloatVector",
371
+ 5: "Int64",
372
+ 9: "Double",
373
+ }
374
+ mapped_type = type_mapping.get(existing_type, str(existing_type))
375
+ logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
376
+ existing_type = mapped_type
377
+
378
+ # Normalize type names for comparison
379
+ type_aliases = {
380
+ "VARCHAR": "VarChar",
381
+ "String": "VarChar",
382
+ "FLOAT_VECTOR": "FloatVector",
383
+ "INT64": "Int64",
384
+ "BigInt": "Int64",
385
+ "DOUBLE": "Double",
386
+ "Float": "Double",
387
+ }
388
+
389
+ original_existing = existing_type
390
+ original_expected = expected_type
391
+ existing_type = type_aliases.get(existing_type, existing_type)
392
+ expected_type = type_aliases.get(expected_type, expected_type)
393
+
394
+ if original_existing != existing_type or original_expected != expected_type:
395
+ logger.debug(
396
+ f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
397
+ )
398
+
399
+ # Basic type compatibility check
400
+ type_compatible = existing_type == expected_type
401
+ logger.debug(
402
+ f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
403
  )
404
 
405
+ if not type_compatible:
406
+ logger.warning(
407
+ f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
408
+ )
409
+ return False
410
+
411
+ # Primary key check - be more flexible about primary key detection
412
+ if expected_config.get("is_primary"):
413
+ # Check multiple possible field names for primary key status
414
+ is_primary = (
415
+ existing_field.get("is_primary_key", False)
416
+ or existing_field.get("is_primary", False)
417
+ or existing_field.get("primary_key", False)
418
+ )
419
+ logger.debug(
420
+ f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
421
+ )
422
+ logger.debug(f"Raw field data for '{field_name}': {existing_field}")
423
+
424
+ # For ID field, be more lenient - if it's the ID field, assume it should be primary
425
+ if field_name == "id" and not is_primary:
426
+ logger.info(
427
+ f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
428
+ )
429
+ # Don't fail for ID field primary key mismatch
430
+ elif not is_primary:
431
+ logger.warning(
432
+ f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
433
+ )
434
+ return False
435
+
436
+ logger.debug(f"Field '{field_name}' is compatible")
437
+ return True
438
+
439
+ def _check_vector_dimension(self, collection_info: dict):
440
+ """Check vector dimension compatibility"""
441
+ current_dimension = self.embedding_func.embedding_dim
442
+
443
+ # Find vector field dimension
444
+ for field in collection_info.get("fields", []):
445
+ if field.get("name") == "vector":
446
+ field_type = field.get("type")
447
+ if field_type in ["FloatVector", "FLOAT_VECTOR"]:
448
+ existing_dimension = field.get("params", {}).get("dim")
449
+
450
+ if existing_dimension != current_dimension:
451
+ raise ValueError(
452
+ f"Vector dimension mismatch for collection '{self.namespace}': "
453
+ f"existing={existing_dimension}, current={current_dimension}"
454
+ )
455
+
456
+ logger.debug(f"Vector dimension check passed: {current_dimension}")
457
+ return
458
+
459
+ # If no vector field found, this might be an old collection created with simple schema
460
+ logger.warning(
461
+ f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
462
+ )
463
+ logger.warning("Consider recreating the collection for optimal performance.")
464
+ return
465
+
466
+ def _check_schema_compatibility(self, collection_info: dict):
467
+ """Check schema field compatibility"""
468
+ existing_fields = {
469
+ field["name"]: field for field in collection_info.get("fields", [])
470
+ }
471
+
472
+ # Check if this is an old collection created with simple schema
473
+ has_vector_field = any(
474
+ field.get("name") == "vector" for field in collection_info.get("fields", [])
475
+ )
476
+
477
+ if not has_vector_field:
478
+ logger.warning(
479
+ f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
480
+ )
481
+ logger.warning(
482
+ "This collection will work but may have suboptimal performance"
483
+ )
484
+ logger.warning("Consider recreating the collection for optimal performance")
485
+ return
486
+
487
+ # For collections with vector field, check basic compatibility
488
+ # Only check for critical incompatibilities, not missing optional fields
489
+ critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
490
+
491
+ incompatible_fields = []
492
+
493
+ for field_name, expected_config in critical_fields.items():
494
+ if field_name in existing_fields:
495
+ existing_field = existing_fields[field_name]
496
+ if not self._is_field_compatible(existing_field, expected_config):
497
+ incompatible_fields.append(
498
+ f"{field_name}: expected {expected_config['type']}, "
499
+ f"got {existing_field.get('type')}"
500
+ )
501
+
502
+ if incompatible_fields:
503
+ raise ValueError(
504
+ f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
505
+ )
506
+
507
+ # Get all expected fields for informational purposes
508
+ expected_fields = self._get_required_fields_for_namespace()
509
+ missing_fields = [
510
+ field for field in expected_fields if field not in existing_fields
511
+ ]
512
+
513
+ if missing_fields:
514
+ logger.info(
515
+ f"Collection {self.namespace} missing optional fields: {missing_fields}"
516
+ )
517
+ logger.info(
518
+ "These fields would be available in a newly created collection for better performance"
519
+ )
520
+
521
+ logger.debug(f"Schema compatibility check passed for {self.namespace}")
522
+
523
+ def _validate_collection_compatibility(self):
524
+ """Validate existing collection's dimension and schema compatibility"""
525
+ try:
526
+ collection_info = self._client.describe_collection(self.namespace)
527
+
528
+ # 1. Check vector dimension
529
+ self._check_vector_dimension(collection_info)
530
+
531
+ # 2. Check schema compatibility
532
+ self._check_schema_compatibility(collection_info)
533
+
534
+ logger.info(f"Collection {self.namespace} compatibility validation passed")
535
+
536
+ except Exception as e:
537
+ logger.error(
538
+ f"Collection compatibility validation failed for {self.namespace}: {e}"
539
+ )
540
+ raise
541
+
542
+ def _create_collection_if_not_exist(self):
543
+ """Create collection if not exists and check existing collection compatibility"""
544
+
545
+ try:
546
+ # First, list all collections to see what actually exists
547
+ try:
548
+ all_collections = self._client.list_collections()
549
+ logger.debug(f"All collections in database: {all_collections}")
550
+ except Exception as list_error:
551
+ logger.warning(f"Could not list collections: {list_error}")
552
+ all_collections = []
553
+
554
+ # Check if our specific collection exists
555
+ collection_exists = self._client.has_collection(self.namespace)
556
+ logger.info(
557
+ f"Collection '{self.namespace}' exists check: {collection_exists}"
558
+ )
559
+
560
+ if collection_exists:
561
+ # Double-check by trying to describe the collection
562
+ try:
563
+ self._client.describe_collection(self.namespace)
564
+ logger.info(
565
+ f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
566
+ )
567
+ self._validate_collection_compatibility()
568
+ return
569
+ except Exception as describe_error:
570
+ logger.warning(
571
+ f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
572
+ )
573
+ logger.info(
574
+ "Treating as if collection doesn't exist and creating new one..."
575
+ )
576
+ # Fall through to creation logic
577
+
578
+ # Collection doesn't exist, create new collection
579
+ logger.info(f"Creating new collection: {self.namespace}")
580
+ schema = self._create_schema_for_namespace()
581
+
582
+ # Create collection with schema only first
583
+ self._client.create_collection(
584
+ collection_name=self.namespace, schema=schema
585
+ )
586
+
587
+ # Then create indexes
588
+ self._create_indexes_after_collection()
589
+
590
+ logger.info(f"Successfully created Milvus collection: {self.namespace}")
591
+
592
+ except Exception as e:
593
+ logger.error(
594
+ f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
595
+ )
596
+
597
+ # If there's any error, try to force create the collection
598
+ logger.info(f"Attempting to force create collection {self.namespace}...")
599
+ try:
600
+ # Try to drop the collection first if it exists in a bad state
601
+ try:
602
+ if self._client.has_collection(self.namespace):
603
+ logger.info(
604
+ f"Dropping potentially corrupted collection {self.namespace}"
605
+ )
606
+ self._client.drop_collection(self.namespace)
607
+ except Exception as drop_error:
608
+ logger.warning(
609
+ f"Could not drop collection {self.namespace}: {drop_error}"
610
+ )
611
+
612
+ # Create fresh collection
613
+ schema = self._create_schema_for_namespace()
614
+ self._client.create_collection(
615
+ collection_name=self.namespace, schema=schema
616
+ )
617
+ self._create_indexes_after_collection()
618
+ logger.info(f"Successfully force-created collection {self.namespace}")
619
+
620
+ except Exception as create_error:
621
+ logger.error(
622
+ f"Failed to force-create collection {self.namespace}: {create_error}"
623
+ )
624
+ raise
625
+
626
  def __post_init__(self):
627
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
628
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
632
  )
633
  self.cosine_better_than_threshold = cosine_threshold
634
 
635
+ # Ensure created_at is in meta_fields
636
+ if "created_at" not in self.meta_fields:
637
+ self.meta_fields.add("created_at")
638
+
639
  self._client = MilvusClient(
640
  uri=os.environ.get(
641
  "MILVUS_URI",
 
661
  ),
662
  )
663
  self._max_batch_size = self.global_config["embedding_batch_num"]
664
+
665
+ # Create collection and check compatibility
666
+ self._create_collection_if_not_exist()
 
 
667
 
668
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
669
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
670
  if not data:
671
  return
672
 
 
703
  embedding = await self.embedding_func(
704
  [query], _priority=5
705
  ) # higher priority for query
706
+
707
+ # Include all meta_fields (created_at is now always included)
708
+ output_fields = list(self.meta_fields)
709
+
710
  results = self._client.search(
711
  collection_name=self.namespace,
712
  data=embedding,
713
  limit=top_k,
714
+ output_fields=output_fields,
715
  search_params={
716
  "metric_type": "COSINE",
717
  "params": {"radius": self.cosine_better_than_threshold},
718
  },
719
  )
 
720
  return [
721
  {
722
  **dp["entity"],
723
  "id": dp["id"],
724
  "distance": dp["distance"],
 
725
  "created_at": dp.get("created_at"),
726
  }
727
  for dp in results[0]
 
825
  The vector data if found, or None if not found
826
  """
827
  try:
828
+ # Include all meta_fields (created_at is now always included) plus id
829
+ output_fields = list(self.meta_fields) + ["id"]
830
+
831
  # Query Milvus for a specific ID
832
  result = self._client.query(
833
  collection_name=self.namespace,
834
  filter=f'id == "{id}"',
835
+ output_fields=output_fields,
836
  )
837
 
838
  if not result or len(result) == 0:
839
  return None
840
 
 
 
 
 
841
  return result[0]
842
  except Exception as e:
843
  logger.error(f"Error retrieving vector data for ID {id}: {e}")
 
856
  return []
857
 
858
  try:
859
+ # Include all meta_fields (created_at is now always included) plus id
860
+ output_fields = list(self.meta_fields) + ["id"]
861
+
862
  # Prepare the ID filter expression
863
  id_list = '", "'.join(ids)
864
  filter_expr = f'id in ["{id_list}"]'
 
867
  result = self._client.query(
868
  collection_name=self.namespace,
869
  filter=filter_expr,
870
+ output_fields=output_fields,
871
  )
872
 
 
 
 
 
 
873
  return result or []
874
  except Exception as e:
875
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
 
891
  self._client.drop_collection(self.namespace)
892
 
893
  # Recreate the collection
894
+ self._create_collection_if_not_exist()
 
 
 
 
895
 
896
  logger.info(
897
  f"Process {os.getpid()} drop Milvus collection {self.namespace}"
lightrag/kg/mongo_impl.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from dataclasses import dataclass, field
3
  import numpy as np
4
  import configparser
@@ -14,7 +15,6 @@ from ..base import (
14
  DocStatus,
15
  DocStatusStorage,
16
  )
17
- from ..namespace import NameSpace, is_namespace
18
  from ..utils import logger, compute_mdhash_id
19
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
20
  from ..constants import GRAPH_FIELD_SEP
@@ -35,6 +35,7 @@ config.read("config.ini", "utf-8")
35
 
36
  # Get maximum number of graph nodes from environment variable, default is 1000
37
  MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
 
38
 
39
 
40
  class ClientManager:
@@ -96,11 +97,22 @@ class MongoKVStorage(BaseKVStorage):
96
  self._data = None
97
 
98
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
99
- return await self._data.find_one({"_id": id})
 
 
 
 
 
 
100
 
101
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
102
  cursor = self._data.find({"_id": {"$in": ids}})
103
- return await cursor.to_list()
 
 
 
 
 
104
 
105
  async def filter_keys(self, keys: set[str]) -> set[str]:
106
  cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@@ -117,47 +129,53 @@ class MongoKVStorage(BaseKVStorage):
117
  result = {}
118
  async for doc in cursor:
119
  doc_id = doc.pop("_id")
 
 
 
120
  result[doc_id] = doc
121
  return result
122
 
123
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
124
- logger.info(f"Inserting {len(data)} to {self.namespace}")
125
  if not data:
126
  return
127
 
128
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
129
- update_tasks: list[Any] = []
130
- for mode, items in data.items():
131
- for k, v in items.items():
132
- key = f"{mode}_{k}"
133
- data[mode][k]["_id"] = f"{mode}_{k}"
134
- update_tasks.append(
135
- self._data.update_one(
136
- {"_id": key}, {"$setOnInsert": v}, upsert=True
137
- )
138
- )
139
- await asyncio.gather(*update_tasks)
140
- else:
141
- update_tasks = []
142
- for k, v in data.items():
143
- data[k]["_id"] = k
144
- update_tasks.append(
145
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
- await asyncio.gather(*update_tasks)
148
-
149
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
150
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
151
- res = {}
152
- v = await self._data.find_one({"_id": mode + "_" + id})
153
- if v:
154
- res[id] = v
155
- logger.debug(f"llm_response_cache find one by:{id}")
156
- return res
157
- else:
158
- return None
159
- else:
160
- return None
161
 
162
  async def index_done_callback(self) -> None:
163
  # Mongo handles persistence automatically
@@ -197,8 +215,8 @@ class MongoKVStorage(BaseKVStorage):
197
  return False
198
 
199
  try:
200
- # Build regex pattern to match documents with the specified modes
201
- pattern = f"^({'|'.join(modes)})_"
202
  result = await self._data.delete_many({"_id": {"$regex": pattern}})
203
  logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
204
  return True
@@ -262,11 +280,14 @@ class MongoDocStatusStorage(DocStatusStorage):
262
  return data - existing_ids
263
 
264
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
265
- logger.info(f"Inserting {len(data)} to {self.namespace}")
266
  if not data:
267
  return
268
  update_tasks: list[Any] = []
269
  for k, v in data.items():
 
 
 
270
  data[k]["_id"] = k
271
  update_tasks.append(
272
  self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
@@ -299,6 +320,7 @@ class MongoDocStatusStorage(DocStatusStorage):
299
  updated_at=doc.get("updated_at"),
300
  chunks_count=doc.get("chunks_count", -1),
301
  file_path=doc.get("file_path", doc["_id"]),
 
302
  )
303
  for doc in result
304
  }
@@ -417,11 +439,21 @@ class MongoGraphStorage(BaseGraphStorage):
417
 
418
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
419
  """
420
- Check if there's a direct single-hop edge from source_node_id to target_node_id.
421
  """
422
- # Direct check if the target_node appears among the edges array.
423
  doc = await self.edge_collection.find_one(
424
- {"source_node_id": source_node_id, "target_node_id": target_node_id},
 
 
 
 
 
 
 
 
 
 
 
425
  {"_id": 1},
426
  )
427
  return doc is not None
@@ -651,7 +683,7 @@ class MongoGraphStorage(BaseGraphStorage):
651
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
652
  ) -> None:
653
  """
654
- Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
655
  If an edge with the same target exists, we remove it and re-insert with updated data.
656
  """
657
  # Ensure source node exists
@@ -663,8 +695,22 @@ class MongoGraphStorage(BaseGraphStorage):
663
  GRAPH_FIELD_SEP
664
  )
665
 
 
 
 
666
  await self.edge_collection.update_one(
667
- {"source_node_id": source_node_id, "target_node_id": target_node_id},
 
 
 
 
 
 
 
 
 
 
 
668
  update_doc,
669
  upsert=True,
670
  )
@@ -678,7 +724,7 @@ class MongoGraphStorage(BaseGraphStorage):
678
  async def delete_node(self, node_id: str) -> None:
679
  """
680
  1) Remove node's doc entirely.
681
- 2) Remove inbound edges from any doc that references node_id.
682
  """
683
  # Remove all edges
684
  await self.edge_collection.delete_many(
@@ -709,141 +755,369 @@ class MongoGraphStorage(BaseGraphStorage):
709
  labels.append(doc["_id"])
710
  return labels
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  async def get_knowledge_graph(
713
  self,
714
  node_label: str,
715
- max_depth: int = 5,
716
  max_nodes: int = MAX_GRAPH_NODES,
717
  ) -> KnowledgeGraph:
718
  """
719
- Get complete connected subgraph for specified node (including the starting node itself)
720
 
721
  Args:
722
- node_label: Label of the nodes to start from
723
- max_depth: Maximum depth of traversal (default: 5)
 
724
 
725
  Returns:
726
- KnowledgeGraph object containing nodes and edges of the subgraph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  """
728
- label = node_label
729
  result = KnowledgeGraph()
730
- seen_nodes = set()
731
- seen_edges = set()
732
- node_edges = []
733
 
734
  try:
735
  # Optimize pipeline to avoid memory issues with large datasets
736
- if label == "*":
737
- # For getting all nodes, use a simpler pipeline to avoid memory issues
738
- pipeline = [
739
- {"$limit": max_nodes}, # Limit early to reduce memory usage
740
- {
741
- "$graphLookup": {
742
- "from": self._edge_collection_name,
743
- "startWith": "$_id",
744
- "connectFromField": "target_node_id",
745
- "connectToField": "source_node_id",
746
- "maxDepth": max_depth,
747
- "depthField": "depth",
748
- "as": "connected_edges",
749
- },
750
- },
751
- ]
752
-
753
- # Check if we need to set truncation flag
754
- all_node_count = await self.collection.count_documents({})
755
- result.is_truncated = all_node_count > max_nodes
756
  else:
757
- # Verify if starting node exists
758
- start_node = await self.collection.find_one({"_id": label})
759
- if not start_node:
760
- logger.warning(f"Starting node with label {label} does not exist!")
761
- return result
762
-
763
- # For specific node queries, use the original pipeline but optimized
764
- pipeline = [
765
- {"$match": {"_id": label}},
766
- {
767
- "$graphLookup": {
768
- "from": self._edge_collection_name,
769
- "startWith": "$_id",
770
- "connectFromField": "target_node_id",
771
- "connectToField": "source_node_id",
772
- "maxDepth": max_depth,
773
- "depthField": "depth",
774
- "as": "connected_edges",
775
- },
776
- },
777
- {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
778
- {"$sort": {"edge_count": -1}},
779
- {"$limit": max_nodes},
780
- ]
781
-
782
- cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
783
- nodes_processed = 0
784
-
785
- async for doc in cursor:
786
- # Add the start node
787
- node_id = str(doc["_id"])
788
- result.nodes.append(
789
- KnowledgeGraphNode(
790
- id=node_id,
791
- labels=[node_id],
792
- properties={
793
- k: v
794
- for k, v in doc.items()
795
- if k
796
- not in [
797
- "_id",
798
- "connected_edges",
799
- "edge_count",
800
- ]
801
- },
802
- )
803
  )
804
- seen_nodes.add(node_id)
805
- if doc.get("connected_edges", []):
806
- node_edges.extend(doc.get("connected_edges"))
807
 
808
- nodes_processed += 1
809
-
810
- # Additional safety check to prevent memory issues
811
- if nodes_processed >= max_nodes:
812
- result.is_truncated = True
813
- break
814
-
815
- for edge in node_edges:
816
- if (
817
- edge["source_node_id"] not in seen_nodes
818
- or edge["target_node_id"] not in seen_nodes
819
- ):
820
- continue
821
-
822
- edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
823
- if edge_id not in seen_edges:
824
- result.edges.append(
825
- KnowledgeGraphEdge(
826
- id=edge_id,
827
- type=edge.get("relationship", ""),
828
- source=edge["source_node_id"],
829
- target=edge["target_node_id"],
830
- properties={
831
- k: v
832
- for k, v in edge.items()
833
- if k
834
- not in [
835
- "_id",
836
- "source_node_id",
837
- "target_node_id",
838
- "relationship",
839
- ]
840
- },
841
- )
842
- )
843
- seen_edges.add(edge_id)
844
 
845
  logger.info(
846
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
847
  )
848
 
849
  except PyMongoError as e:
@@ -856,13 +1130,8 @@ class MongoGraphStorage(BaseGraphStorage):
856
  try:
857
  simple_cursor = self.collection.find({}).limit(max_nodes)
858
  async for doc in simple_cursor:
859
- node_id = str(doc["_id"])
860
  result.nodes.append(
861
- KnowledgeGraphNode(
862
- id=node_id,
863
- labels=[node_id],
864
- properties={k: v for k, v in doc.items() if k != "_id"},
865
- )
866
  )
867
  result.is_truncated = True
868
  logger.info(
@@ -1023,13 +1292,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
1023
  logger.debug("vector index already exist")
1024
 
1025
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
1026
- logger.info(f"Inserting {len(data)} to {self.namespace}")
1027
  if not data:
1028
  return
1029
 
1030
  # Add current time as Unix timestamp
1031
- import time
1032
-
1033
  current_time = int(time.time())
1034
 
1035
  list_data = [
@@ -1114,7 +1381,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
1114
  Args:
1115
  ids: List of vector IDs to be deleted
1116
  """
1117
- logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
1118
  if not ids:
1119
  return
1120
 
 
1
  import os
2
+ import time
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import configparser
 
15
  DocStatus,
16
  DocStatusStorage,
17
  )
 
18
  from ..utils import logger, compute_mdhash_id
19
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
20
  from ..constants import GRAPH_FIELD_SEP
 
35
 
36
  # Get maximum number of graph nodes from environment variable, default is 1000
37
  MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
38
+ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
39
 
40
 
41
  class ClientManager:
 
97
  self._data = None
98
 
99
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
100
+ # Unified handling for flattened keys
101
+ doc = await self._data.find_one({"_id": id})
102
+ if doc:
103
+ # Ensure time fields are present, provide default values for old data
104
+ doc.setdefault("create_time", 0)
105
+ doc.setdefault("update_time", 0)
106
+ return doc
107
 
108
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
109
  cursor = self._data.find({"_id": {"$in": ids}})
110
+ docs = await cursor.to_list()
111
+ # Ensure time fields are present for all documents
112
+ for doc in docs:
113
+ doc.setdefault("create_time", 0)
114
+ doc.setdefault("update_time", 0)
115
+ return docs
116
 
117
  async def filter_keys(self, keys: set[str]) -> set[str]:
118
  cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
 
129
  result = {}
130
  async for doc in cursor:
131
  doc_id = doc.pop("_id")
132
+ # Ensure time fields are present for all documents
133
+ doc.setdefault("create_time", 0)
134
+ doc.setdefault("update_time", 0)
135
  result[doc_id] = doc
136
  return result
137
 
138
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
139
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
140
  if not data:
141
  return
142
 
143
+ # Unified handling for all namespaces with flattened keys
144
+ # Use bulk_write for better performance
145
+ from pymongo import UpdateOne
146
+
147
+ operations = []
148
+ current_time = int(time.time()) # Get current Unix timestamp
149
+
150
+ for k, v in data.items():
151
+ # For text_chunks namespace, ensure llm_cache_list field exists
152
+ if self.namespace.endswith("text_chunks"):
153
+ if "llm_cache_list" not in v:
154
+ v["llm_cache_list"] = []
155
+
156
+ # Create a copy of v for $set operation, excluding create_time to avoid conflicts
157
+ v_for_set = v.copy()
158
+ v_for_set["_id"] = k # Use flattened key as _id
159
+ v_for_set["update_time"] = current_time # Always update update_time
160
+
161
+ # Remove create_time from $set to avoid conflict with $setOnInsert
162
+ v_for_set.pop("create_time", None)
163
+
164
+ operations.append(
165
+ UpdateOne(
166
+ {"_id": k},
167
+ {
168
+ "$set": v_for_set, # Update all fields except create_time
169
+ "$setOnInsert": {
170
+ "create_time": current_time
171
+ }, # Set create_time only on insert
172
+ },
173
+ upsert=True,
174
  )
175
+ )
176
+
177
+ if operations:
178
+ await self._data.bulk_write(operations)
 
 
 
 
 
 
 
 
 
 
179
 
180
  async def index_done_callback(self) -> None:
181
  # Mongo handles persistence automatically
 
215
  return False
216
 
217
  try:
218
+ # Build regex pattern to match flattened key format: mode:cache_type:hash
219
+ pattern = f"^({'|'.join(modes)}):"
220
  result = await self._data.delete_many({"_id": {"$regex": pattern}})
221
  logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
222
  return True
 
280
  return data - existing_ids
281
 
282
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
283
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
284
  if not data:
285
  return
286
  update_tasks: list[Any] = []
287
  for k, v in data.items():
288
+ # Ensure chunks_list field exists and is an array
289
+ if "chunks_list" not in v:
290
+ v["chunks_list"] = []
291
  data[k]["_id"] = k
292
  update_tasks.append(
293
  self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
 
320
  updated_at=doc.get("updated_at"),
321
  chunks_count=doc.get("chunks_count", -1),
322
  file_path=doc.get("file_path", doc["_id"]),
323
+ chunks_list=doc.get("chunks_list", []),
324
  )
325
  for doc in result
326
  }
 
439
 
440
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
441
  """
442
+ Check if there's a direct single-hop edge between source_node_id and target_node_id.
443
  """
 
444
  doc = await self.edge_collection.find_one(
445
+ {
446
+ "$or": [
447
+ {
448
+ "source_node_id": source_node_id,
449
+ "target_node_id": target_node_id,
450
+ },
451
+ {
452
+ "source_node_id": target_node_id,
453
+ "target_node_id": source_node_id,
454
+ },
455
+ ]
456
+ },
457
  {"_id": 1},
458
  )
459
  return doc is not None
 
683
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
684
  ) -> None:
685
  """
686
+ Upsert an edge between source_node_id and target_node_id with optional 'relation'.
687
  If an edge with the same target exists, we remove it and re-insert with updated data.
688
  """
689
  # Ensure source node exists
 
695
  GRAPH_FIELD_SEP
696
  )
697
 
698
+ edge_data["source_node_id"] = source_node_id
699
+ edge_data["target_node_id"] = target_node_id
700
+
701
  await self.edge_collection.update_one(
702
+ {
703
+ "$or": [
704
+ {
705
+ "source_node_id": source_node_id,
706
+ "target_node_id": target_node_id,
707
+ },
708
+ {
709
+ "source_node_id": target_node_id,
710
+ "target_node_id": source_node_id,
711
+ },
712
+ ]
713
+ },
714
  update_doc,
715
  upsert=True,
716
  )
 
724
  async def delete_node(self, node_id: str) -> None:
725
  """
726
  1) Remove node's doc entirely.
727
+ 2) Remove inbound & outbound edges from any doc that references node_id.
728
  """
729
  # Remove all edges
730
  await self.edge_collection.delete_many(
 
755
  labels.append(doc["_id"])
756
  return labels
757
 
758
+ def _construct_graph_node(
759
+ self, node_id, node_data: dict[str, str]
760
+ ) -> KnowledgeGraphNode:
761
+ return KnowledgeGraphNode(
762
+ id=node_id,
763
+ labels=[node_id],
764
+ properties={
765
+ k: v
766
+ for k, v in node_data.items()
767
+ if k
768
+ not in [
769
+ "_id",
770
+ "connected_edges",
771
+ "source_ids",
772
+ "edge_count",
773
+ ]
774
+ },
775
+ )
776
+
777
+ def _construct_graph_edge(self, edge_id: str, edge: dict[str, str]):
778
+ return KnowledgeGraphEdge(
779
+ id=edge_id,
780
+ type=edge.get("relationship", ""),
781
+ source=edge["source_node_id"],
782
+ target=edge["target_node_id"],
783
+ properties={
784
+ k: v
785
+ for k, v in edge.items()
786
+ if k
787
+ not in [
788
+ "_id",
789
+ "source_node_id",
790
+ "target_node_id",
791
+ "relationship",
792
+ "source_ids",
793
+ ]
794
+ },
795
+ )
796
+
797
+ async def get_knowledge_graph_all_by_degree(
798
+ self, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
799
+ ) -> KnowledgeGraph:
800
+ """
801
+ It's possible that the node with one or multiple relationships is retrieved,
802
+ while its neighbor is not. Then this node might seem like disconnected in UI.
803
+ """
804
+
805
+ total_node_count = await self.collection.count_documents({})
806
+ result = KnowledgeGraph()
807
+ seen_edges = set()
808
+
809
+ result.is_truncated = total_node_count > max_nodes
810
+ if result.is_truncated:
811
+ # Get all node_ids ranked by degree if max_nodes exceeds total node count
812
+ pipeline = [
813
+ {"$project": {"source_node_id": 1, "_id": 0}},
814
+ {"$group": {"_id": "$source_node_id", "degree": {"$sum": 1}}},
815
+ {
816
+ "$unionWith": {
817
+ "coll": self._edge_collection_name,
818
+ "pipeline": [
819
+ {"$project": {"target_node_id": 1, "_id": 0}},
820
+ {
821
+ "$group": {
822
+ "_id": "$target_node_id",
823
+ "degree": {"$sum": 1},
824
+ }
825
+ },
826
+ ],
827
+ }
828
+ },
829
+ {"$group": {"_id": "$_id", "degree": {"$sum": "$degree"}}},
830
+ {"$sort": {"degree": -1}},
831
+ {"$limit": max_nodes},
832
+ ]
833
+ cursor = await self.edge_collection.aggregate(pipeline, allowDiskUse=True)
834
+
835
+ node_ids = []
836
+ async for doc in cursor:
837
+ node_id = str(doc["_id"])
838
+ node_ids.append(node_id)
839
+
840
+ cursor = self.collection.find({"_id": {"$in": node_ids}}, {"source_ids": 0})
841
+ async for doc in cursor:
842
+ result.nodes.append(self._construct_graph_node(doc["_id"], doc))
843
+
844
+ # As node count reaches the limit, only need to fetch the edges that directly connect to these nodes
845
+ edge_cursor = self.edge_collection.find(
846
+ {
847
+ "$and": [
848
+ {"source_node_id": {"$in": node_ids}},
849
+ {"target_node_id": {"$in": node_ids}},
850
+ ]
851
+ }
852
+ )
853
+ else:
854
+ # All nodes and edges are needed
855
+ cursor = self.collection.find({}, {"source_ids": 0})
856
+
857
+ async for doc in cursor:
858
+ node_id = str(doc["_id"])
859
+ result.nodes.append(self._construct_graph_node(doc["_id"], doc))
860
+
861
+ edge_cursor = self.edge_collection.find({})
862
+
863
+ async for edge in edge_cursor:
864
+ edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
865
+ if edge_id not in seen_edges:
866
+ seen_edges.add(edge_id)
867
+ result.edges.append(self._construct_graph_edge(edge_id, edge))
868
+
869
+ return result
870
+
871
+ async def _bidirectional_bfs_nodes(
872
+ self,
873
+ node_labels: list[str],
874
+ seen_nodes: set[str],
875
+ result: KnowledgeGraph,
876
+ depth: int = 0,
877
+ max_depth: int = 3,
878
+ max_nodes: int = MAX_GRAPH_NODES,
879
+ ) -> KnowledgeGraph:
880
+ if depth > max_depth or len(result.nodes) > max_nodes:
881
+ return result
882
+
883
+ cursor = self.collection.find({"_id": {"$in": node_labels}})
884
+
885
+ async for node in cursor:
886
+ node_id = node["_id"]
887
+ if node_id not in seen_nodes:
888
+ seen_nodes.add(node_id)
889
+ result.nodes.append(self._construct_graph_node(node_id, node))
890
+ if len(result.nodes) > max_nodes:
891
+ return result
892
+
893
+ # Collect neighbors
894
+ # Get both inbound and outbound one hop nodes
895
+ cursor = self.edge_collection.find(
896
+ {
897
+ "$or": [
898
+ {"source_node_id": {"$in": node_labels}},
899
+ {"target_node_id": {"$in": node_labels}},
900
+ ]
901
+ }
902
+ )
903
+
904
+ neighbor_nodes = []
905
+ async for edge in cursor:
906
+ if edge["source_node_id"] not in seen_nodes:
907
+ neighbor_nodes.append(edge["source_node_id"])
908
+ if edge["target_node_id"] not in seen_nodes:
909
+ neighbor_nodes.append(edge["target_node_id"])
910
+
911
+ if neighbor_nodes:
912
+ result = await self._bidirectional_bfs_nodes(
913
+ neighbor_nodes, seen_nodes, result, depth + 1, max_depth, max_nodes
914
+ )
915
+
916
+ return result
917
+
918
+ async def get_knowledge_subgraph_bidirectional_bfs(
919
+ self,
920
+ node_label: str,
921
+ depth=0,
922
+ max_depth: int = 3,
923
+ max_nodes: int = MAX_GRAPH_NODES,
924
+ ) -> KnowledgeGraph:
925
+ seen_nodes = set()
926
+ seen_edges = set()
927
+ result = KnowledgeGraph()
928
+
929
+ result = await self._bidirectional_bfs_nodes(
930
+ [node_label], seen_nodes, result, depth, max_depth, max_nodes
931
+ )
932
+
933
+ # Get all edges from seen_nodes
934
+ all_node_ids = list(seen_nodes)
935
+ cursor = self.edge_collection.find(
936
+ {
937
+ "$and": [
938
+ {"source_node_id": {"$in": all_node_ids}},
939
+ {"target_node_id": {"$in": all_node_ids}},
940
+ ]
941
+ }
942
+ )
943
+
944
+ async for edge in cursor:
945
+ edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
946
+ if edge_id not in seen_edges:
947
+ result.edges.append(self._construct_graph_edge(edge_id, edge))
948
+ seen_edges.add(edge_id)
949
+
950
+ return result
951
+
952
+ async def get_knowledge_subgraph_in_out_bound_bfs(
953
+ self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES
954
+ ) -> KnowledgeGraph:
955
+ seen_nodes = set()
956
+ seen_edges = set()
957
+ result = KnowledgeGraph()
958
+ project_doc = {
959
+ "source_ids": 0,
960
+ "created_at": 0,
961
+ "entity_type": 0,
962
+ "file_path": 0,
963
+ }
964
+
965
+ # Verify if starting node exists
966
+ start_node = await self.collection.find_one({"_id": node_label})
967
+ if not start_node:
968
+ logger.warning(f"Starting node with label {node_label} does not exist!")
969
+ return result
970
+
971
+ seen_nodes.add(node_label)
972
+ result.nodes.append(self._construct_graph_node(node_label, start_node))
973
+
974
+ if max_depth == 0:
975
+ return result
976
+
977
+ # In MongoDB, depth = 0 means one-hop
978
+ max_depth = max_depth - 1
979
+
980
+ pipeline = [
981
+ {"$match": {"_id": node_label}},
982
+ {"$project": project_doc},
983
+ {
984
+ "$graphLookup": {
985
+ "from": self._edge_collection_name,
986
+ "startWith": "$_id",
987
+ "connectFromField": "target_node_id",
988
+ "connectToField": "source_node_id",
989
+ "maxDepth": max_depth,
990
+ "depthField": "depth",
991
+ "as": "connected_edges",
992
+ },
993
+ },
994
+ {
995
+ "$unionWith": {
996
+ "coll": self._collection_name,
997
+ "pipeline": [
998
+ {"$match": {"_id": node_label}},
999
+ {"$project": project_doc},
1000
+ {
1001
+ "$graphLookup": {
1002
+ "from": self._edge_collection_name,
1003
+ "startWith": "$_id",
1004
+ "connectFromField": "source_node_id",
1005
+ "connectToField": "target_node_id",
1006
+ "maxDepth": max_depth,
1007
+ "depthField": "depth",
1008
+ "as": "connected_edges",
1009
+ }
1010
+ },
1011
+ ],
1012
+ }
1013
+ },
1014
+ ]
1015
+
1016
+ cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
1017
+ node_edges = []
1018
+
1019
+ # Two records for node_label are returned capturing outbound and inbound connected_edges
1020
+ async for doc in cursor:
1021
+ if doc.get("connected_edges", []):
1022
+ node_edges.extend(doc.get("connected_edges"))
1023
+
1024
+ # Sort the connected edges by depth ascending and weight descending
1025
+ # And stores the source_node_id and target_node_id in sequence to retrieve the neighbouring nodes
1026
+ node_edges = sorted(
1027
+ node_edges,
1028
+ key=lambda x: (x["depth"], -x["weight"]),
1029
+ )
1030
+
1031
+ # As order matters, we need to use another list to store the node_id
1032
+ # And only take the first max_nodes ones
1033
+ node_ids = []
1034
+ for edge in node_edges:
1035
+ if len(node_ids) < max_nodes and edge["source_node_id"] not in seen_nodes:
1036
+ node_ids.append(edge["source_node_id"])
1037
+ seen_nodes.add(edge["source_node_id"])
1038
+
1039
+ if len(node_ids) < max_nodes and edge["target_node_id"] not in seen_nodes:
1040
+ node_ids.append(edge["target_node_id"])
1041
+ seen_nodes.add(edge["target_node_id"])
1042
+
1043
+ # Filter out all the node whose id is same as node_label so that we do not check existence next step
1044
+ cursor = self.collection.find({"_id": {"$in": node_ids}})
1045
+
1046
+ async for doc in cursor:
1047
+ result.nodes.append(self._construct_graph_node(str(doc["_id"]), doc))
1048
+
1049
+ for edge in node_edges:
1050
+ if (
1051
+ edge["source_node_id"] not in seen_nodes
1052
+ or edge["target_node_id"] not in seen_nodes
1053
+ ):
1054
+ continue
1055
+
1056
+ edge_id = f"{edge['source_node_id']}-{edge['target_node_id']}"
1057
+ if edge_id not in seen_edges:
1058
+ result.edges.append(self._construct_graph_edge(edge_id, edge))
1059
+ seen_edges.add(edge_id)
1060
+
1061
+ return result
1062
+
1063
  async def get_knowledge_graph(
1064
  self,
1065
  node_label: str,
1066
+ max_depth: int = 3,
1067
  max_nodes: int = MAX_GRAPH_NODES,
1068
  ) -> KnowledgeGraph:
1069
  """
1070
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
1071
 
1072
  Args:
1073
+ node_label: Label of the starting node, * means all nodes
1074
+ max_depth: Maximum depth of the subgraph, Defaults to 3
1075
+ max_nodes: Maxiumu nodes to return, Defaults to 1000
1076
 
1077
  Returns:
1078
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
1079
+ indicating whether the graph was truncated due to max_nodes limit
1080
+
1081
+ If a graph is like this and starting from B:
1082
+ A β†’ B ← C ← F, B -> E, C β†’ D
1083
+
1084
+ Outbound BFS:
1085
+ B β†’ E
1086
+
1087
+ Inbound BFS:
1088
+ A β†’ B
1089
+ C β†’ B
1090
+ F β†’ C
1091
+
1092
+ Bidirectional BFS:
1093
+ A β†’ B
1094
+ B β†’ E
1095
+ F β†’ C
1096
+ C β†’ B
1097
+ C β†’ D
1098
  """
 
1099
  result = KnowledgeGraph()
1100
+ start = time.perf_counter()
 
 
1101
 
1102
  try:
1103
  # Optimize pipeline to avoid memory issues with large datasets
1104
+ if node_label == "*":
1105
+ result = await self.get_knowledge_graph_all_by_degree(
1106
+ max_depth, max_nodes
1107
+ )
1108
+ elif GRAPH_BFS_MODE == "in_out_bound":
1109
+ result = await self.get_knowledge_subgraph_in_out_bound_bfs(
1110
+ node_label, max_depth, max_nodes
1111
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1112
  else:
1113
+ result = await self.get_knowledge_subgraph_bidirectional_bfs(
1114
+ node_label, 0, max_depth, max_nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  )
 
 
 
1116
 
1117
+ duration = time.perf_counter() - start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1118
 
1119
  logger.info(
1120
+ f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
1121
  )
1122
 
1123
  except PyMongoError as e:
 
1130
  try:
1131
  simple_cursor = self.collection.find({}).limit(max_nodes)
1132
  async for doc in simple_cursor:
 
1133
  result.nodes.append(
1134
+ self._construct_graph_node(str(doc["_id"]), doc)
 
 
 
 
1135
  )
1136
  result.is_truncated = True
1137
  logger.info(
 
1292
  logger.debug("vector index already exist")
1293
 
1294
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
1295
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
1296
  if not data:
1297
  return
1298
 
1299
  # Add current time as Unix timestamp
 
 
1300
  current_time = int(time.time())
1301
 
1302
  list_data = [
 
1381
  Args:
1382
  ids: List of vector IDs to be deleted
1383
  """
1384
+ logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
1385
  if not ids:
1386
  return
1387
 
lightrag/kg/networkx_impl.py CHANGED
@@ -106,7 +106,9 @@ class NetworkXStorage(BaseGraphStorage):
106
 
107
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
108
  graph = await self._get_graph()
109
- return graph.degree(src_id) + graph.degree(tgt_id)
 
 
110
 
111
  async def get_edge(
112
  self, source_node_id: str, target_node_id: str
 
106
 
107
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
108
  graph = await self._get_graph()
109
+ src_degree = graph.degree(src_id) if graph.has_node(src_id) else 0
110
+ tgt_degree = graph.degree(tgt_id) if graph.has_node(tgt_id) else 0
111
+ return src_degree + tgt_degree
112
 
113
  async def get_edge(
114
  self, source_node_id: str, target_node_id: str
lightrag/kg/postgres_impl.py CHANGED
@@ -136,6 +136,52 @@ class PostgreSQLDB:
136
  except Exception as e:
137
  logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  async def _migrate_timestamp_columns(self):
140
  """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
141
  # Tables and columns that need migration
@@ -189,6 +235,239 @@ class PostgreSQLDB:
189
  # Log error but don't interrupt the process
190
  logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  async def check_tables(self):
193
  # First create all tables
194
  for k, v in TABLES.items():
@@ -240,6 +519,44 @@ class PostgreSQLDB:
240
  logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
241
  # Don't throw an exception, allow the initialization process to continue
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  async def query(
244
  self,
245
  sql: str,
@@ -423,74 +740,139 @@ class PGKVStorage(BaseKVStorage):
423
  try:
424
  results = await self.db.query(sql, params, multirows=True)
425
 
 
426
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
427
- result_dict = {}
428
  for row in results:
429
- mode = row["mode"]
430
- if mode not in result_dict:
431
- result_dict[mode] = {}
432
- result_dict[mode][row["id"]] = row
433
- return result_dict
434
- else:
435
- return {row["id"]: row for row in results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  except Exception as e:
437
  logger.error(f"Error retrieving all data from {self.namespace}: {e}")
438
  return {}
439
 
440
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
441
- """Get doc_full data by id."""
442
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
443
  params = {"workspace": self.db.workspace, "id": id}
444
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
445
- array_res = await self.db.query(sql, params, multirows=True)
446
- res = {}
447
- for row in array_res:
448
- res[row["id"]] = row
449
- return res if res else None
450
- else:
451
- response = await self.db.query(sql, params)
452
- return response if response else None
453
-
454
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
455
- """Specifically for llm_response_cache."""
456
- sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
457
- params = {"workspace": self.db.workspace, "mode": mode, "id": id}
458
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
459
- array_res = await self.db.query(sql, params, multirows=True)
460
- res = {}
461
- for row in array_res:
462
- res[row["id"]] = row
463
- return res
464
- else:
465
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Query by id
468
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
469
- """Get doc_chunks data by id"""
470
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
471
  ids=",".join([f"'{id}'" for id in ids])
472
  )
473
  params = {"workspace": self.db.workspace}
474
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
475
- array_res = await self.db.query(sql, params, multirows=True)
476
- modes = set()
477
- dict_res: dict[str, dict] = {}
478
- for row in array_res:
479
- modes.add(row["mode"])
480
- for mode in modes:
481
- if mode not in dict_res:
482
- dict_res[mode] = {}
483
- for row in array_res:
484
- dict_res[row["mode"]][row["id"]] = row
485
- return [{k: v} for k, v in dict_res.items()]
486
- else:
487
- return await self.db.query(sql, params, multirows=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
490
- """Specifically for llm_response_cache."""
491
- SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
492
- params = {"workspace": self.db.workspace, "status": status}
493
- return await self.db.query(SQL, params, multirows=True)
494
 
495
  async def filter_keys(self, keys: set[str]) -> set[str]:
496
  """Filter out duplicated content"""
@@ -520,7 +902,22 @@ class PGKVStorage(BaseKVStorage):
520
  return
521
 
522
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
523
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
525
  for k, v in data.items():
526
  upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
@@ -531,19 +928,21 @@ class PGKVStorage(BaseKVStorage):
531
  }
532
  await self.db.execute(upsert_sql, _data)
533
  elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
534
- for mode, items in data.items():
535
- for k, v in items.items():
536
- upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
537
- _data = {
538
- "workspace": self.db.workspace,
539
- "id": k,
540
- "original_prompt": v["original_prompt"],
541
- "return_value": v["return"],
542
- "mode": mode,
543
- "chunk_id": v.get("chunk_id"),
544
- }
 
 
545
 
546
- await self.db.execute(upsert_sql, _data)
547
 
548
  async def index_done_callback(self) -> None:
549
  # PG handles persistence automatically
@@ -949,8 +1348,8 @@ class PGDocStatusStorage(DocStatusStorage):
949
  else:
950
  exist_keys = []
951
  new_keys = set([s for s in keys if s not in exist_keys])
952
- print(f"keys: {keys}")
953
- print(f"new_keys: {new_keys}")
954
  return new_keys
955
  except Exception as e:
956
  logger.error(
@@ -965,6 +1364,14 @@ class PGDocStatusStorage(DocStatusStorage):
965
  if result is None or result == []:
966
  return None
967
  else:
 
 
 
 
 
 
 
 
968
  return dict(
969
  content=result[0]["content"],
970
  content_length=result[0]["content_length"],
@@ -974,6 +1381,7 @@ class PGDocStatusStorage(DocStatusStorage):
974
  created_at=result[0]["created_at"],
975
  updated_at=result[0]["updated_at"],
976
  file_path=result[0]["file_path"],
 
977
  )
978
 
979
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -988,19 +1396,32 @@ class PGDocStatusStorage(DocStatusStorage):
988
 
989
  if not results:
990
  return []
991
- return [
992
- {
993
- "content": row["content"],
994
- "content_length": row["content_length"],
995
- "content_summary": row["content_summary"],
996
- "status": row["status"],
997
- "chunks_count": row["chunks_count"],
998
- "created_at": row["created_at"],
999
- "updated_at": row["updated_at"],
1000
- "file_path": row["file_path"],
1001
- }
1002
- for row in results
1003
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
  async def get_status_counts(self) -> dict[str, int]:
1006
  """Get counts of documents in each status"""
@@ -1021,8 +1442,18 @@ class PGDocStatusStorage(DocStatusStorage):
1021
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
1022
  params = {"workspace": self.db.workspace, "status": status.value}
1023
  result = await self.db.query(sql, params, True)
1024
- docs_by_status = {
1025
- element["id"]: DocProcessingStatus(
 
 
 
 
 
 
 
 
 
 
1026
  content=element["content"],
1027
  content_summary=element["content_summary"],
1028
  content_length=element["content_length"],
@@ -1031,9 +1462,9 @@ class PGDocStatusStorage(DocStatusStorage):
1031
  updated_at=element["updated_at"],
1032
  chunks_count=element["chunks_count"],
1033
  file_path=element["file_path"],
 
1034
  )
1035
- for element in result
1036
- }
1037
  return docs_by_status
1038
 
1039
  async def index_done_callback(self) -> None:
@@ -1097,10 +1528,10 @@ class PGDocStatusStorage(DocStatusStorage):
1097
  logger.warning(f"Unable to parse datetime string: {dt_str}")
1098
  return None
1099
 
1100
- # Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations
1101
- # Both fields are updated from the input data in both INSERT and UPDATE cases
1102
- sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at)
1103
- values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
1104
  on conflict(id,workspace) do update set
1105
  content = EXCLUDED.content,
1106
  content_summary = EXCLUDED.content_summary,
@@ -1108,6 +1539,7 @@ class PGDocStatusStorage(DocStatusStorage):
1108
  chunks_count = EXCLUDED.chunks_count,
1109
  status = EXCLUDED.status,
1110
  file_path = EXCLUDED.file_path,
 
1111
  created_at = EXCLUDED.created_at,
1112
  updated_at = EXCLUDED.updated_at"""
1113
  for k, v in data.items():
@@ -1115,7 +1547,7 @@ class PGDocStatusStorage(DocStatusStorage):
1115
  created_at = parse_datetime(v.get("created_at"))
1116
  updated_at = parse_datetime(v.get("updated_at"))
1117
 
1118
- # chunks_count is optional
1119
  await self.db.execute(
1120
  sql,
1121
  {
@@ -1127,6 +1559,7 @@ class PGDocStatusStorage(DocStatusStorage):
1127
  "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
1128
  "status": v["status"],
1129
  "file_path": v["file_path"],
 
1130
  "created_at": created_at, # Use the converted datetime object
1131
  "updated_at": updated_at, # Use the converted datetime object
1132
  },
@@ -2409,7 +2842,7 @@ class PGGraphStorage(BaseGraphStorage):
2409
  NAMESPACE_TABLE_MAP = {
2410
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
2411
  NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
2412
- NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
2413
  NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
2414
  NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
2415
  NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
@@ -2444,13 +2877,28 @@ TABLES = {
2444
  chunk_order_index INTEGER,
2445
  tokens INTEGER,
2446
  content TEXT,
2447
- content_vector VECTOR,
2448
  file_path VARCHAR(256),
 
2449
  create_time TIMESTAMP(0) WITH TIME ZONE,
2450
  update_time TIMESTAMP(0) WITH TIME ZONE,
2451
  CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
2452
  )"""
2453
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2454
  "LIGHTRAG_VDB_ENTITY": {
2455
  "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
2456
  id VARCHAR(255),
@@ -2503,6 +2951,7 @@ TABLES = {
2503
  chunks_count int4 NULL,
2504
  status varchar(64) NULL,
2505
  file_path TEXT NULL,
 
2506
  created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
2507
  updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
2508
  CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
@@ -2517,24 +2966,30 @@ SQL_TEMPLATES = {
2517
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
2518
  """,
2519
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2520
- chunk_order_index, full_doc_id, file_path
 
 
2521
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2522
  """,
2523
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2524
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
 
2525
  """,
2526
- "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2527
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
2528
  """,
2529
  "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
2530
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
2531
  """,
2532
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2533
- chunk_order_index, full_doc_id, file_path
 
 
2534
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
2535
  """,
2536
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2537
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
 
2538
  """,
2539
  "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
2540
  "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
@@ -2542,16 +2997,31 @@ SQL_TEMPLATES = {
2542
  ON CONFLICT (workspace,id) DO UPDATE
2543
  SET content = $2, update_time = CURRENT_TIMESTAMP
2544
  """,
2545
- "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id)
2546
- VALUES ($1, $2, $3, $4, $5, $6)
2547
  ON CONFLICT (workspace,mode,id) DO UPDATE
2548
  SET original_prompt = EXCLUDED.original_prompt,
2549
  return_value=EXCLUDED.return_value,
2550
  mode=EXCLUDED.mode,
2551
  chunk_id=EXCLUDED.chunk_id,
 
2552
  update_time = CURRENT_TIMESTAMP
2553
  """,
2554
- "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2555
  chunk_order_index, full_doc_id, content, content_vector, file_path,
2556
  create_time, update_time)
2557
  VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@@ -2564,7 +3034,6 @@ SQL_TEMPLATES = {
2564
  file_path=EXCLUDED.file_path,
2565
  update_time = EXCLUDED.update_time
2566
  """,
2567
- # SQL for VectorStorage
2568
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
2569
  content_vector, chunk_ids, file_path, create_time, update_time)
2570
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
@@ -2591,7 +3060,7 @@ SQL_TEMPLATES = {
2591
  "relationships": """
2592
  WITH relevant_chunks AS (
2593
  SELECT id as chunk_id
2594
- FROM LIGHTRAG_DOC_CHUNKS
2595
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
2596
  )
2597
  SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
@@ -2608,7 +3077,7 @@ SQL_TEMPLATES = {
2608
  "entities": """
2609
  WITH relevant_chunks AS (
2610
  SELECT id as chunk_id
2611
- FROM LIGHTRAG_DOC_CHUNKS
2612
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
2613
  )
2614
  SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
@@ -2625,13 +3094,13 @@ SQL_TEMPLATES = {
2625
  "chunks": """
2626
  WITH relevant_chunks AS (
2627
  SELECT id as chunk_id
2628
- FROM LIGHTRAG_DOC_CHUNKS
2629
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
2630
  )
2631
  SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
2632
  (
2633
  SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
2634
- FROM LIGHTRAG_DOC_CHUNKS
2635
  WHERE workspace=$1
2636
  AND id IN (SELECT chunk_id FROM relevant_chunks)
2637
  ) as chunk_distances
 
136
  except Exception as e:
137
  logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}")
138
 
139
+ async def _migrate_llm_cache_add_cache_type(self):
140
+ """Add cache_type column to LIGHTRAG_LLM_CACHE table if it doesn't exist"""
141
+ try:
142
+ # Check if cache_type column exists
143
+ check_column_sql = """
144
+ SELECT column_name
145
+ FROM information_schema.columns
146
+ WHERE table_name = 'lightrag_llm_cache'
147
+ AND column_name = 'cache_type'
148
+ """
149
+
150
+ column_info = await self.query(check_column_sql)
151
+ if not column_info:
152
+ logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table")
153
+ add_column_sql = """
154
+ ALTER TABLE LIGHTRAG_LLM_CACHE
155
+ ADD COLUMN cache_type VARCHAR(32) NULL
156
+ """
157
+ await self.execute(add_column_sql)
158
+ logger.info(
159
+ "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table"
160
+ )
161
+
162
+ # Migrate existing data: extract cache_type from flattened keys
163
+ logger.info(
164
+ "Migrating existing LLM cache data to populate cache_type field"
165
+ )
166
+ update_sql = """
167
+ UPDATE LIGHTRAG_LLM_CACHE
168
+ SET cache_type = CASE
169
+ WHEN id LIKE '%:%:%' THEN split_part(id, ':', 2)
170
+ ELSE 'extract'
171
+ END
172
+ WHERE cache_type IS NULL
173
+ """
174
+ await self.execute(update_sql)
175
+ logger.info("Successfully migrated existing LLM cache data")
176
+ else:
177
+ logger.info(
178
+ "cache_type column already exists in LIGHTRAG_LLM_CACHE table"
179
+ )
180
+ except Exception as e:
181
+ logger.warning(
182
+ f"Failed to add cache_type column to LIGHTRAG_LLM_CACHE: {e}"
183
+ )
184
+
185
  async def _migrate_timestamp_columns(self):
186
  """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time"""
187
  # Tables and columns that need migration
 
235
  # Log error but don't interrupt the process
236
  logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}")
237
 
238
+ async def _migrate_doc_chunks_to_vdb_chunks(self):
239
+ """
240
+ Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met.
241
+ This migration is intended for users who are upgrading and have an older table structure
242
+ where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column.
243
+
244
+ """
245
+ try:
246
+ # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty
247
+ vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS"
248
+ vdb_chunks_count_result = await self.query(vdb_chunks_count_sql)
249
+ if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0:
250
+ logger.info(
251
+ "Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data."
252
+ )
253
+ return
254
+
255
+ # 2. Check if `content_vector` column exists in the old table
256
+ check_column_sql = """
257
+ SELECT 1 FROM information_schema.columns
258
+ WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector'
259
+ """
260
+ column_exists = await self.query(check_column_sql)
261
+ if not column_exists:
262
+ logger.info(
263
+ "Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS"
264
+ )
265
+ return
266
+
267
+ # 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data
268
+ doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS"
269
+ doc_chunks_count_result = await self.query(doc_chunks_count_sql)
270
+ if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0:
271
+ logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.")
272
+ return
273
+
274
+ # 4. Perform the migration
275
+ logger.info(
276
+ "Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..."
277
+ )
278
+ migration_sql = """
279
+ INSERT INTO LIGHTRAG_VDB_CHUNKS (
280
+ id, workspace, full_doc_id, chunk_order_index, tokens, content,
281
+ content_vector, file_path, create_time, update_time
282
+ )
283
+ SELECT
284
+ id, workspace, full_doc_id, chunk_order_index, tokens, content,
285
+ content_vector, file_path, create_time, update_time
286
+ FROM LIGHTRAG_DOC_CHUNKS
287
+ ON CONFLICT (workspace, id) DO NOTHING;
288
+ """
289
+ await self.execute(migration_sql)
290
+ logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.")
291
+
292
+ except Exception as e:
293
+ logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
294
+ # Do not re-raise, to allow the application to start
295
+
296
+ async def _check_llm_cache_needs_migration(self):
297
+ """Check if LLM cache data needs migration by examining the first record"""
298
+ try:
299
+ # Only query the first record to determine format
300
+ check_sql = """
301
+ SELECT id FROM LIGHTRAG_LLM_CACHE
302
+ ORDER BY create_time ASC
303
+ LIMIT 1
304
+ """
305
+ result = await self.query(check_sql)
306
+
307
+ if result and result.get("id"):
308
+ # If id doesn't contain colon, it's old format
309
+ return ":" not in result["id"]
310
+
311
+ return False # No data or already new format
312
+ except Exception as e:
313
+ logger.warning(f"Failed to check LLM cache migration status: {e}")
314
+ return False
315
+
316
+ async def _migrate_llm_cache_to_flattened_keys(self):
317
+ """Migrate LLM cache to flattened key format, recalculating hash values"""
318
+ try:
319
+ # Get all old format data
320
+ old_data_sql = """
321
+ SELECT id, mode, original_prompt, return_value, chunk_id,
322
+ create_time, update_time
323
+ FROM LIGHTRAG_LLM_CACHE
324
+ WHERE id NOT LIKE '%:%'
325
+ """
326
+
327
+ old_records = await self.query(old_data_sql, multirows=True)
328
+
329
+ if not old_records:
330
+ logger.info("No old format LLM cache data found, skipping migration")
331
+ return
332
+
333
+ logger.info(
334
+ f"Found {len(old_records)} old format cache records, starting migration..."
335
+ )
336
+
337
+ # Import hash calculation function
338
+ from ..utils import compute_args_hash
339
+
340
+ migrated_count = 0
341
+
342
+ # Migrate data in batches
343
+ for record in old_records:
344
+ try:
345
+ # Recalculate hash using correct method
346
+ new_hash = compute_args_hash(
347
+ record["mode"], record["original_prompt"]
348
+ )
349
+
350
+ # Determine cache_type based on mode
351
+ cache_type = "extract" if record["mode"] == "default" else "unknown"
352
+
353
+ # Generate new flattened key
354
+ new_key = f"{record['mode']}:{cache_type}:{new_hash}"
355
+
356
+ # Insert new format data with cache_type field
357
+ insert_sql = """
358
+ INSERT INTO LIGHTRAG_LLM_CACHE
359
+ (workspace, id, mode, original_prompt, return_value, chunk_id, cache_type, create_time, update_time)
360
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
361
+ ON CONFLICT (workspace, mode, id) DO NOTHING
362
+ """
363
+
364
+ await self.execute(
365
+ insert_sql,
366
+ {
367
+ "workspace": self.workspace,
368
+ "id": new_key,
369
+ "mode": record["mode"],
370
+ "original_prompt": record["original_prompt"],
371
+ "return_value": record["return_value"],
372
+ "chunk_id": record["chunk_id"],
373
+ "cache_type": cache_type, # Add cache_type field
374
+ "create_time": record["create_time"],
375
+ "update_time": record["update_time"],
376
+ },
377
+ )
378
+
379
+ # Delete old data
380
+ delete_sql = """
381
+ DELETE FROM LIGHTRAG_LLM_CACHE
382
+ WHERE workspace=$1 AND mode=$2 AND id=$3
383
+ """
384
+ await self.execute(
385
+ delete_sql,
386
+ {
387
+ "workspace": self.workspace,
388
+ "mode": record["mode"],
389
+ "id": record["id"], # Old id
390
+ },
391
+ )
392
+
393
+ migrated_count += 1
394
+
395
+ except Exception as e:
396
+ logger.warning(
397
+ f"Failed to migrate cache record {record['id']}: {e}"
398
+ )
399
+ continue
400
+
401
+ logger.info(
402
+ f"Successfully migrated {migrated_count} cache records to flattened format"
403
+ )
404
+
405
+ except Exception as e:
406
+ logger.error(f"LLM cache migration failed: {e}")
407
+ # Don't raise exception, allow system to continue startup
408
+
409
+ async def _migrate_doc_status_add_chunks_list(self):
410
+ """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist"""
411
+ try:
412
+ # Check if chunks_list column exists
413
+ check_column_sql = """
414
+ SELECT column_name
415
+ FROM information_schema.columns
416
+ WHERE table_name = 'lightrag_doc_status'
417
+ AND column_name = 'chunks_list'
418
+ """
419
+
420
+ column_info = await self.query(check_column_sql)
421
+ if not column_info:
422
+ logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table")
423
+ add_column_sql = """
424
+ ALTER TABLE LIGHTRAG_DOC_STATUS
425
+ ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb
426
+ """
427
+ await self.execute(add_column_sql)
428
+ logger.info(
429
+ "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table"
430
+ )
431
+ else:
432
+ logger.info(
433
+ "chunks_list column already exists in LIGHTRAG_DOC_STATUS table"
434
+ )
435
+ except Exception as e:
436
+ logger.warning(
437
+ f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}"
438
+ )
439
+
440
+ async def _migrate_text_chunks_add_llm_cache_list(self):
441
+ """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist"""
442
+ try:
443
+ # Check if llm_cache_list column exists
444
+ check_column_sql = """
445
+ SELECT column_name
446
+ FROM information_schema.columns
447
+ WHERE table_name = 'lightrag_doc_chunks'
448
+ AND column_name = 'llm_cache_list'
449
+ """
450
+
451
+ column_info = await self.query(check_column_sql)
452
+ if not column_info:
453
+ logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table")
454
+ add_column_sql = """
455
+ ALTER TABLE LIGHTRAG_DOC_CHUNKS
456
+ ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb
457
+ """
458
+ await self.execute(add_column_sql)
459
+ logger.info(
460
+ "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table"
461
+ )
462
+ else:
463
+ logger.info(
464
+ "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table"
465
+ )
466
+ except Exception as e:
467
+ logger.warning(
468
+ f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}"
469
+ )
470
+
471
  async def check_tables(self):
472
  # First create all tables
473
  for k, v in TABLES.items():
 
519
  logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}")
520
  # Don't throw an exception, allow the initialization process to continue
521
 
522
+ # Migrate LLM cache table to add cache_type field if needed
523
+ try:
524
+ await self._migrate_llm_cache_add_cache_type()
525
+ except Exception as e:
526
+ logger.error(
527
+ f"PostgreSQL, Failed to migrate LLM cache cache_type field: {e}"
528
+ )
529
+ # Don't throw an exception, allow the initialization process to continue
530
+
531
+ # Finally, attempt to migrate old doc chunks data if needed
532
+ try:
533
+ await self._migrate_doc_chunks_to_vdb_chunks()
534
+ except Exception as e:
535
+ logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
536
+
537
+ # Check and migrate LLM cache to flattened keys if needed
538
+ try:
539
+ if await self._check_llm_cache_needs_migration():
540
+ await self._migrate_llm_cache_to_flattened_keys()
541
+ except Exception as e:
542
+ logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
543
+
544
+ # Migrate doc status to add chunks_list field if needed
545
+ try:
546
+ await self._migrate_doc_status_add_chunks_list()
547
+ except Exception as e:
548
+ logger.error(
549
+ f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}"
550
+ )
551
+
552
+ # Migrate text chunks to add llm_cache_list field if needed
553
+ try:
554
+ await self._migrate_text_chunks_add_llm_cache_list()
555
+ except Exception as e:
556
+ logger.error(
557
+ f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}"
558
+ )
559
+
560
  async def query(
561
  self,
562
  sql: str,
 
740
  try:
741
  results = await self.db.query(sql, params, multirows=True)
742
 
743
+ # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
744
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
745
+ processed_results = {}
746
  for row in results:
747
+ create_time = row.get("create_time", 0)
748
+ update_time = row.get("update_time", 0)
749
+ # Map field names and add cache_type for compatibility
750
+ processed_row = {
751
+ **row,
752
+ "return": row.get("return_value", ""),
753
+ "cache_type": row.get("original_prompt", "unknow"),
754
+ "original_prompt": row.get("original_prompt", ""),
755
+ "chunk_id": row.get("chunk_id"),
756
+ "mode": row.get("mode", "default"),
757
+ "create_time": create_time,
758
+ "update_time": create_time if update_time == 0 else update_time,
759
+ }
760
+ processed_results[row["id"]] = processed_row
761
+ return processed_results
762
+
763
+ # For text_chunks namespace, parse llm_cache_list JSON string back to list
764
+ if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
765
+ processed_results = {}
766
+ for row in results:
767
+ llm_cache_list = row.get("llm_cache_list", [])
768
+ if isinstance(llm_cache_list, str):
769
+ try:
770
+ llm_cache_list = json.loads(llm_cache_list)
771
+ except json.JSONDecodeError:
772
+ llm_cache_list = []
773
+ row["llm_cache_list"] = llm_cache_list
774
+ create_time = row.get("create_time", 0)
775
+ update_time = row.get("update_time", 0)
776
+ row["create_time"] = create_time
777
+ row["update_time"] = (
778
+ create_time if update_time == 0 else update_time
779
+ )
780
+ processed_results[row["id"]] = row
781
+ return processed_results
782
+
783
+ # For other namespaces, return as-is
784
+ return {row["id"]: row for row in results}
785
  except Exception as e:
786
  logger.error(f"Error retrieving all data from {self.namespace}: {e}")
787
  return {}
788
 
789
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
790
+ """Get data by id."""
791
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
792
  params = {"workspace": self.db.workspace, "id": id}
793
+ response = await self.db.query(sql, params)
794
+
795
+ if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
796
+ # Parse llm_cache_list JSON string back to list
797
+ llm_cache_list = response.get("llm_cache_list", [])
798
+ if isinstance(llm_cache_list, str):
799
+ try:
800
+ llm_cache_list = json.loads(llm_cache_list)
801
+ except json.JSONDecodeError:
802
+ llm_cache_list = []
803
+ response["llm_cache_list"] = llm_cache_list
804
+ create_time = response.get("create_time", 0)
805
+ update_time = response.get("update_time", 0)
806
+ response["create_time"] = create_time
807
+ response["update_time"] = create_time if update_time == 0 else update_time
808
+
809
+ # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
810
+ if response and is_namespace(
811
+ self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
812
+ ):
813
+ create_time = response.get("create_time", 0)
814
+ update_time = response.get("update_time", 0)
815
+ # Map field names and add cache_type for compatibility
816
+ response = {
817
+ **response,
818
+ "return": response.get("return_value", ""),
819
+ "cache_type": response.get("cache_type"),
820
+ "original_prompt": response.get("original_prompt", ""),
821
+ "chunk_id": response.get("chunk_id"),
822
+ "mode": response.get("mode", "default"),
823
+ "create_time": create_time,
824
+ "update_time": create_time if update_time == 0 else update_time,
825
+ }
826
+
827
+ return response if response else None
828
 
829
  # Query by id
830
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
831
+ """Get data by ids"""
832
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
833
  ids=",".join([f"'{id}'" for id in ids])
834
  )
835
  params = {"workspace": self.db.workspace}
836
+ results = await self.db.query(sql, params, multirows=True)
837
+
838
+ if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
839
+ # Parse llm_cache_list JSON string back to list for each result
840
+ for result in results:
841
+ llm_cache_list = result.get("llm_cache_list", [])
842
+ if isinstance(llm_cache_list, str):
843
+ try:
844
+ llm_cache_list = json.loads(llm_cache_list)
845
+ except json.JSONDecodeError:
846
+ llm_cache_list = []
847
+ result["llm_cache_list"] = llm_cache_list
848
+ create_time = result.get("create_time", 0)
849
+ update_time = result.get("update_time", 0)
850
+ result["create_time"] = create_time
851
+ result["update_time"] = create_time if update_time == 0 else update_time
852
+
853
+ # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
854
+ if results and is_namespace(
855
+ self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
856
+ ):
857
+ processed_results = []
858
+ for row in results:
859
+ create_time = row.get("create_time", 0)
860
+ update_time = row.get("update_time", 0)
861
+ # Map field names and add cache_type for compatibility
862
+ processed_row = {
863
+ **row,
864
+ "return": row.get("return_value", ""),
865
+ "cache_type": row.get("cache_type"),
866
+ "original_prompt": row.get("original_prompt", ""),
867
+ "chunk_id": row.get("chunk_id"),
868
+ "mode": row.get("mode", "default"),
869
+ "create_time": create_time,
870
+ "update_time": create_time if update_time == 0 else update_time,
871
+ }
872
+ processed_results.append(processed_row)
873
+ return processed_results
874
 
875
+ return results if results else []
 
 
 
 
876
 
877
  async def filter_keys(self, keys: set[str]) -> set[str]:
878
  """Filter out duplicated content"""
 
902
  return
903
 
904
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
905
+ current_time = datetime.datetime.now(timezone.utc)
906
+ for k, v in data.items():
907
+ upsert_sql = SQL_TEMPLATES["upsert_text_chunk"]
908
+ _data = {
909
+ "workspace": self.db.workspace,
910
+ "id": k,
911
+ "tokens": v["tokens"],
912
+ "chunk_order_index": v["chunk_order_index"],
913
+ "full_doc_id": v["full_doc_id"],
914
+ "content": v["content"],
915
+ "file_path": v["file_path"],
916
+ "llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
917
+ "create_time": current_time,
918
+ "update_time": current_time,
919
+ }
920
+ await self.db.execute(upsert_sql, _data)
921
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
922
  for k, v in data.items():
923
  upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
 
928
  }
929
  await self.db.execute(upsert_sql, _data)
930
  elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
931
+ for k, v in data.items():
932
+ upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
933
+ _data = {
934
+ "workspace": self.db.workspace,
935
+ "id": k, # Use flattened key as id
936
+ "original_prompt": v["original_prompt"],
937
+ "return_value": v["return"],
938
+ "mode": v.get("mode", "default"), # Get mode from data
939
+ "chunk_id": v.get("chunk_id"),
940
+ "cache_type": v.get(
941
+ "cache_type", "extract"
942
+ ), # Get cache_type from data
943
+ }
944
 
945
+ await self.db.execute(upsert_sql, _data)
946
 
947
  async def index_done_callback(self) -> None:
948
  # PG handles persistence automatically
 
1348
  else:
1349
  exist_keys = []
1350
  new_keys = set([s for s in keys if s not in exist_keys])
1351
+ # print(f"keys: {keys}")
1352
+ # print(f"new_keys: {new_keys}")
1353
  return new_keys
1354
  except Exception as e:
1355
  logger.error(
 
1364
  if result is None or result == []:
1365
  return None
1366
  else:
1367
+ # Parse chunks_list JSON string back to list
1368
+ chunks_list = result[0].get("chunks_list", [])
1369
+ if isinstance(chunks_list, str):
1370
+ try:
1371
+ chunks_list = json.loads(chunks_list)
1372
+ except json.JSONDecodeError:
1373
+ chunks_list = []
1374
+
1375
  return dict(
1376
  content=result[0]["content"],
1377
  content_length=result[0]["content_length"],
 
1381
  created_at=result[0]["created_at"],
1382
  updated_at=result[0]["updated_at"],
1383
  file_path=result[0]["file_path"],
1384
+ chunks_list=chunks_list,
1385
  )
1386
 
1387
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
 
1396
 
1397
  if not results:
1398
  return []
1399
+
1400
+ processed_results = []
1401
+ for row in results:
1402
+ # Parse chunks_list JSON string back to list
1403
+ chunks_list = row.get("chunks_list", [])
1404
+ if isinstance(chunks_list, str):
1405
+ try:
1406
+ chunks_list = json.loads(chunks_list)
1407
+ except json.JSONDecodeError:
1408
+ chunks_list = []
1409
+
1410
+ processed_results.append(
1411
+ {
1412
+ "content": row["content"],
1413
+ "content_length": row["content_length"],
1414
+ "content_summary": row["content_summary"],
1415
+ "status": row["status"],
1416
+ "chunks_count": row["chunks_count"],
1417
+ "created_at": row["created_at"],
1418
+ "updated_at": row["updated_at"],
1419
+ "file_path": row["file_path"],
1420
+ "chunks_list": chunks_list,
1421
+ }
1422
+ )
1423
+
1424
+ return processed_results
1425
 
1426
  async def get_status_counts(self) -> dict[str, int]:
1427
  """Get counts of documents in each status"""
 
1442
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
1443
  params = {"workspace": self.db.workspace, "status": status.value}
1444
  result = await self.db.query(sql, params, True)
1445
+
1446
+ docs_by_status = {}
1447
+ for element in result:
1448
+ # Parse chunks_list JSON string back to list
1449
+ chunks_list = element.get("chunks_list", [])
1450
+ if isinstance(chunks_list, str):
1451
+ try:
1452
+ chunks_list = json.loads(chunks_list)
1453
+ except json.JSONDecodeError:
1454
+ chunks_list = []
1455
+
1456
+ docs_by_status[element["id"]] = DocProcessingStatus(
1457
  content=element["content"],
1458
  content_summary=element["content_summary"],
1459
  content_length=element["content_length"],
 
1462
  updated_at=element["updated_at"],
1463
  chunks_count=element["chunks_count"],
1464
  file_path=element["file_path"],
1465
+ chunks_list=chunks_list,
1466
  )
1467
+
 
1468
  return docs_by_status
1469
 
1470
  async def index_done_callback(self) -> None:
 
1528
  logger.warning(f"Unable to parse datetime string: {dt_str}")
1529
  return None
1530
 
1531
+ # Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations
1532
+ # All fields are updated from the input data in both INSERT and UPDATE cases
1533
+ sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at)
1534
+ values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
1535
  on conflict(id,workspace) do update set
1536
  content = EXCLUDED.content,
1537
  content_summary = EXCLUDED.content_summary,
 
1539
  chunks_count = EXCLUDED.chunks_count,
1540
  status = EXCLUDED.status,
1541
  file_path = EXCLUDED.file_path,
1542
+ chunks_list = EXCLUDED.chunks_list,
1543
  created_at = EXCLUDED.created_at,
1544
  updated_at = EXCLUDED.updated_at"""
1545
  for k, v in data.items():
 
1547
  created_at = parse_datetime(v.get("created_at"))
1548
  updated_at = parse_datetime(v.get("updated_at"))
1549
 
1550
+ # chunks_count and chunks_list are optional
1551
  await self.db.execute(
1552
  sql,
1553
  {
 
1559
  "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
1560
  "status": v["status"],
1561
  "file_path": v["file_path"],
1562
+ "chunks_list": json.dumps(v.get("chunks_list", [])),
1563
  "created_at": created_at, # Use the converted datetime object
1564
  "updated_at": updated_at, # Use the converted datetime object
1565
  },
 
2842
  NAMESPACE_TABLE_MAP = {
2843
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
2844
  NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
2845
+ NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
2846
  NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
2847
  NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
2848
  NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
 
2877
  chunk_order_index INTEGER,
2878
  tokens INTEGER,
2879
  content TEXT,
 
2880
  file_path VARCHAR(256),
2881
+ llm_cache_list JSONB NULL DEFAULT '[]'::jsonb,
2882
  create_time TIMESTAMP(0) WITH TIME ZONE,
2883
  update_time TIMESTAMP(0) WITH TIME ZONE,
2884
  CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
2885
  )"""
2886
  },
2887
+ "LIGHTRAG_VDB_CHUNKS": {
2888
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS (
2889
+ id VARCHAR(255),
2890
+ workspace VARCHAR(255),
2891
+ full_doc_id VARCHAR(256),
2892
+ chunk_order_index INTEGER,
2893
+ tokens INTEGER,
2894
+ content TEXT,
2895
+ content_vector VECTOR,
2896
+ file_path VARCHAR(256),
2897
+ create_time TIMESTAMP(0) WITH TIME ZONE,
2898
+ update_time TIMESTAMP(0) WITH TIME ZONE,
2899
+ CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id)
2900
+ )"""
2901
+ },
2902
  "LIGHTRAG_VDB_ENTITY": {
2903
  "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
2904
  id VARCHAR(255),
 
2951
  chunks_count int4 NULL,
2952
  status varchar(64) NULL,
2953
  file_path TEXT NULL,
2954
+ chunks_list JSONB NULL DEFAULT '[]'::jsonb,
2955
  created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
2956
  updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL,
2957
  CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
 
2966
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
2967
  """,
2968
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2969
+ chunk_order_index, full_doc_id, file_path,
2970
+ COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
2971
+ create_time, update_time
2972
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2973
  """,
2974
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
2975
+ create_time, update_time
2976
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
2977
  """,
2978
+ "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
2979
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
2980
  """,
2981
  "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
2982
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
2983
  """,
2984
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2985
+ chunk_order_index, full_doc_id, file_path,
2986
+ COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
2987
+ create_time, update_time
2988
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
2989
  """,
2990
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
2991
+ create_time, update_time
2992
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
2993
  """,
2994
  "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
2995
  "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
 
2997
  ON CONFLICT (workspace,id) DO UPDATE
2998
  SET content = $2, update_time = CURRENT_TIMESTAMP
2999
  """,
3000
+ "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type)
3001
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
3002
  ON CONFLICT (workspace,mode,id) DO UPDATE
3003
  SET original_prompt = EXCLUDED.original_prompt,
3004
  return_value=EXCLUDED.return_value,
3005
  mode=EXCLUDED.mode,
3006
  chunk_id=EXCLUDED.chunk_id,
3007
+ cache_type=EXCLUDED.cache_type,
3008
  update_time = CURRENT_TIMESTAMP
3009
  """,
3010
+ "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
3011
+ chunk_order_index, full_doc_id, content, file_path, llm_cache_list,
3012
+ create_time, update_time)
3013
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
3014
+ ON CONFLICT (workspace,id) DO UPDATE
3015
+ SET tokens=EXCLUDED.tokens,
3016
+ chunk_order_index=EXCLUDED.chunk_order_index,
3017
+ full_doc_id=EXCLUDED.full_doc_id,
3018
+ content = EXCLUDED.content,
3019
+ file_path=EXCLUDED.file_path,
3020
+ llm_cache_list=EXCLUDED.llm_cache_list,
3021
+ update_time = EXCLUDED.update_time
3022
+ """,
3023
+ # SQL for VectorStorage
3024
+ "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
3025
  chunk_order_index, full_doc_id, content, content_vector, file_path,
3026
  create_time, update_time)
3027
  VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
 
3034
  file_path=EXCLUDED.file_path,
3035
  update_time = EXCLUDED.update_time
3036
  """,
 
3037
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
3038
  content_vector, chunk_ids, file_path, create_time, update_time)
3039
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
 
3060
  "relationships": """
3061
  WITH relevant_chunks AS (
3062
  SELECT id as chunk_id
3063
+ FROM LIGHTRAG_VDB_CHUNKS
3064
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
3065
  )
3066
  SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at
 
3077
  "entities": """
3078
  WITH relevant_chunks AS (
3079
  SELECT id as chunk_id
3080
+ FROM LIGHTRAG_VDB_CHUNKS
3081
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
3082
  )
3083
  SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
 
3094
  "chunks": """
3095
  WITH relevant_chunks AS (
3096
  SELECT id as chunk_id
3097
+ FROM LIGHTRAG_VDB_CHUNKS
3098
  WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[])
3099
  )
3100
  SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM
3101
  (
3102
  SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
3103
+ FROM LIGHTRAG_VDB_CHUNKS
3104
  WHERE workspace=$1
3105
  AND id IN (SELECT chunk_id FROM relevant_chunks)
3106
  ) as chunk_distances
lightrag/kg/qdrant_impl.py CHANGED
@@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
85
  )
86
 
87
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
88
- logger.info(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
 
85
  )
86
 
87
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
88
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
lightrag/kg/redis_impl.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
- from typing import Any, final
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
6
  from contextlib import asynccontextmanager
 
7
 
8
  if not pm.is_installed("redis"):
9
  pm.install("redis")
@@ -13,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
13
  from redis.exceptions import RedisError, ConnectionError # type: ignore
14
  from lightrag.utils import logger
15
 
16
- from lightrag.base import BaseKVStorage
 
 
 
 
 
17
  import json
18
 
19
 
@@ -26,6 +32,41 @@ SOCKET_TIMEOUT = 5.0
26
  SOCKET_CONNECT_TIMEOUT = 3.0
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @final
30
  @dataclass
31
  class RedisKVStorage(BaseKVStorage):
@@ -33,19 +74,28 @@ class RedisKVStorage(BaseKVStorage):
33
  redis_url = os.environ.get(
34
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
35
  )
36
- # Create a connection pool with limits
37
- self._pool = ConnectionPool.from_url(
38
- redis_url,
39
- max_connections=MAX_CONNECTIONS,
40
- decode_responses=True,
41
- socket_timeout=SOCKET_TIMEOUT,
42
- socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
43
- )
44
  self._redis = Redis(connection_pool=self._pool)
45
  logger.info(
46
- f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @asynccontextmanager
50
  async def _get_redis_connection(self):
51
  """Safe context manager for Redis operations."""
@@ -82,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
82
  async with self._get_redis_connection() as redis:
83
  try:
84
  data = await redis.get(f"{self.namespace}:{id}")
85
- return json.loads(data) if data else None
 
 
 
 
 
 
86
  except json.JSONDecodeError as e:
87
  logger.error(f"JSON decode error for id {id}: {e}")
88
  return None
@@ -94,35 +150,113 @@ class RedisKVStorage(BaseKVStorage):
94
  for id in ids:
95
  pipe.get(f"{self.namespace}:{id}")
96
  results = await pipe.execute()
97
- return [json.loads(result) if result else None for result in results]
 
 
 
 
 
 
 
 
 
 
 
 
98
  except json.JSONDecodeError as e:
99
  logger.error(f"JSON decode error in batch get: {e}")
100
  return [None] * len(ids)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  async def filter_keys(self, keys: set[str]) -> set[str]:
103
  async with self._get_redis_connection() as redis:
104
  pipe = redis.pipeline()
105
- for key in keys:
 
106
  pipe.exists(f"{self.namespace}:{key}")
107
  results = await pipe.execute()
108
 
109
- existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
110
  return set(keys) - existing_ids
111
 
112
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
113
  if not data:
114
  return
115
 
116
- logger.info(f"Inserting {len(data)} items to {self.namespace}")
 
 
 
117
  async with self._get_redis_connection() as redis:
118
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  pipe = redis.pipeline()
120
  for k, v in data.items():
121
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
122
  await pipe.execute()
123
 
124
- for k in data:
125
- data[k]["_id"] = k
126
  except json.JSONEncodeError as e:
127
  logger.error(f"JSON encode error during upsert: {e}")
128
  raise
@@ -148,13 +282,13 @@ class RedisKVStorage(BaseKVStorage):
148
  )
149
 
150
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
151
- """Delete specific records from storage by by cache mode
152
 
153
  Importance notes for Redis storage:
154
  1. This will immediately delete the specified cache modes from Redis
155
 
156
  Args:
157
- modes (list[str]): List of cache mode to be drop from storage
158
 
159
  Returns:
160
  True: if the cache drop successfully
@@ -164,9 +298,47 @@ class RedisKVStorage(BaseKVStorage):
164
  return False
165
 
166
  try:
167
- await self.delete(modes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return True
169
- except Exception:
 
170
  return False
171
 
172
  async def drop(self) -> dict[str, str]:
@@ -177,24 +349,370 @@ class RedisKVStorage(BaseKVStorage):
177
  """
178
  async with self._get_redis_connection() as redis:
179
  try:
180
- keys = await redis.keys(f"{self.namespace}:*")
181
-
182
- if keys:
183
- pipe = redis.pipeline()
184
- for key in keys:
185
- pipe.delete(key)
186
- results = await pipe.execute()
187
- deleted_count = sum(results)
188
-
189
- logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
190
- return {
191
- "status": "success",
192
- "message": f"{deleted_count} keys dropped",
193
- }
194
- else:
195
- logger.info(f"No keys found to drop in {self.namespace}")
196
- return {"status": "success", "message": "no keys to drop"}
 
 
 
 
 
 
197
 
198
  except Exception as e:
199
  logger.error(f"Error dropping keys from {self.namespace}: {e}")
200
  return {"status": "error", "message": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, final, Union
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
6
  from contextlib import asynccontextmanager
7
+ import threading
8
 
9
  if not pm.is_installed("redis"):
10
  pm.install("redis")
 
14
  from redis.exceptions import RedisError, ConnectionError # type: ignore
15
  from lightrag.utils import logger
16
 
17
+ from lightrag.base import (
18
+ BaseKVStorage,
19
+ DocStatusStorage,
20
+ DocStatus,
21
+ DocProcessingStatus,
22
+ )
23
  import json
24
 
25
 
 
32
  SOCKET_CONNECT_TIMEOUT = 3.0
33
 
34
 
35
+ class RedisConnectionManager:
36
+ """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
37
+
38
+ _pools = {}
39
+ _lock = threading.Lock()
40
+
41
+ @classmethod
42
+ def get_pool(cls, redis_url: str) -> ConnectionPool:
43
+ """Get or create a connection pool for the given Redis URL"""
44
+ if redis_url not in cls._pools:
45
+ with cls._lock:
46
+ if redis_url not in cls._pools:
47
+ cls._pools[redis_url] = ConnectionPool.from_url(
48
+ redis_url,
49
+ max_connections=MAX_CONNECTIONS,
50
+ decode_responses=True,
51
+ socket_timeout=SOCKET_TIMEOUT,
52
+ socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
53
+ )
54
+ logger.info(f"Created shared Redis connection pool for {redis_url}")
55
+ return cls._pools[redis_url]
56
+
57
+ @classmethod
58
+ def close_all_pools(cls):
59
+ """Close all connection pools (for cleanup)"""
60
+ with cls._lock:
61
+ for url, pool in cls._pools.items():
62
+ try:
63
+ pool.disconnect()
64
+ logger.info(f"Closed Redis connection pool for {url}")
65
+ except Exception as e:
66
+ logger.error(f"Error closing Redis pool for {url}: {e}")
67
+ cls._pools.clear()
68
+
69
+
70
  @final
71
  @dataclass
72
  class RedisKVStorage(BaseKVStorage):
 
74
  redis_url = os.environ.get(
75
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
76
  )
77
+ # Use shared connection pool
78
+ self._pool = RedisConnectionManager.get_pool(redis_url)
 
 
 
 
 
 
79
  self._redis = Redis(connection_pool=self._pool)
80
  logger.info(
81
+ f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
82
  )
83
 
84
+ async def initialize(self):
85
+ """Initialize Redis connection and migrate legacy cache structure if needed"""
86
+ # Test connection
87
+ try:
88
+ async with self._get_redis_connection() as redis:
89
+ await redis.ping()
90
+ logger.info(f"Connected to Redis for namespace {self.namespace}")
91
+ except Exception as e:
92
+ logger.error(f"Failed to connect to Redis: {e}")
93
+ raise
94
+
95
+ # Migrate legacy cache structure if this is a cache namespace
96
+ if self.namespace.endswith("_cache"):
97
+ await self._migrate_legacy_cache_structure()
98
+
99
  @asynccontextmanager
100
  async def _get_redis_connection(self):
101
  """Safe context manager for Redis operations."""
 
132
  async with self._get_redis_connection() as redis:
133
  try:
134
  data = await redis.get(f"{self.namespace}:{id}")
135
+ if data:
136
+ result = json.loads(data)
137
+ # Ensure time fields are present, provide default values for old data
138
+ result.setdefault("create_time", 0)
139
+ result.setdefault("update_time", 0)
140
+ return result
141
+ return None
142
  except json.JSONDecodeError as e:
143
  logger.error(f"JSON decode error for id {id}: {e}")
144
  return None
 
150
  for id in ids:
151
  pipe.get(f"{self.namespace}:{id}")
152
  results = await pipe.execute()
153
+
154
+ processed_results = []
155
+ for result in results:
156
+ if result:
157
+ data = json.loads(result)
158
+ # Ensure time fields are present for all documents
159
+ data.setdefault("create_time", 0)
160
+ data.setdefault("update_time", 0)
161
+ processed_results.append(data)
162
+ else:
163
+ processed_results.append(None)
164
+
165
+ return processed_results
166
  except json.JSONDecodeError as e:
167
  logger.error(f"JSON decode error in batch get: {e}")
168
  return [None] * len(ids)
169
 
170
+ async def get_all(self) -> dict[str, Any]:
171
+ """Get all data from storage
172
+
173
+ Returns:
174
+ Dictionary containing all stored data
175
+ """
176
+ async with self._get_redis_connection() as redis:
177
+ try:
178
+ # Get all keys for this namespace
179
+ keys = await redis.keys(f"{self.namespace}:*")
180
+
181
+ if not keys:
182
+ return {}
183
+
184
+ # Get all values in batch
185
+ pipe = redis.pipeline()
186
+ for key in keys:
187
+ pipe.get(key)
188
+ values = await pipe.execute()
189
+
190
+ # Build result dictionary
191
+ result = {}
192
+ for key, value in zip(keys, values):
193
+ if value:
194
+ # Extract the ID part (after namespace:)
195
+ key_id = key.split(":", 1)[1]
196
+ try:
197
+ data = json.loads(value)
198
+ # Ensure time fields are present for all documents
199
+ data.setdefault("create_time", 0)
200
+ data.setdefault("update_time", 0)
201
+ result[key_id] = data
202
+ except json.JSONDecodeError as e:
203
+ logger.error(f"JSON decode error for key {key}: {e}")
204
+ continue
205
+
206
+ return result
207
+ except Exception as e:
208
+ logger.error(f"Error getting all data from Redis: {e}")
209
+ return {}
210
+
211
  async def filter_keys(self, keys: set[str]) -> set[str]:
212
  async with self._get_redis_connection() as redis:
213
  pipe = redis.pipeline()
214
+ keys_list = list(keys) # Convert set to list for indexing
215
+ for key in keys_list:
216
  pipe.exists(f"{self.namespace}:{key}")
217
  results = await pipe.execute()
218
 
219
+ existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
220
  return set(keys) - existing_ids
221
 
222
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
223
  if not data:
224
  return
225
 
226
+ import time
227
+
228
+ current_time = int(time.time()) # Get current Unix timestamp
229
+
230
  async with self._get_redis_connection() as redis:
231
  try:
232
+ # Check which keys already exist to determine create vs update
233
+ pipe = redis.pipeline()
234
+ for k in data.keys():
235
+ pipe.exists(f"{self.namespace}:{k}")
236
+ exists_results = await pipe.execute()
237
+
238
+ # Add timestamps to data
239
+ for i, (k, v) in enumerate(data.items()):
240
+ # For text_chunks namespace, ensure llm_cache_list field exists
241
+ if "text_chunks" in self.namespace:
242
+ if "llm_cache_list" not in v:
243
+ v["llm_cache_list"] = []
244
+
245
+ # Add timestamps based on whether key exists
246
+ if exists_results[i]: # Key exists, only update update_time
247
+ v["update_time"] = current_time
248
+ else: # New key, set both create_time and update_time
249
+ v["create_time"] = current_time
250
+ v["update_time"] = current_time
251
+
252
+ v["_id"] = k
253
+
254
+ # Store the data
255
  pipe = redis.pipeline()
256
  for k, v in data.items():
257
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
258
  await pipe.execute()
259
 
 
 
260
  except json.JSONEncodeError as e:
261
  logger.error(f"JSON encode error during upsert: {e}")
262
  raise
 
282
  )
283
 
284
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
285
+ """Delete specific records from storage by cache mode
286
 
287
  Importance notes for Redis storage:
288
  1. This will immediately delete the specified cache modes from Redis
289
 
290
  Args:
291
+ modes (list[str]): List of cache modes to be dropped from storage
292
 
293
  Returns:
294
  True: if the cache drop successfully
 
298
  return False
299
 
300
  try:
301
+ async with self._get_redis_connection() as redis:
302
+ keys_to_delete = []
303
+
304
+ # Find matching keys for each mode using SCAN
305
+ for mode in modes:
306
+ # Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
307
+ pattern = f"{self.namespace}:{mode}:*"
308
+ cursor = 0
309
+ mode_keys = []
310
+
311
+ while True:
312
+ cursor, keys = await redis.scan(
313
+ cursor, match=pattern, count=1000
314
+ )
315
+ if keys:
316
+ mode_keys.extend(keys)
317
+
318
+ if cursor == 0:
319
+ break
320
+
321
+ keys_to_delete.extend(mode_keys)
322
+ logger.info(
323
+ f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
324
+ )
325
+
326
+ if keys_to_delete:
327
+ # Batch delete
328
+ pipe = redis.pipeline()
329
+ for key in keys_to_delete:
330
+ pipe.delete(key)
331
+ results = await pipe.execute()
332
+ deleted_count = sum(results)
333
+ logger.info(
334
+ f"Dropped {deleted_count} cache entries for modes: {modes}"
335
+ )
336
+ else:
337
+ logger.warning(f"No cache entries found for modes: {modes}")
338
+
339
  return True
340
+ except Exception as e:
341
+ logger.error(f"Error dropping cache by modes in Redis: {e}")
342
  return False
343
 
344
  async def drop(self) -> dict[str, str]:
 
349
  """
350
  async with self._get_redis_connection() as redis:
351
  try:
352
+ # Use SCAN to find all keys with the namespace prefix
353
+ pattern = f"{self.namespace}:*"
354
+ cursor = 0
355
+ deleted_count = 0
356
+
357
+ while True:
358
+ cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
359
+ if keys:
360
+ # Delete keys in batches
361
+ pipe = redis.pipeline()
362
+ for key in keys:
363
+ pipe.delete(key)
364
+ results = await pipe.execute()
365
+ deleted_count += sum(results)
366
+
367
+ if cursor == 0:
368
+ break
369
+
370
+ logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
371
+ return {
372
+ "status": "success",
373
+ "message": f"{deleted_count} keys dropped",
374
+ }
375
 
376
  except Exception as e:
377
  logger.error(f"Error dropping keys from {self.namespace}: {e}")
378
  return {"status": "error", "message": str(e)}
379
+
380
+ async def _migrate_legacy_cache_structure(self):
381
+ """Migrate legacy nested cache structure to flattened structure for Redis
382
+
383
+ Redis already stores data in a flattened way, but we need to check for
384
+ legacy keys that might contain nested JSON structures and migrate them.
385
+
386
+ Early exit if any flattened key is found (indicating migration already done).
387
+ """
388
+ from lightrag.utils import generate_cache_key
389
+
390
+ async with self._get_redis_connection() as redis:
391
+ # Get all keys for this namespace
392
+ keys = await redis.keys(f"{self.namespace}:*")
393
+
394
+ if not keys:
395
+ return
396
+
397
+ # Check if we have any flattened keys already - if so, skip migration
398
+ has_flattened_keys = False
399
+ keys_to_migrate = []
400
+
401
+ for key in keys:
402
+ # Extract the ID part (after namespace:)
403
+ key_id = key.split(":", 1)[1]
404
+
405
+ # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
406
+ if ":" in key_id and len(key_id.split(":")) == 3:
407
+ has_flattened_keys = True
408
+ break # Early exit - migration already done
409
+
410
+ # Get the data to check if it's a legacy nested structure
411
+ data = await redis.get(key)
412
+ if data:
413
+ try:
414
+ parsed_data = json.loads(data)
415
+ # Check if this looks like a legacy cache mode with nested structure
416
+ if isinstance(parsed_data, dict) and all(
417
+ isinstance(v, dict) and "return" in v
418
+ for v in parsed_data.values()
419
+ ):
420
+ keys_to_migrate.append((key, key_id, parsed_data))
421
+ except json.JSONDecodeError:
422
+ continue
423
+
424
+ # If we found any flattened keys, assume migration is already done
425
+ if has_flattened_keys:
426
+ logger.debug(
427
+ f"Found flattened cache keys in {self.namespace}, skipping migration"
428
+ )
429
+ return
430
+
431
+ if not keys_to_migrate:
432
+ return
433
+
434
+ # Perform migration
435
+ pipe = redis.pipeline()
436
+ migration_count = 0
437
+
438
+ for old_key, mode, nested_data in keys_to_migrate:
439
+ # Delete the old key
440
+ pipe.delete(old_key)
441
+
442
+ # Create new flattened keys
443
+ for cache_hash, cache_entry in nested_data.items():
444
+ cache_type = cache_entry.get("cache_type", "extract")
445
+ flattened_key = generate_cache_key(mode, cache_type, cache_hash)
446
+ full_key = f"{self.namespace}:{flattened_key}"
447
+ pipe.set(full_key, json.dumps(cache_entry))
448
+ migration_count += 1
449
+
450
+ await pipe.execute()
451
+
452
+ if migration_count > 0:
453
+ logger.info(
454
+ f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
455
+ )
456
+
457
+
458
+ @final
459
+ @dataclass
460
+ class RedisDocStatusStorage(DocStatusStorage):
461
+ """Redis implementation of document status storage"""
462
+
463
+ def __post_init__(self):
464
+ redis_url = os.environ.get(
465
+ "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
466
+ )
467
+ # Use shared connection pool
468
+ self._pool = RedisConnectionManager.get_pool(redis_url)
469
+ self._redis = Redis(connection_pool=self._pool)
470
+ logger.info(
471
+ f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
472
+ )
473
+
474
+ async def initialize(self):
475
+ """Initialize Redis connection"""
476
+ try:
477
+ async with self._get_redis_connection() as redis:
478
+ await redis.ping()
479
+ logger.info(
480
+ f"Connected to Redis for doc status namespace {self.namespace}"
481
+ )
482
+ except Exception as e:
483
+ logger.error(f"Failed to connect to Redis for doc status: {e}")
484
+ raise
485
+
486
+ @asynccontextmanager
487
+ async def _get_redis_connection(self):
488
+ """Safe context manager for Redis operations."""
489
+ try:
490
+ yield self._redis
491
+ except ConnectionError as e:
492
+ logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
493
+ raise
494
+ except RedisError as e:
495
+ logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
496
+ raise
497
+ except Exception as e:
498
+ logger.error(
499
+ f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
500
+ )
501
+ raise
502
+
503
+ async def close(self):
504
+ """Close the Redis connection."""
505
+ if hasattr(self, "_redis") and self._redis:
506
+ await self._redis.close()
507
+ logger.debug(f"Closed Redis connection for doc status {self.namespace}")
508
+
509
+ async def __aenter__(self):
510
+ """Support for async context manager."""
511
+ return self
512
+
513
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
514
+ """Ensure Redis resources are cleaned up when exiting context."""
515
+ await self.close()
516
+
517
+ async def filter_keys(self, keys: set[str]) -> set[str]:
518
+ """Return keys that should be processed (not in storage or not successfully processed)"""
519
+ async with self._get_redis_connection() as redis:
520
+ pipe = redis.pipeline()
521
+ keys_list = list(keys)
522
+ for key in keys_list:
523
+ pipe.exists(f"{self.namespace}:{key}")
524
+ results = await pipe.execute()
525
+
526
+ existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
527
+ return set(keys) - existing_ids
528
+
529
+ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
530
+ result: list[dict[str, Any]] = []
531
+ async with self._get_redis_connection() as redis:
532
+ try:
533
+ pipe = redis.pipeline()
534
+ for id in ids:
535
+ pipe.get(f"{self.namespace}:{id}")
536
+ results = await pipe.execute()
537
+
538
+ for result_data in results:
539
+ if result_data:
540
+ try:
541
+ result.append(json.loads(result_data))
542
+ except json.JSONDecodeError as e:
543
+ logger.error(f"JSON decode error in get_by_ids: {e}")
544
+ continue
545
+ except Exception as e:
546
+ logger.error(f"Error in get_by_ids: {e}")
547
+ return result
548
+
549
+ async def get_status_counts(self) -> dict[str, int]:
550
+ """Get counts of documents in each status"""
551
+ counts = {status.value: 0 for status in DocStatus}
552
+ async with self._get_redis_connection() as redis:
553
+ try:
554
+ # Use SCAN to iterate through all keys in the namespace
555
+ cursor = 0
556
+ while True:
557
+ cursor, keys = await redis.scan(
558
+ cursor, match=f"{self.namespace}:*", count=1000
559
+ )
560
+ if keys:
561
+ # Get all values in batch
562
+ pipe = redis.pipeline()
563
+ for key in keys:
564
+ pipe.get(key)
565
+ values = await pipe.execute()
566
+
567
+ # Count statuses
568
+ for value in values:
569
+ if value:
570
+ try:
571
+ doc_data = json.loads(value)
572
+ status = doc_data.get("status")
573
+ if status in counts:
574
+ counts[status] += 1
575
+ except json.JSONDecodeError:
576
+ continue
577
+
578
+ if cursor == 0:
579
+ break
580
+ except Exception as e:
581
+ logger.error(f"Error getting status counts: {e}")
582
+
583
+ return counts
584
+
585
+ async def get_docs_by_status(
586
+ self, status: DocStatus
587
+ ) -> dict[str, DocProcessingStatus]:
588
+ """Get all documents with a specific status"""
589
+ result = {}
590
+ async with self._get_redis_connection() as redis:
591
+ try:
592
+ # Use SCAN to iterate through all keys in the namespace
593
+ cursor = 0
594
+ while True:
595
+ cursor, keys = await redis.scan(
596
+ cursor, match=f"{self.namespace}:*", count=1000
597
+ )
598
+ if keys:
599
+ # Get all values in batch
600
+ pipe = redis.pipeline()
601
+ for key in keys:
602
+ pipe.get(key)
603
+ values = await pipe.execute()
604
+
605
+ # Filter by status and create DocProcessingStatus objects
606
+ for key, value in zip(keys, values):
607
+ if value:
608
+ try:
609
+ doc_data = json.loads(value)
610
+ if doc_data.get("status") == status.value:
611
+ # Extract document ID from key
612
+ doc_id = key.split(":", 1)[1]
613
+
614
+ # Make a copy of the data to avoid modifying the original
615
+ data = doc_data.copy()
616
+ # If content is missing, use content_summary as content
617
+ if (
618
+ "content" not in data
619
+ and "content_summary" in data
620
+ ):
621
+ data["content"] = data["content_summary"]
622
+ # If file_path is not in data, use document id as file path
623
+ if "file_path" not in data:
624
+ data["file_path"] = "no-file-path"
625
+
626
+ result[doc_id] = DocProcessingStatus(**data)
627
+ except (json.JSONDecodeError, KeyError) as e:
628
+ logger.error(
629
+ f"Error processing document {key}: {e}"
630
+ )
631
+ continue
632
+
633
+ if cursor == 0:
634
+ break
635
+ except Exception as e:
636
+ logger.error(f"Error getting docs by status: {e}")
637
+
638
+ return result
639
+
640
+ async def index_done_callback(self) -> None:
641
+ """Redis handles persistence automatically"""
642
+ pass
643
+
644
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
645
+ """Insert or update document status data"""
646
+ if not data:
647
+ return
648
+
649
+ logger.debug(f"Inserting {len(data)} records to {self.namespace}")
650
+ async with self._get_redis_connection() as redis:
651
+ try:
652
+ # Ensure chunks_list field exists for new documents
653
+ for doc_id, doc_data in data.items():
654
+ if "chunks_list" not in doc_data:
655
+ doc_data["chunks_list"] = []
656
+
657
+ pipe = redis.pipeline()
658
+ for k, v in data.items():
659
+ pipe.set(f"{self.namespace}:{k}", json.dumps(v))
660
+ await pipe.execute()
661
+ except json.JSONEncodeError as e:
662
+ logger.error(f"JSON encode error during upsert: {e}")
663
+ raise
664
+
665
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
666
+ async with self._get_redis_connection() as redis:
667
+ try:
668
+ data = await redis.get(f"{self.namespace}:{id}")
669
+ return json.loads(data) if data else None
670
+ except json.JSONDecodeError as e:
671
+ logger.error(f"JSON decode error for id {id}: {e}")
672
+ return None
673
+
674
+ async def delete(self, doc_ids: list[str]) -> None:
675
+ """Delete specific records from storage by their IDs"""
676
+ if not doc_ids:
677
+ return
678
+
679
+ async with self._get_redis_connection() as redis:
680
+ pipe = redis.pipeline()
681
+ for doc_id in doc_ids:
682
+ pipe.delete(f"{self.namespace}:{doc_id}")
683
+
684
+ results = await pipe.execute()
685
+ deleted_count = sum(results)
686
+ logger.info(
687
+ f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
688
+ )
689
+
690
+ async def drop(self) -> dict[str, str]:
691
+ """Drop all document status data from storage and clean up resources"""
692
+ try:
693
+ async with self._get_redis_connection() as redis:
694
+ # Use SCAN to find all keys with the namespace prefix
695
+ pattern = f"{self.namespace}:*"
696
+ cursor = 0
697
+ deleted_count = 0
698
+
699
+ while True:
700
+ cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
701
+ if keys:
702
+ # Delete keys in batches
703
+ pipe = redis.pipeline()
704
+ for key in keys:
705
+ pipe.delete(key)
706
+ results = await pipe.execute()
707
+ deleted_count += sum(results)
708
+
709
+ if cursor == 0:
710
+ break
711
+
712
+ logger.info(
713
+ f"Dropped {deleted_count} doc status keys from {self.namespace}"
714
+ )
715
+ return {"status": "success", "message": "data dropped"}
716
+ except Exception as e:
717
+ logger.error(f"Error dropping doc status {self.namespace}: {e}")
718
+ return {"status": "error", "message": str(e)}
lightrag/lightrag.py CHANGED
@@ -22,6 +22,7 @@ from typing import (
22
  Dict,
23
  )
24
  from lightrag.constants import (
 
25
  DEFAULT_MAX_TOKEN_SUMMARY,
26
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
27
  )
@@ -124,7 +125,9 @@ class LightRAG:
124
  # Entity extraction
125
  # ---
126
 
127
- entity_extract_max_gleaning: int = field(default=1)
 
 
128
  """Maximum number of entity extraction attempts for ambiguous content."""
129
 
130
  summary_to_max_tokens: int = field(
@@ -346,6 +349,7 @@ class LightRAG:
346
 
347
  # Fix global_config now
348
  global_config = asdict(self)
 
349
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
350
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
351
 
@@ -394,13 +398,13 @@ class LightRAG:
394
  embedding_func=self.embedding_func,
395
  )
396
 
397
- # TODO: deprecating, text_chunks is redundant with chunks_vdb
398
  self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
399
  namespace=make_namespace(
400
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
401
  ),
402
  embedding_func=self.embedding_func,
403
  )
 
404
  self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
405
  namespace=make_namespace(
406
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
@@ -949,6 +953,7 @@ class LightRAG:
949
  **dp,
950
  "full_doc_id": doc_id,
951
  "file_path": file_path, # Add file path to each chunk
 
952
  }
953
  for dp in self.chunking_func(
954
  self.tokenizer,
@@ -960,14 +965,17 @@ class LightRAG:
960
  )
961
  }
962
 
963
- # Process document (text chunks and full docs) in parallel
964
- # Create tasks with references for potential cancellation
965
  doc_status_task = asyncio.create_task(
966
  self.doc_status.upsert(
967
  {
968
  doc_id: {
969
  "status": DocStatus.PROCESSING,
970
  "chunks_count": len(chunks),
 
 
 
971
  "content": status_doc.content,
972
  "content_summary": status_doc.content_summary,
973
  "content_length": status_doc.content_length,
@@ -983,11 +991,6 @@ class LightRAG:
983
  chunks_vdb_task = asyncio.create_task(
984
  self.chunks_vdb.upsert(chunks)
985
  )
986
- entity_relation_task = asyncio.create_task(
987
- self._process_entity_relation_graph(
988
- chunks, pipeline_status, pipeline_status_lock
989
- )
990
- )
991
  full_docs_task = asyncio.create_task(
992
  self.full_docs.upsert(
993
  {doc_id: {"content": status_doc.content}}
@@ -996,14 +999,26 @@ class LightRAG:
996
  text_chunks_task = asyncio.create_task(
997
  self.text_chunks.upsert(chunks)
998
  )
999
- tasks = [
 
 
1000
  doc_status_task,
1001
  chunks_vdb_task,
1002
- entity_relation_task,
1003
  full_docs_task,
1004
  text_chunks_task,
1005
  ]
1006
- await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
1007
  file_extraction_stage_ok = True
1008
 
1009
  except Exception as e:
@@ -1018,14 +1033,14 @@ class LightRAG:
1018
  )
1019
  pipeline_status["history_messages"].append(error_msg)
1020
 
1021
- # Cancel other tasks as they are no longer meaningful
1022
- for task in [
1023
- chunks_vdb_task,
1024
- entity_relation_task,
1025
- full_docs_task,
1026
- text_chunks_task,
1027
- ]:
1028
- if not task.done():
1029
  task.cancel()
1030
 
1031
  # Persistent llm cache
@@ -1075,6 +1090,9 @@ class LightRAG:
1075
  doc_id: {
1076
  "status": DocStatus.PROCESSED,
1077
  "chunks_count": len(chunks),
 
 
 
1078
  "content": status_doc.content,
1079
  "content_summary": status_doc.content_summary,
1080
  "content_length": status_doc.content_length,
@@ -1193,6 +1211,7 @@ class LightRAG:
1193
  pipeline_status=pipeline_status,
1194
  pipeline_status_lock=pipeline_status_lock,
1195
  llm_response_cache=self.llm_response_cache,
 
1196
  )
1197
  return chunk_results
1198
  except Exception as e:
@@ -1723,28 +1742,10 @@ class LightRAG:
1723
  file_path="",
1724
  )
1725
 
1726
- # 2. Get all chunks related to this document
1727
- try:
1728
- all_chunks = await self.text_chunks.get_all()
1729
- related_chunks = {
1730
- chunk_id: chunk_data
1731
- for chunk_id, chunk_data in all_chunks.items()
1732
- if isinstance(chunk_data, dict)
1733
- and chunk_data.get("full_doc_id") == doc_id
1734
- }
1735
 
1736
- # Update pipeline status after getting chunks count
1737
- async with pipeline_status_lock:
1738
- log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
1739
- logger.info(log_message)
1740
- pipeline_status["latest_message"] = log_message
1741
- pipeline_status["history_messages"].append(log_message)
1742
-
1743
- except Exception as e:
1744
- logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
1745
- raise Exception(f"Failed to retrieve document chunks: {e}") from e
1746
-
1747
- if not related_chunks:
1748
  logger.warning(f"No chunks found for document {doc_id}")
1749
  # Mark that deletion operations have started
1750
  deletion_operations_started = True
@@ -1775,7 +1776,6 @@ class LightRAG:
1775
  file_path=file_path,
1776
  )
1777
 
1778
- chunk_ids = set(related_chunks.keys())
1779
  # Mark that deletion operations have started
1780
  deletion_operations_started = True
1781
 
@@ -1799,26 +1799,12 @@ class LightRAG:
1799
  )
1800
  )
1801
 
1802
- # Update pipeline status after getting affected_nodes
1803
- async with pipeline_status_lock:
1804
- log_message = f"Found {len(affected_nodes)} affected entities"
1805
- logger.info(log_message)
1806
- pipeline_status["latest_message"] = log_message
1807
- pipeline_status["history_messages"].append(log_message)
1808
-
1809
  affected_edges = (
1810
  await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
1811
  list(chunk_ids)
1812
  )
1813
  )
1814
 
1815
- # Update pipeline status after getting affected_edges
1816
- async with pipeline_status_lock:
1817
- log_message = f"Found {len(affected_edges)} affected relations"
1818
- logger.info(log_message)
1819
- pipeline_status["latest_message"] = log_message
1820
- pipeline_status["history_messages"].append(log_message)
1821
-
1822
  except Exception as e:
1823
  logger.error(f"Failed to analyze affected graph elements: {e}")
1824
  raise Exception(f"Failed to analyze graph dependencies: {e}") from e
@@ -1836,6 +1822,14 @@ class LightRAG:
1836
  elif remaining_sources != sources:
1837
  entities_to_rebuild[node_label] = remaining_sources
1838
 
 
 
 
 
 
 
 
 
1839
  # Process relationships
1840
  for edge_data in affected_edges:
1841
  src = edge_data.get("source")
@@ -1857,6 +1851,14 @@ class LightRAG:
1857
  elif remaining_sources != sources:
1858
  relationships_to_rebuild[edge_tuple] = remaining_sources
1859
 
 
 
 
 
 
 
 
 
1860
  except Exception as e:
1861
  logger.error(f"Failed to process graph analysis results: {e}")
1862
  raise Exception(f"Failed to process graph dependencies: {e}") from e
@@ -1940,17 +1942,13 @@ class LightRAG:
1940
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1941
  entities_vdb=self.entities_vdb,
1942
  relationships_vdb=self.relationships_vdb,
1943
- text_chunks=self.text_chunks,
1944
  llm_response_cache=self.llm_response_cache,
1945
  global_config=asdict(self),
 
 
1946
  )
1947
 
1948
- async with pipeline_status_lock:
1949
- log_message = f"Successfully rebuilt {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relations"
1950
- logger.info(log_message)
1951
- pipeline_status["latest_message"] = log_message
1952
- pipeline_status["history_messages"].append(log_message)
1953
-
1954
  except Exception as e:
1955
  logger.error(f"Failed to rebuild knowledge from chunks: {e}")
1956
  raise Exception(
 
22
  Dict,
23
  )
24
  from lightrag.constants import (
25
+ DEFAULT_MAX_GLEANING,
26
  DEFAULT_MAX_TOKEN_SUMMARY,
27
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
28
  )
 
125
  # Entity extraction
126
  # ---
127
 
128
+ entity_extract_max_gleaning: int = field(
129
+ default=get_env_value("MAX_GLEANING", DEFAULT_MAX_GLEANING, int)
130
+ )
131
  """Maximum number of entity extraction attempts for ambiguous content."""
132
 
133
  summary_to_max_tokens: int = field(
 
349
 
350
  # Fix global_config now
351
  global_config = asdict(self)
352
+
353
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
354
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
355
 
 
398
  embedding_func=self.embedding_func,
399
  )
400
 
 
401
  self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
402
  namespace=make_namespace(
403
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
404
  ),
405
  embedding_func=self.embedding_func,
406
  )
407
+
408
  self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
409
  namespace=make_namespace(
410
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
 
953
  **dp,
954
  "full_doc_id": doc_id,
955
  "file_path": file_path, # Add file path to each chunk
956
+ "llm_cache_list": [], # Initialize empty LLM cache list for each chunk
957
  }
958
  for dp in self.chunking_func(
959
  self.tokenizer,
 
965
  )
966
  }
967
 
968
+ # Process document in two stages
969
+ # Stage 1: Process text chunks and docs (parallel execution)
970
  doc_status_task = asyncio.create_task(
971
  self.doc_status.upsert(
972
  {
973
  doc_id: {
974
  "status": DocStatus.PROCESSING,
975
  "chunks_count": len(chunks),
976
+ "chunks_list": list(
977
+ chunks.keys()
978
+ ), # Save chunks list
979
  "content": status_doc.content,
980
  "content_summary": status_doc.content_summary,
981
  "content_length": status_doc.content_length,
 
991
  chunks_vdb_task = asyncio.create_task(
992
  self.chunks_vdb.upsert(chunks)
993
  )
 
 
 
 
 
994
  full_docs_task = asyncio.create_task(
995
  self.full_docs.upsert(
996
  {doc_id: {"content": status_doc.content}}
 
999
  text_chunks_task = asyncio.create_task(
1000
  self.text_chunks.upsert(chunks)
1001
  )
1002
+
1003
+ # First stage tasks (parallel execution)
1004
+ first_stage_tasks = [
1005
  doc_status_task,
1006
  chunks_vdb_task,
 
1007
  full_docs_task,
1008
  text_chunks_task,
1009
  ]
1010
+ entity_relation_task = None
1011
+
1012
+ # Execute first stage tasks
1013
+ await asyncio.gather(*first_stage_tasks)
1014
+
1015
+ # Stage 2: Process entity relation graph (after text_chunks are saved)
1016
+ entity_relation_task = asyncio.create_task(
1017
+ self._process_entity_relation_graph(
1018
+ chunks, pipeline_status, pipeline_status_lock
1019
+ )
1020
+ )
1021
+ await entity_relation_task
1022
  file_extraction_stage_ok = True
1023
 
1024
  except Exception as e:
 
1033
  )
1034
  pipeline_status["history_messages"].append(error_msg)
1035
 
1036
+ # Cancel tasks that are not yet completed
1037
+ all_tasks = first_stage_tasks + (
1038
+ [entity_relation_task]
1039
+ if entity_relation_task
1040
+ else []
1041
+ )
1042
+ for task in all_tasks:
1043
+ if task and not task.done():
1044
  task.cancel()
1045
 
1046
  # Persistent llm cache
 
1090
  doc_id: {
1091
  "status": DocStatus.PROCESSED,
1092
  "chunks_count": len(chunks),
1093
+ "chunks_list": list(
1094
+ chunks.keys()
1095
+ ), # 保留 chunks_list
1096
  "content": status_doc.content,
1097
  "content_summary": status_doc.content_summary,
1098
  "content_length": status_doc.content_length,
 
1211
  pipeline_status=pipeline_status,
1212
  pipeline_status_lock=pipeline_status_lock,
1213
  llm_response_cache=self.llm_response_cache,
1214
+ text_chunks_storage=self.text_chunks,
1215
  )
1216
  return chunk_results
1217
  except Exception as e:
 
1742
  file_path="",
1743
  )
1744
 
1745
+ # 2. Get chunk IDs from document status
1746
+ chunk_ids = set(doc_status_data.get("chunks_list", []))
 
 
 
 
 
 
 
1747
 
1748
+ if not chunk_ids:
 
 
 
 
 
 
 
 
 
 
 
1749
  logger.warning(f"No chunks found for document {doc_id}")
1750
  # Mark that deletion operations have started
1751
  deletion_operations_started = True
 
1776
  file_path=file_path,
1777
  )
1778
 
 
1779
  # Mark that deletion operations have started
1780
  deletion_operations_started = True
1781
 
 
1799
  )
1800
  )
1801
 
 
 
 
 
 
 
 
1802
  affected_edges = (
1803
  await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
1804
  list(chunk_ids)
1805
  )
1806
  )
1807
 
 
 
 
 
 
 
 
1808
  except Exception as e:
1809
  logger.error(f"Failed to analyze affected graph elements: {e}")
1810
  raise Exception(f"Failed to analyze graph dependencies: {e}") from e
 
1822
  elif remaining_sources != sources:
1823
  entities_to_rebuild[node_label] = remaining_sources
1824
 
1825
+ async with pipeline_status_lock:
1826
+ log_message = (
1827
+ f"Found {len(entities_to_rebuild)} affected entities"
1828
+ )
1829
+ logger.info(log_message)
1830
+ pipeline_status["latest_message"] = log_message
1831
+ pipeline_status["history_messages"].append(log_message)
1832
+
1833
  # Process relationships
1834
  for edge_data in affected_edges:
1835
  src = edge_data.get("source")
 
1851
  elif remaining_sources != sources:
1852
  relationships_to_rebuild[edge_tuple] = remaining_sources
1853
 
1854
+ async with pipeline_status_lock:
1855
+ log_message = (
1856
+ f"Found {len(relationships_to_rebuild)} affected relations"
1857
+ )
1858
+ logger.info(log_message)
1859
+ pipeline_status["latest_message"] = log_message
1860
+ pipeline_status["history_messages"].append(log_message)
1861
+
1862
  except Exception as e:
1863
  logger.error(f"Failed to process graph analysis results: {e}")
1864
  raise Exception(f"Failed to process graph dependencies: {e}") from e
 
1942
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1943
  entities_vdb=self.entities_vdb,
1944
  relationships_vdb=self.relationships_vdb,
1945
+ text_chunks_storage=self.text_chunks,
1946
  llm_response_cache=self.llm_response_cache,
1947
  global_config=asdict(self),
1948
+ pipeline_status=pipeline_status,
1949
+ pipeline_status_lock=pipeline_status_lock,
1950
  )
1951
 
 
 
 
 
 
 
1952
  except Exception as e:
1953
  logger.error(f"Failed to rebuild knowledge from chunks: {e}")
1954
  raise Exception(
lightrag/operate.py CHANGED
@@ -25,6 +25,7 @@ from .utils import (
25
  CacheData,
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
 
28
  )
29
  from .base import (
30
  BaseGraphStorage,
@@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
103
  entity_or_relation_name: str,
104
  description: str,
105
  global_config: dict,
106
- pipeline_status: dict = None,
107
- pipeline_status_lock=None,
108
  llm_response_cache: BaseKVStorage | None = None,
109
  ) -> str:
110
  """Handle entity relation summary
@@ -247,9 +246,11 @@ async def _rebuild_knowledge_from_chunks(
247
  knowledge_graph_inst: BaseGraphStorage,
248
  entities_vdb: BaseVectorStorage,
249
  relationships_vdb: BaseVectorStorage,
250
- text_chunks: BaseKVStorage,
251
  llm_response_cache: BaseKVStorage,
252
  global_config: dict[str, str],
 
 
253
  ) -> None:
254
  """Rebuild entity and relationship descriptions from cached extraction results
255
 
@@ -259,9 +260,12 @@ async def _rebuild_knowledge_from_chunks(
259
  Args:
260
  entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
261
  relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
 
262
  """
263
  if not entities_to_rebuild and not relationships_to_rebuild:
264
  return
 
 
265
 
266
  # Get all referenced chunk IDs
267
  all_referenced_chunk_ids = set()
@@ -270,36 +274,74 @@ async def _rebuild_knowledge_from_chunks(
270
  for chunk_ids in relationships_to_rebuild.values():
271
  all_referenced_chunk_ids.update(chunk_ids)
272
 
273
- logger.debug(
274
- f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
275
- )
 
 
 
276
 
277
- # Get cached extraction results for these chunks
 
278
  cached_results = await _get_cached_extraction_results(
279
- llm_response_cache, all_referenced_chunk_ids
 
 
280
  )
281
 
282
  if not cached_results:
283
- logger.warning("No cached extraction results found, cannot rebuild")
 
 
 
 
 
284
  return
285
 
286
  # Process cached results to get entities and relationships for each chunk
287
  chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
288
  chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
289
 
290
- for chunk_id, extraction_result in cached_results.items():
291
  try:
292
- entities, relationships = await _parse_extraction_result(
293
- text_chunks=text_chunks,
294
- extraction_result=extraction_result,
295
- chunk_id=chunk_id,
296
- )
297
- chunk_entities[chunk_id] = entities
298
- chunk_relationships[chunk_id] = relationships
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  except Exception as e:
300
- logger.error(
301
  f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
302
  )
 
 
 
 
 
303
  continue
304
 
305
  # Rebuild entities
@@ -314,11 +356,22 @@ async def _rebuild_knowledge_from_chunks(
314
  llm_response_cache=llm_response_cache,
315
  global_config=global_config,
316
  )
317
- logger.debug(
318
- f"Rebuilt entity {entity_name} from {len(chunk_ids)} cached extractions"
 
319
  )
 
 
 
 
 
320
  except Exception as e:
321
- logger.error(f"Failed to rebuild entity {entity_name}: {e}")
 
 
 
 
 
322
 
323
  # Rebuild relationships
324
  for (src, tgt), chunk_ids in relationships_to_rebuild.items():
@@ -333,53 +386,112 @@ async def _rebuild_knowledge_from_chunks(
333
  llm_response_cache=llm_response_cache,
334
  global_config=global_config,
335
  )
336
- logger.debug(
337
- f"Rebuilt relationship {src}-{tgt} from {len(chunk_ids)} cached extractions"
 
338
  )
 
 
 
 
 
339
  except Exception as e:
340
- logger.error(f"Failed to rebuild relationship {src}-{tgt}: {e}")
 
 
 
 
 
341
 
342
- logger.debug("Completed rebuilding knowledge from cached extractions")
 
 
 
 
 
343
 
344
 
345
  async def _get_cached_extraction_results(
346
- llm_response_cache: BaseKVStorage, chunk_ids: set[str]
347
- ) -> dict[str, str]:
 
 
348
  """Get cached extraction results for specific chunk IDs
349
 
350
  Args:
 
351
  chunk_ids: Set of chunk IDs to get cached results for
 
 
352
 
353
  Returns:
354
- Dict mapping chunk_id -> extraction_result_text
355
  """
356
  cached_results = {}
357
 
358
- # Get all cached data for "default" mode (entity extraction cache)
359
- default_cache = await llm_response_cache.get_by_id("default") or {}
360
 
361
- for cache_key, cache_entry in default_cache.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  if (
363
- isinstance(cache_entry, dict)
 
364
  and cache_entry.get("cache_type") == "extract"
365
  and cache_entry.get("chunk_id") in chunk_ids
366
  ):
367
  chunk_id = cache_entry["chunk_id"]
368
  extraction_result = cache_entry["return"]
369
- cached_results[chunk_id] = extraction_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
- logger.debug(
372
- f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs"
373
  )
374
  return cached_results
375
 
376
 
377
  async def _parse_extraction_result(
378
- text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str
379
  ) -> tuple[dict, dict]:
380
  """Parse cached extraction result using the same logic as extract_entities
381
 
382
  Args:
 
383
  extraction_result: The cached LLM extraction result
384
  chunk_id: The chunk ID for source tracking
385
 
@@ -387,8 +499,8 @@ async def _parse_extraction_result(
387
  Tuple of (entities_dict, relationships_dict)
388
  """
389
 
390
- # Get chunk data for file_path
391
- chunk_data = await text_chunks.get_by_id(chunk_id)
392
  file_path = (
393
  chunk_data.get("file_path", "unknown_source")
394
  if chunk_data
@@ -761,8 +873,6 @@ async def _merge_nodes_then_upsert(
761
  entity_name,
762
  description,
763
  global_config,
764
- pipeline_status,
765
- pipeline_status_lock,
766
  llm_response_cache,
767
  )
768
  else:
@@ -925,8 +1035,6 @@ async def _merge_edges_then_upsert(
925
  f"({src_id}, {tgt_id})",
926
  description,
927
  global_config,
928
- pipeline_status,
929
- pipeline_status_lock,
930
  llm_response_cache,
931
  )
932
  else:
@@ -1102,6 +1210,7 @@ async def extract_entities(
1102
  pipeline_status: dict = None,
1103
  pipeline_status_lock=None,
1104
  llm_response_cache: BaseKVStorage | None = None,
 
1105
  ) -> list:
1106
  use_llm_func: callable = global_config["llm_model_func"]
1107
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -1208,6 +1317,9 @@ async def extract_entities(
1208
  # Get file path from chunk data or use default
1209
  file_path = chunk_dp.get("file_path", "unknown_source")
1210
 
 
 
 
1211
  # Get initial extraction
1212
  hint_prompt = entity_extract_prompt.format(
1213
  **{**context_base, "input_text": content}
@@ -1219,7 +1331,10 @@ async def extract_entities(
1219
  llm_response_cache=llm_response_cache,
1220
  cache_type="extract",
1221
  chunk_id=chunk_key,
 
1222
  )
 
 
1223
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
1224
 
1225
  # Process initial extraction with file path
@@ -1236,6 +1351,7 @@ async def extract_entities(
1236
  history_messages=history,
1237
  cache_type="extract",
1238
  chunk_id=chunk_key,
 
1239
  )
1240
 
1241
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -1266,11 +1382,21 @@ async def extract_entities(
1266
  llm_response_cache=llm_response_cache,
1267
  history_messages=history,
1268
  cache_type="extract",
 
1269
  )
1270
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
1271
  if if_loop_result != "yes":
1272
  break
1273
 
 
 
 
 
 
 
 
 
 
1274
  processed_chunks += 1
1275
  entities_count = len(maybe_nodes)
1276
  relations_count = len(maybe_edges)
@@ -1343,7 +1469,7 @@ async def kg_query(
1343
  use_model_func = partial(use_model_func, _priority=5)
1344
 
1345
  # Handle cache
1346
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1347
  cached_response, quantized, min_val, max_val = await handle_cache(
1348
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
1349
  )
@@ -1390,7 +1516,7 @@ async def kg_query(
1390
  )
1391
 
1392
  if query_param.only_need_context:
1393
- return context
1394
  if context is None:
1395
  return PROMPTS["fail_response"]
1396
 
@@ -1502,7 +1628,7 @@ async def extract_keywords_only(
1502
  """
1503
 
1504
  # 1. Handle cache if needed - add cache type for keywords
1505
- args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
1506
  cached_response, quantized, min_val, max_val = await handle_cache(
1507
  hashing_kv, args_hash, text, param.mode, cache_type="keywords"
1508
  )
@@ -1647,7 +1773,7 @@ async def _get_vector_context(
1647
  f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1648
  )
1649
  logger.info(
1650
- f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
1651
  )
1652
 
1653
  if not maybe_trun_chunks:
@@ -1871,7 +1997,7 @@ async def _get_node_data(
1871
  )
1872
 
1873
  logger.info(
1874
- f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
1875
  )
1876
 
1877
  # build prompt
@@ -2180,7 +2306,7 @@ async def _get_edge_data(
2180
  ),
2181
  )
2182
  logger.info(
2183
- f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
2184
  )
2185
 
2186
  relations_context = []
@@ -2369,7 +2495,7 @@ async def naive_query(
2369
  use_model_func = partial(use_model_func, _priority=5)
2370
 
2371
  # Handle cache
2372
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
2373
  cached_response, quantized, min_val, max_val = await handle_cache(
2374
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2375
  )
@@ -2485,7 +2611,7 @@ async def kg_query_with_keywords(
2485
  # Apply higher priority (5) to query relation LLM function
2486
  use_model_func = partial(use_model_func, _priority=5)
2487
 
2488
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
2489
  cached_response, quantized, min_val, max_val = await handle_cache(
2490
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2491
  )
 
25
  CacheData,
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
28
+ update_chunk_cache_list,
29
  )
30
  from .base import (
31
  BaseGraphStorage,
 
104
  entity_or_relation_name: str,
105
  description: str,
106
  global_config: dict,
 
 
107
  llm_response_cache: BaseKVStorage | None = None,
108
  ) -> str:
109
  """Handle entity relation summary
 
246
  knowledge_graph_inst: BaseGraphStorage,
247
  entities_vdb: BaseVectorStorage,
248
  relationships_vdb: BaseVectorStorage,
249
+ text_chunks_storage: BaseKVStorage,
250
  llm_response_cache: BaseKVStorage,
251
  global_config: dict[str, str],
252
+ pipeline_status: dict | None = None,
253
+ pipeline_status_lock=None,
254
  ) -> None:
255
  """Rebuild entity and relationship descriptions from cached extraction results
256
 
 
260
  Args:
261
  entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
262
  relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
263
+ text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
264
  """
265
  if not entities_to_rebuild and not relationships_to_rebuild:
266
  return
267
+ rebuilt_entities_count = 0
268
+ rebuilt_relationships_count = 0
269
 
270
  # Get all referenced chunk IDs
271
  all_referenced_chunk_ids = set()
 
274
  for chunk_ids in relationships_to_rebuild.values():
275
  all_referenced_chunk_ids.update(chunk_ids)
276
 
277
+ status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
278
+ logger.info(status_message)
279
+ if pipeline_status is not None and pipeline_status_lock is not None:
280
+ async with pipeline_status_lock:
281
+ pipeline_status["latest_message"] = status_message
282
+ pipeline_status["history_messages"].append(status_message)
283
 
284
+ # Get cached extraction results for these chunks using storage
285
+ # cached_results: chunk_id -> [list of extraction result from LLM cache sorted by created_at]
286
  cached_results = await _get_cached_extraction_results(
287
+ llm_response_cache,
288
+ all_referenced_chunk_ids,
289
+ text_chunks_storage=text_chunks_storage,
290
  )
291
 
292
  if not cached_results:
293
+ status_message = "No cached extraction results found, cannot rebuild"
294
+ logger.warning(status_message)
295
+ if pipeline_status is not None and pipeline_status_lock is not None:
296
+ async with pipeline_status_lock:
297
+ pipeline_status["latest_message"] = status_message
298
+ pipeline_status["history_messages"].append(status_message)
299
  return
300
 
301
  # Process cached results to get entities and relationships for each chunk
302
  chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
303
  chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
304
 
305
+ for chunk_id, extraction_results in cached_results.items():
306
  try:
307
+ # Handle multiple extraction results per chunk
308
+ chunk_entities[chunk_id] = defaultdict(list)
309
+ chunk_relationships[chunk_id] = defaultdict(list)
310
+
311
+ # process multiple LLM extraction results for a single chunk_id
312
+ for extraction_result in extraction_results:
313
+ entities, relationships = await _parse_extraction_result(
314
+ text_chunks_storage=text_chunks_storage,
315
+ extraction_result=extraction_result,
316
+ chunk_id=chunk_id,
317
+ )
318
+
319
+ # Merge entities and relationships from this extraction result
320
+ # Only keep the first occurrence of each entity_name in the same chunk_id
321
+ for entity_name, entity_list in entities.items():
322
+ if (
323
+ entity_name not in chunk_entities[chunk_id]
324
+ or len(chunk_entities[chunk_id][entity_name]) == 0
325
+ ):
326
+ chunk_entities[chunk_id][entity_name].extend(entity_list)
327
+
328
+ # Only keep the first occurrence of each rel_key in the same chunk_id
329
+ for rel_key, rel_list in relationships.items():
330
+ if (
331
+ rel_key not in chunk_relationships[chunk_id]
332
+ or len(chunk_relationships[chunk_id][rel_key]) == 0
333
+ ):
334
+ chunk_relationships[chunk_id][rel_key].extend(rel_list)
335
+
336
  except Exception as e:
337
+ status_message = (
338
  f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
339
  )
340
+ logger.info(status_message) # Per requirement, change to info
341
+ if pipeline_status is not None and pipeline_status_lock is not None:
342
+ async with pipeline_status_lock:
343
+ pipeline_status["latest_message"] = status_message
344
+ pipeline_status["history_messages"].append(status_message)
345
  continue
346
 
347
  # Rebuild entities
 
356
  llm_response_cache=llm_response_cache,
357
  global_config=global_config,
358
  )
359
+ rebuilt_entities_count += 1
360
+ status_message = (
361
+ f"Rebuilt entity: {entity_name} from {len(chunk_ids)} chunks"
362
  )
363
+ logger.info(status_message)
364
+ if pipeline_status is not None and pipeline_status_lock is not None:
365
+ async with pipeline_status_lock:
366
+ pipeline_status["latest_message"] = status_message
367
+ pipeline_status["history_messages"].append(status_message)
368
  except Exception as e:
369
+ status_message = f"Failed to rebuild entity {entity_name}: {e}"
370
+ logger.info(status_message) # Per requirement, change to info
371
+ if pipeline_status is not None and pipeline_status_lock is not None:
372
+ async with pipeline_status_lock:
373
+ pipeline_status["latest_message"] = status_message
374
+ pipeline_status["history_messages"].append(status_message)
375
 
376
  # Rebuild relationships
377
  for (src, tgt), chunk_ids in relationships_to_rebuild.items():
 
386
  llm_response_cache=llm_response_cache,
387
  global_config=global_config,
388
  )
389
+ rebuilt_relationships_count += 1
390
+ status_message = (
391
+ f"Rebuilt relationship: {src}->{tgt} from {len(chunk_ids)} chunks"
392
  )
393
+ logger.info(status_message)
394
+ if pipeline_status is not None and pipeline_status_lock is not None:
395
+ async with pipeline_status_lock:
396
+ pipeline_status["latest_message"] = status_message
397
+ pipeline_status["history_messages"].append(status_message)
398
  except Exception as e:
399
+ status_message = f"Failed to rebuild relationship {src}->{tgt}: {e}"
400
+ logger.info(status_message)
401
+ if pipeline_status is not None and pipeline_status_lock is not None:
402
+ async with pipeline_status_lock:
403
+ pipeline_status["latest_message"] = status_message
404
+ pipeline_status["history_messages"].append(status_message)
405
 
406
+ status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships."
407
+ logger.info(status_message)
408
+ if pipeline_status is not None and pipeline_status_lock is not None:
409
+ async with pipeline_status_lock:
410
+ pipeline_status["latest_message"] = status_message
411
+ pipeline_status["history_messages"].append(status_message)
412
 
413
 
414
  async def _get_cached_extraction_results(
415
+ llm_response_cache: BaseKVStorage,
416
+ chunk_ids: set[str],
417
+ text_chunks_storage: BaseKVStorage,
418
+ ) -> dict[str, list[str]]:
419
  """Get cached extraction results for specific chunk IDs
420
 
421
  Args:
422
+ llm_response_cache: LLM response cache storage
423
  chunk_ids: Set of chunk IDs to get cached results for
424
+ text_chunks_data: Pre-loaded chunk data (optional, for performance)
425
+ text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
426
 
427
  Returns:
428
+ Dict mapping chunk_id -> list of extraction_result_text
429
  """
430
  cached_results = {}
431
 
432
+ # Collect all LLM cache IDs from chunks
433
+ all_cache_ids = set()
434
 
435
+ # Read from storage
436
+ chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
437
+ for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
438
+ if chunk_data and isinstance(chunk_data, dict):
439
+ llm_cache_list = chunk_data.get("llm_cache_list", [])
440
+ if llm_cache_list:
441
+ all_cache_ids.update(llm_cache_list)
442
+ else:
443
+ logger.warning(
444
+ f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
445
+ )
446
+
447
+ if not all_cache_ids:
448
+ logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
449
+ return cached_results
450
+
451
+ # Batch get LLM cache entries
452
+ cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
453
+
454
+ # Process cache entries and group by chunk_id
455
+ valid_entries = 0
456
+ for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
457
  if (
458
+ cache_entry is not None
459
+ and isinstance(cache_entry, dict)
460
  and cache_entry.get("cache_type") == "extract"
461
  and cache_entry.get("chunk_id") in chunk_ids
462
  ):
463
  chunk_id = cache_entry["chunk_id"]
464
  extraction_result = cache_entry["return"]
465
+ create_time = cache_entry.get(
466
+ "create_time", 0
467
+ ) # Get creation time, default to 0
468
+ valid_entries += 1
469
+
470
+ # Support multiple LLM caches per chunk
471
+ if chunk_id not in cached_results:
472
+ cached_results[chunk_id] = []
473
+ # Store tuple with extraction result and creation time for sorting
474
+ cached_results[chunk_id].append((extraction_result, create_time))
475
+
476
+ # Sort extraction results by create_time for each chunk
477
+ for chunk_id in cached_results:
478
+ # Sort by create_time (x[1]), then extract only extraction_result (x[0])
479
+ cached_results[chunk_id].sort(key=lambda x: x[1])
480
+ cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
481
 
482
+ logger.info(
483
+ f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
484
  )
485
  return cached_results
486
 
487
 
488
  async def _parse_extraction_result(
489
+ text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
490
  ) -> tuple[dict, dict]:
491
  """Parse cached extraction result using the same logic as extract_entities
492
 
493
  Args:
494
+ text_chunks_storage: Text chunks storage to get chunk data
495
  extraction_result: The cached LLM extraction result
496
  chunk_id: The chunk ID for source tracking
497
 
 
499
  Tuple of (entities_dict, relationships_dict)
500
  """
501
 
502
+ # Get chunk data for file_path from storage
503
+ chunk_data = await text_chunks_storage.get_by_id(chunk_id)
504
  file_path = (
505
  chunk_data.get("file_path", "unknown_source")
506
  if chunk_data
 
873
  entity_name,
874
  description,
875
  global_config,
 
 
876
  llm_response_cache,
877
  )
878
  else:
 
1035
  f"({src_id}, {tgt_id})",
1036
  description,
1037
  global_config,
 
 
1038
  llm_response_cache,
1039
  )
1040
  else:
 
1210
  pipeline_status: dict = None,
1211
  pipeline_status_lock=None,
1212
  llm_response_cache: BaseKVStorage | None = None,
1213
+ text_chunks_storage: BaseKVStorage | None = None,
1214
  ) -> list:
1215
  use_llm_func: callable = global_config["llm_model_func"]
1216
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
 
1317
  # Get file path from chunk data or use default
1318
  file_path = chunk_dp.get("file_path", "unknown_source")
1319
 
1320
+ # Create cache keys collector for batch processing
1321
+ cache_keys_collector = []
1322
+
1323
  # Get initial extraction
1324
  hint_prompt = entity_extract_prompt.format(
1325
  **{**context_base, "input_text": content}
 
1331
  llm_response_cache=llm_response_cache,
1332
  cache_type="extract",
1333
  chunk_id=chunk_key,
1334
+ cache_keys_collector=cache_keys_collector,
1335
  )
1336
+
1337
+ # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
1338
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
1339
 
1340
  # Process initial extraction with file path
 
1351
  history_messages=history,
1352
  cache_type="extract",
1353
  chunk_id=chunk_key,
1354
+ cache_keys_collector=cache_keys_collector,
1355
  )
1356
 
1357
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
 
1382
  llm_response_cache=llm_response_cache,
1383
  history_messages=history,
1384
  cache_type="extract",
1385
+ cache_keys_collector=cache_keys_collector,
1386
  )
1387
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
1388
  if if_loop_result != "yes":
1389
  break
1390
 
1391
+ # Batch update chunk's llm_cache_list with all collected cache keys
1392
+ if cache_keys_collector and text_chunks_storage:
1393
+ await update_chunk_cache_list(
1394
+ chunk_key,
1395
+ text_chunks_storage,
1396
+ cache_keys_collector,
1397
+ "entity_extraction",
1398
+ )
1399
+
1400
  processed_chunks += 1
1401
  entities_count = len(maybe_nodes)
1402
  relations_count = len(maybe_edges)
 
1469
  use_model_func = partial(use_model_func, _priority=5)
1470
 
1471
  # Handle cache
1472
+ args_hash = compute_args_hash(query_param.mode, query)
1473
  cached_response, quantized, min_val, max_val = await handle_cache(
1474
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
1475
  )
 
1516
  )
1517
 
1518
  if query_param.only_need_context:
1519
+ return context if context is not None else PROMPTS["fail_response"]
1520
  if context is None:
1521
  return PROMPTS["fail_response"]
1522
 
 
1628
  """
1629
 
1630
  # 1. Handle cache if needed - add cache type for keywords
1631
+ args_hash = compute_args_hash(param.mode, text)
1632
  cached_response, quantized, min_val, max_val = await handle_cache(
1633
  hashing_kv, args_hash, text, param.mode, cache_type="keywords"
1634
  )
 
1773
  f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1774
  )
1775
  logger.info(
1776
+ f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
1777
  )
1778
 
1779
  if not maybe_trun_chunks:
 
1997
  )
1998
 
1999
  logger.info(
2000
+ f"Local query: {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
2001
  )
2002
 
2003
  # build prompt
 
2306
  ),
2307
  )
2308
  logger.info(
2309
+ f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
2310
  )
2311
 
2312
  relations_context = []
 
2495
  use_model_func = partial(use_model_func, _priority=5)
2496
 
2497
  # Handle cache
2498
+ args_hash = compute_args_hash(query_param.mode, query)
2499
  cached_response, quantized, min_val, max_val = await handle_cache(
2500
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2501
  )
 
2611
  # Apply higher priority (5) to query relation LLM function
2612
  use_model_func = partial(use_model_func, _priority=5)
2613
 
2614
+ args_hash = compute_args_hash(query_param.mode, query)
2615
  cached_response, quantized, min_val, max_val = await handle_cache(
2616
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2617
  )
lightrag/utils.py CHANGED
@@ -14,7 +14,6 @@ from functools import wraps
14
  from hashlib import md5
15
  from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import numpy as np
17
- from lightrag.prompt import PROMPTS
18
  from dotenv import load_dotenv
19
  from lightrag.constants import (
20
  DEFAULT_LOG_MAX_BYTES,
@@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
278
  raise e from None
279
 
280
 
281
- def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
282
  """Compute a hash for the given arguments.
283
  Args:
284
  *args: Arguments to hash
285
- cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
286
  Returns:
287
  str: Hash string
288
  """
@@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
290
 
291
  # Convert all arguments to strings and join them
292
  args_str = "".join([str(arg) for arg in args])
293
- if cache_type:
294
- args_str = f"{cache_type}:{args_str}"
295
 
296
  # Compute MD5 hash
297
  return hashlib.md5(args_str.encode()).hexdigest()
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  def compute_mdhash_id(content: str, prefix: str = "") -> str:
301
  """
302
  Compute a unique ID for a given content string.
@@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
783
  return combined_data
784
 
785
 
786
- async def get_best_cached_response(
787
- hashing_kv,
788
- current_embedding,
789
- similarity_threshold=0.95,
790
- mode="default",
791
- use_llm_check=False,
792
- llm_func=None,
793
- original_prompt=None,
794
- cache_type=None,
795
- ) -> str | None:
796
- logger.debug(
797
- f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
798
- )
799
- mode_cache = await hashing_kv.get_by_id(mode)
800
- if not mode_cache:
801
- return None
802
-
803
- best_similarity = -1
804
- best_response = None
805
- best_prompt = None
806
- best_cache_id = None
807
-
808
- # Only iterate through cache entries for this mode
809
- for cache_id, cache_data in mode_cache.items():
810
- # Skip if cache_type doesn't match
811
- if cache_type and cache_data.get("cache_type") != cache_type:
812
- continue
813
-
814
- # Check if cache data is valid
815
- if cache_data["embedding"] is None:
816
- continue
817
-
818
- try:
819
- # Safely convert cached embedding
820
- cached_quantized = np.frombuffer(
821
- bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
822
- ).reshape(cache_data["embedding_shape"])
823
-
824
- # Ensure min_val and max_val are valid float values
825
- embedding_min = cache_data.get("embedding_min")
826
- embedding_max = cache_data.get("embedding_max")
827
-
828
- if (
829
- embedding_min is None
830
- or embedding_max is None
831
- or embedding_min >= embedding_max
832
- ):
833
- logger.warning(
834
- f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
835
- )
836
- continue
837
-
838
- cached_embedding = dequantize_embedding(
839
- cached_quantized,
840
- embedding_min,
841
- embedding_max,
842
- )
843
- except Exception as e:
844
- logger.warning(f"Error processing cached embedding: {str(e)}")
845
- continue
846
-
847
- similarity = cosine_similarity(current_embedding, cached_embedding)
848
- if similarity > best_similarity:
849
- best_similarity = similarity
850
- best_response = cache_data["return"]
851
- best_prompt = cache_data["original_prompt"]
852
- best_cache_id = cache_id
853
-
854
- if best_similarity > similarity_threshold:
855
- # If LLM check is enabled and all required parameters are provided
856
- if (
857
- use_llm_check
858
- and llm_func
859
- and original_prompt
860
- and best_prompt
861
- and best_response is not None
862
- ):
863
- compare_prompt = PROMPTS["similarity_check"].format(
864
- original_prompt=original_prompt, cached_prompt=best_prompt
865
- )
866
-
867
- try:
868
- llm_result = await llm_func(compare_prompt)
869
- llm_result = llm_result.strip()
870
- llm_similarity = float(llm_result)
871
-
872
- # Replace vector similarity with LLM similarity score
873
- best_similarity = llm_similarity
874
- if best_similarity < similarity_threshold:
875
- log_data = {
876
- "event": "cache_rejected_by_llm",
877
- "type": cache_type,
878
- "mode": mode,
879
- "original_question": original_prompt[:100] + "..."
880
- if len(original_prompt) > 100
881
- else original_prompt,
882
- "cached_question": best_prompt[:100] + "..."
883
- if len(best_prompt) > 100
884
- else best_prompt,
885
- "similarity_score": round(best_similarity, 4),
886
- "threshold": similarity_threshold,
887
- }
888
- logger.debug(json.dumps(log_data, ensure_ascii=False))
889
- logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
890
- return None
891
- except Exception as e: # Catch all possible exceptions
892
- logger.warning(f"LLM similarity check failed: {e}")
893
- return None # Return None directly when LLM check fails
894
-
895
- prompt_display = (
896
- best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
897
- )
898
- log_data = {
899
- "event": "cache_hit",
900
- "type": cache_type,
901
- "mode": mode,
902
- "similarity": round(best_similarity, 4),
903
- "cache_id": best_cache_id,
904
- "original_prompt": prompt_display,
905
- }
906
- logger.debug(json.dumps(log_data, ensure_ascii=False))
907
- return best_response
908
- return None
909
-
910
-
911
  def cosine_similarity(v1, v2):
912
  """Calculate cosine similarity between two vectors"""
913
  dot_product = np.dot(v1, v2)
@@ -957,7 +857,7 @@ async def handle_cache(
957
  mode="default",
958
  cache_type=None,
959
  ):
960
- """Generic cache handling function"""
961
  if hashing_kv is None:
962
  return None, None, None, None
963
 
@@ -968,15 +868,14 @@ async def handle_cache(
968
  if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
969
  return None, None, None, None
970
 
971
- if exists_func(hashing_kv, "get_by_mode_and_id"):
972
- mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
973
- else:
974
- mode_cache = await hashing_kv.get_by_id(mode) or {}
975
- if args_hash in mode_cache:
976
- logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
977
- return mode_cache[args_hash]["return"], None, None, None
978
 
979
- logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
980
  return None, None, None, None
981
 
982
 
@@ -994,7 +893,7 @@ class CacheData:
994
 
995
 
996
  async def save_to_cache(hashing_kv, cache_data: CacheData):
997
- """Save data to cache, with improved handling for streaming responses and duplicate content.
998
 
999
  Args:
1000
  hashing_kv: The key-value storage for caching
@@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
1009
  logger.debug("Streaming response detected, skipping cache")
1010
  return
1011
 
1012
- # Get existing cache data
1013
- if exists_func(hashing_kv, "get_by_mode_and_id"):
1014
- mode_cache = (
1015
- await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
1016
- or {}
1017
- )
1018
- else:
1019
- mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
1020
 
1021
  # Check if we already have identical content cached
1022
- if cache_data.args_hash in mode_cache:
1023
- existing_content = mode_cache[cache_data.args_hash].get("return")
 
1024
  if existing_content == cache_data.content:
1025
- logger.info(
1026
- f"Cache content unchanged for {cache_data.args_hash}, skipping update"
1027
- )
1028
  return
1029
 
1030
- # Update cache with new content
1031
- mode_cache[cache_data.args_hash] = {
1032
  "return": cache_data.content,
1033
  "cache_type": cache_data.cache_type,
1034
  "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
@@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
1043
  "original_prompt": cache_data.prompt,
1044
  }
1045
 
1046
- logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
1047
 
1048
- # Only upsert if there's actual new content
1049
- await hashing_kv.upsert({cache_data.mode: mode_cache})
1050
 
1051
 
1052
  def safe_unicode_decode(content):
@@ -1529,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
1529
  return import_class
1530
 
1531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1532
  async def use_llm_func_with_cache(
1533
  input_text: str,
1534
  use_llm_func: callable,
@@ -1537,6 +1473,7 @@ async def use_llm_func_with_cache(
1537
  history_messages: list[dict[str, str]] = None,
1538
  cache_type: str = "extract",
1539
  chunk_id: str | None = None,
 
1540
  ) -> str:
1541
  """Call LLM function with cache support
1542
 
@@ -1551,6 +1488,8 @@ async def use_llm_func_with_cache(
1551
  history_messages: History messages list
1552
  cache_type: Type of cache
1553
  chunk_id: Chunk identifier to store in cache
 
 
1554
 
1555
  Returns:
1556
  LLM response text
@@ -1563,6 +1502,9 @@ async def use_llm_func_with_cache(
1563
  _prompt = input_text
1564
 
1565
  arg_hash = compute_args_hash(_prompt)
 
 
 
1566
  cached_return, _1, _2, _3 = await handle_cache(
1567
  llm_response_cache,
1568
  arg_hash,
@@ -1573,6 +1515,11 @@ async def use_llm_func_with_cache(
1573
  if cached_return:
1574
  logger.debug(f"Found cache for {arg_hash}")
1575
  statistic_data["llm_cache"] += 1
 
 
 
 
 
1576
  return cached_return
1577
  statistic_data["llm_call"] += 1
1578
 
@@ -1597,6 +1544,10 @@ async def use_llm_func_with_cache(
1597
  ),
1598
  )
1599
 
 
 
 
 
1600
  return res
1601
 
1602
  # When cache is disabled, directly call LLM
 
14
  from hashlib import md5
15
  from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import numpy as np
 
17
  from dotenv import load_dotenv
18
  from lightrag.constants import (
19
  DEFAULT_LOG_MAX_BYTES,
 
277
  raise e from None
278
 
279
 
280
+ def compute_args_hash(*args: Any) -> str:
281
  """Compute a hash for the given arguments.
282
  Args:
283
  *args: Arguments to hash
 
284
  Returns:
285
  str: Hash string
286
  """
 
288
 
289
  # Convert all arguments to strings and join them
290
  args_str = "".join([str(arg) for arg in args])
 
 
291
 
292
  # Compute MD5 hash
293
  return hashlib.md5(args_str.encode()).hexdigest()
294
 
295
 
296
+ def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
297
+ """Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
298
+
299
+ Args:
300
+ mode: Cache mode (e.g., 'default', 'local', 'global')
301
+ cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
302
+ hash_value: Hash value from compute_args_hash
303
+
304
+ Returns:
305
+ str: Flattened cache key
306
+ """
307
+ return f"{mode}:{cache_type}:{hash_value}"
308
+
309
+
310
+ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
311
+ """Parse a flattened cache key back into its components
312
+
313
+ Args:
314
+ cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
315
+
316
+ Returns:
317
+ tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
318
+ """
319
+ parts = cache_key.split(":", 2)
320
+ if len(parts) == 3:
321
+ return parts[0], parts[1], parts[2]
322
+ return None
323
+
324
+
325
  def compute_mdhash_id(content: str, prefix: str = "") -> str:
326
  """
327
  Compute a unique ID for a given content string.
 
808
  return combined_data
809
 
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  def cosine_similarity(v1, v2):
812
  """Calculate cosine similarity between two vectors"""
813
  dot_product = np.dot(v1, v2)
 
857
  mode="default",
858
  cache_type=None,
859
  ):
860
+ """Generic cache handling function with flattened cache keys"""
861
  if hashing_kv is None:
862
  return None, None, None, None
863
 
 
868
  if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
869
  return None, None, None, None
870
 
871
+ # Use flattened cache key format: {mode}:{cache_type}:{hash}
872
+ flattened_key = generate_cache_key(mode, cache_type, args_hash)
873
+ cache_entry = await hashing_kv.get_by_id(flattened_key)
874
+ if cache_entry:
875
+ logger.debug(f"Flattened cache hit(key:{flattened_key})")
876
+ return cache_entry["return"], None, None, None
 
877
 
878
+ logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
879
  return None, None, None, None
880
 
881
 
 
893
 
894
 
895
  async def save_to_cache(hashing_kv, cache_data: CacheData):
896
+ """Save data to cache using flattened key structure.
897
 
898
  Args:
899
  hashing_kv: The key-value storage for caching
 
908
  logger.debug("Streaming response detected, skipping cache")
909
  return
910
 
911
+ # Use flattened cache key format: {mode}:{cache_type}:{hash}
912
+ flattened_key = generate_cache_key(
913
+ cache_data.mode, cache_data.cache_type, cache_data.args_hash
914
+ )
 
 
 
 
915
 
916
  # Check if we already have identical content cached
917
+ existing_cache = await hashing_kv.get_by_id(flattened_key)
918
+ if existing_cache:
919
+ existing_content = existing_cache.get("return")
920
  if existing_content == cache_data.content:
921
+ logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
 
 
922
  return
923
 
924
+ # Create cache entry with flattened structure
925
+ cache_entry = {
926
  "return": cache_data.content,
927
  "cache_type": cache_data.cache_type,
928
  "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
 
937
  "original_prompt": cache_data.prompt,
938
  }
939
 
940
+ logger.info(f" == LLM cache == saving: {flattened_key}")
941
 
942
+ # Save using flattened key
943
+ await hashing_kv.upsert({flattened_key: cache_entry})
944
 
945
 
946
  def safe_unicode_decode(content):
 
1423
  return import_class
1424
 
1425
 
1426
+ async def update_chunk_cache_list(
1427
+ chunk_id: str,
1428
+ text_chunks_storage: "BaseKVStorage",
1429
+ cache_keys: list[str],
1430
+ cache_scenario: str = "batch_update",
1431
+ ) -> None:
1432
+ """Update chunk's llm_cache_list with the given cache keys
1433
+
1434
+ Args:
1435
+ chunk_id: Chunk identifier
1436
+ text_chunks_storage: Text chunks storage instance
1437
+ cache_keys: List of cache keys to add to the list
1438
+ cache_scenario: Description of the cache scenario for logging
1439
+ """
1440
+ if not cache_keys:
1441
+ return
1442
+
1443
+ try:
1444
+ chunk_data = await text_chunks_storage.get_by_id(chunk_id)
1445
+ if chunk_data:
1446
+ # Ensure llm_cache_list exists
1447
+ if "llm_cache_list" not in chunk_data:
1448
+ chunk_data["llm_cache_list"] = []
1449
+
1450
+ # Add cache keys to the list if not already present
1451
+ existing_keys = set(chunk_data["llm_cache_list"])
1452
+ new_keys = [key for key in cache_keys if key not in existing_keys]
1453
+
1454
+ if new_keys:
1455
+ chunk_data["llm_cache_list"].extend(new_keys)
1456
+
1457
+ # Update the chunk in storage
1458
+ await text_chunks_storage.upsert({chunk_id: chunk_data})
1459
+ logger.debug(
1460
+ f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
1461
+ )
1462
+ except Exception as e:
1463
+ logger.warning(
1464
+ f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
1465
+ )
1466
+
1467
+
1468
  async def use_llm_func_with_cache(
1469
  input_text: str,
1470
  use_llm_func: callable,
 
1473
  history_messages: list[dict[str, str]] = None,
1474
  cache_type: str = "extract",
1475
  chunk_id: str | None = None,
1476
+ cache_keys_collector: list = None,
1477
  ) -> str:
1478
  """Call LLM function with cache support
1479
 
 
1488
  history_messages: History messages list
1489
  cache_type: Type of cache
1490
  chunk_id: Chunk identifier to store in cache
1491
+ text_chunks_storage: Text chunks storage to update llm_cache_list
1492
+ cache_keys_collector: Optional list to collect cache keys for batch processing
1493
 
1494
  Returns:
1495
  LLM response text
 
1502
  _prompt = input_text
1503
 
1504
  arg_hash = compute_args_hash(_prompt)
1505
+ # Generate cache key for this LLM call
1506
+ cache_key = generate_cache_key("default", cache_type, arg_hash)
1507
+
1508
  cached_return, _1, _2, _3 = await handle_cache(
1509
  llm_response_cache,
1510
  arg_hash,
 
1515
  if cached_return:
1516
  logger.debug(f"Found cache for {arg_hash}")
1517
  statistic_data["llm_cache"] += 1
1518
+
1519
+ # Add cache key to collector if provided
1520
+ if cache_keys_collector is not None:
1521
+ cache_keys_collector.append(cache_key)
1522
+
1523
  return cached_return
1524
  statistic_data["llm_call"] += 1
1525
 
 
1544
  ),
1545
  )
1546
 
1547
+ # Add cache key to collector if provided
1548
+ if cache_keys_collector is not None:
1549
+ cache_keys_collector.append(cache_key)
1550
+
1551
  return res
1552
 
1553
  # When cache is disabled, directly call LLM
lightrag/utils_graph.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, cast
6
 
7
  from .base import DeletionResult
8
  from .kg.shared_storage import get_graph_db_lock
9
- from .prompt import GRAPH_FIELD_SEP
10
  from .utils import compute_mdhash_id, logger
11
  from .base import StorageNameSpace
12
 
 
6
 
7
  from .base import DeletionResult
8
  from .kg.shared_storage import get_graph_db_lock
9
+ from .constants import GRAPH_FIELD_SEP
10
  from .utils import compute_mdhash_id, logger
11
  from .base import StorageNameSpace
12
 
reproduce/batch_eval.py CHANGED
@@ -57,6 +57,10 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
57
  "Winner": "[Answer 1 or Answer 2]",
58
  "Explanation": "[Provide explanation here]"
59
  }},
 
 
 
 
60
  "Empowerment": {{
61
  "Winner": "[Answer 1 or Answer 2]",
62
  "Explanation": "[Provide explanation here]"
 
57
  "Winner": "[Answer 1 or Answer 2]",
58
  "Explanation": "[Provide explanation here]"
59
  }},
60
+ "Diversity": {{
61
+ "Winner": "[Answer 1 or Answer 2]",
62
+ "Explanation": "[Provide explanation here]"
63
+ }},
64
  "Empowerment": {{
65
  "Winner": "[Answer 1 or Answer 2]",
66
  "Explanation": "[Provide explanation here]"
tests/test_graph_storage.py CHANGED
@@ -8,6 +8,7 @@
8
  ζ”―ζŒηš„ε›Ύε­˜ε‚¨η±»εž‹εŒ…ζ‹¬οΌš
9
  - NetworkXStorage
10
  - Neo4JStorage
 
11
  - PGGraphStorage
12
  - MemgraphStorage
13
  """
 
8
  ζ”―ζŒηš„ε›Ύε­˜ε‚¨η±»εž‹εŒ…ζ‹¬οΌš
9
  - NetworkXStorage
10
  - Neo4JStorage
11
+ - MongoDBStorage
12
  - PGGraphStorage
13
  - MemgraphStorage
14
  """