gzdaniel commited on
Commit
a8003f1
·
1 Parent(s): 9ca235a

Refac: Optimize document deletion performance

Browse files

- Adding chunks_list to dock_status
- Adding llm_cache_list to text_chunks
- Implemented storage types: JsonKV and Redis

lightrag/base.py CHANGED
@@ -634,6 +634,8 @@ class DocProcessingStatus:
634
  """ISO format timestamp when document was last updated"""
635
  chunks_count: int | None = None
636
  """Number of chunks after splitting, used for processing"""
 
 
637
  error: str | None = None
638
  """Error message if failed"""
639
  metadata: dict[str, Any] = field(default_factory=dict)
 
634
  """ISO format timestamp when document was last updated"""
635
  chunks_count: int | None = None
636
  """Number of chunks after splitting, used for processing"""
637
+ chunks_list: list[str] | None = field(default_factory=list)
638
+ """List of chunk IDs associated with this document, used for deletion"""
639
  error: str | None = None
640
  """Error message if failed"""
641
  metadata: dict[str, Any] = field(default_factory=dict)
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -118,6 +118,10 @@ class JsonDocStatusStorage(DocStatusStorage):
118
  return
119
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
 
 
 
 
121
  self._data.update(data)
122
  await set_all_update_flags(self.namespace)
123
 
 
118
  return
119
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
121
+ # Ensure chunks_list field exists for new documents
122
+ for doc_id, doc_data in data.items():
123
+ if "chunks_list" not in doc_data:
124
+ doc_data["chunks_list"] = []
125
  self._data.update(data)
126
  await set_all_update_flags(self.namespace)
127
 
lightrag/kg/json_kv_impl.py CHANGED
@@ -109,6 +109,11 @@ class JsonKVStorage(BaseKVStorage):
109
  return
110
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
111
  async with self._storage_lock:
 
 
 
 
 
112
  self._data.update(data)
113
  await set_all_update_flags(self.namespace)
114
 
 
109
  return
110
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
111
  async with self._storage_lock:
112
+ # For text_chunks namespace, ensure llm_cache_list field exists
113
+ if "text_chunks" in self.namespace:
114
+ for chunk_id, chunk_data in data.items():
115
+ if "llm_cache_list" not in chunk_data:
116
+ chunk_data["llm_cache_list"] = []
117
  self._data.update(data)
118
  await set_all_update_flags(self.namespace)
119
 
lightrag/kg/redis_impl.py CHANGED
@@ -202,6 +202,12 @@ class RedisKVStorage(BaseKVStorage):
202
  return
203
  async with self._get_redis_connection() as redis:
204
  try:
 
 
 
 
 
 
205
  pipe = redis.pipeline()
206
  for k, v in data.items():
207
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -601,6 +607,11 @@ class RedisDocStatusStorage(DocStatusStorage):
601
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
602
  async with self._get_redis_connection() as redis:
603
  try:
 
 
 
 
 
604
  pipe = redis.pipeline()
605
  for k, v in data.items():
606
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
 
202
  return
203
  async with self._get_redis_connection() as redis:
204
  try:
205
+ # For text_chunks namespace, ensure llm_cache_list field exists
206
+ if "text_chunks" in self.namespace:
207
+ for chunk_id, chunk_data in data.items():
208
+ if "llm_cache_list" not in chunk_data:
209
+ chunk_data["llm_cache_list"] = []
210
+
211
  pipe = redis.pipeline()
212
  for k, v in data.items():
213
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
 
607
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
608
  async with self._get_redis_connection() as redis:
609
  try:
610
+ # Ensure chunks_list field exists for new documents
611
+ for doc_id, doc_data in data.items():
612
+ if "chunks_list" not in doc_data:
613
+ doc_data["chunks_list"] = []
614
+
615
  pipe = redis.pipeline()
616
  for k, v in data.items():
617
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
lightrag/lightrag.py CHANGED
@@ -349,6 +349,7 @@ class LightRAG:
349
 
350
  # Fix global_config now
351
  global_config = asdict(self)
 
352
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
353
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
354
 
@@ -952,6 +953,7 @@ class LightRAG:
952
  **dp,
953
  "full_doc_id": doc_id,
954
  "file_path": file_path, # Add file path to each chunk
 
955
  }
956
  for dp in self.chunking_func(
957
  self.tokenizer,
@@ -963,14 +965,17 @@ class LightRAG:
963
  )
964
  }
965
 
966
- # Process document (text chunks and full docs) in parallel
967
- # Create tasks with references for potential cancellation
968
  doc_status_task = asyncio.create_task(
969
  self.doc_status.upsert(
970
  {
971
  doc_id: {
972
  "status": DocStatus.PROCESSING,
973
  "chunks_count": len(chunks),
 
 
 
974
  "content": status_doc.content,
975
  "content_summary": status_doc.content_summary,
976
  "content_length": status_doc.content_length,
@@ -986,11 +991,6 @@ class LightRAG:
986
  chunks_vdb_task = asyncio.create_task(
987
  self.chunks_vdb.upsert(chunks)
988
  )
989
- entity_relation_task = asyncio.create_task(
990
- self._process_entity_relation_graph(
991
- chunks, pipeline_status, pipeline_status_lock
992
- )
993
- )
994
  full_docs_task = asyncio.create_task(
995
  self.full_docs.upsert(
996
  {doc_id: {"content": status_doc.content}}
@@ -999,14 +999,26 @@ class LightRAG:
999
  text_chunks_task = asyncio.create_task(
1000
  self.text_chunks.upsert(chunks)
1001
  )
1002
- tasks = [
 
 
1003
  doc_status_task,
1004
  chunks_vdb_task,
1005
- entity_relation_task,
1006
  full_docs_task,
1007
  text_chunks_task,
1008
  ]
1009
- await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
1010
  file_extraction_stage_ok = True
1011
 
1012
  except Exception as e:
@@ -1021,14 +1033,14 @@ class LightRAG:
1021
  )
1022
  pipeline_status["history_messages"].append(error_msg)
1023
 
1024
- # Cancel other tasks as they are no longer meaningful
1025
- for task in [
1026
- chunks_vdb_task,
1027
- entity_relation_task,
1028
- full_docs_task,
1029
- text_chunks_task,
1030
- ]:
1031
- if not task.done():
1032
  task.cancel()
1033
 
1034
  # Persistent llm cache
@@ -1078,6 +1090,9 @@ class LightRAG:
1078
  doc_id: {
1079
  "status": DocStatus.PROCESSED,
1080
  "chunks_count": len(chunks),
 
 
 
1081
  "content": status_doc.content,
1082
  "content_summary": status_doc.content_summary,
1083
  "content_length": status_doc.content_length,
@@ -1196,6 +1211,7 @@ class LightRAG:
1196
  pipeline_status=pipeline_status,
1197
  pipeline_status_lock=pipeline_status_lock,
1198
  llm_response_cache=self.llm_response_cache,
 
1199
  )
1200
  return chunk_results
1201
  except Exception as e:
@@ -1726,28 +1742,10 @@ class LightRAG:
1726
  file_path="",
1727
  )
1728
 
1729
- # 2. Get all chunks related to this document
1730
- try:
1731
- all_chunks = await self.text_chunks.get_all()
1732
- related_chunks = {
1733
- chunk_id: chunk_data
1734
- for chunk_id, chunk_data in all_chunks.items()
1735
- if isinstance(chunk_data, dict)
1736
- and chunk_data.get("full_doc_id") == doc_id
1737
- }
1738
-
1739
- # Update pipeline status after getting chunks count
1740
- async with pipeline_status_lock:
1741
- log_message = f"Retrieved {len(related_chunks)} of {len(all_chunks)} related chunks"
1742
- logger.info(log_message)
1743
- pipeline_status["latest_message"] = log_message
1744
- pipeline_status["history_messages"].append(log_message)
1745
-
1746
- except Exception as e:
1747
- logger.error(f"Failed to retrieve chunks for document {doc_id}: {e}")
1748
- raise Exception(f"Failed to retrieve document chunks: {e}") from e
1749
 
1750
- if not related_chunks:
1751
  logger.warning(f"No chunks found for document {doc_id}")
1752
  # Mark that deletion operations have started
1753
  deletion_operations_started = True
@@ -1778,7 +1776,6 @@ class LightRAG:
1778
  file_path=file_path,
1779
  )
1780
 
1781
- chunk_ids = set(related_chunks.keys())
1782
  # Mark that deletion operations have started
1783
  deletion_operations_started = True
1784
 
@@ -1943,7 +1940,7 @@ class LightRAG:
1943
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1944
  entities_vdb=self.entities_vdb,
1945
  relationships_vdb=self.relationships_vdb,
1946
- text_chunks=self.text_chunks,
1947
  llm_response_cache=self.llm_response_cache,
1948
  global_config=asdict(self),
1949
  pipeline_status=pipeline_status,
 
349
 
350
  # Fix global_config now
351
  global_config = asdict(self)
352
+
353
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
354
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
355
 
 
953
  **dp,
954
  "full_doc_id": doc_id,
955
  "file_path": file_path, # Add file path to each chunk
956
+ "llm_cache_list": [], # Initialize empty LLM cache list for each chunk
957
  }
958
  for dp in self.chunking_func(
959
  self.tokenizer,
 
965
  )
966
  }
967
 
968
+ # Process document in two stages
969
+ # Stage 1: Process text chunks and docs (parallel execution)
970
  doc_status_task = asyncio.create_task(
971
  self.doc_status.upsert(
972
  {
973
  doc_id: {
974
  "status": DocStatus.PROCESSING,
975
  "chunks_count": len(chunks),
976
+ "chunks_list": list(
977
+ chunks.keys()
978
+ ), # Save chunks list
979
  "content": status_doc.content,
980
  "content_summary": status_doc.content_summary,
981
  "content_length": status_doc.content_length,
 
991
  chunks_vdb_task = asyncio.create_task(
992
  self.chunks_vdb.upsert(chunks)
993
  )
 
 
 
 
 
994
  full_docs_task = asyncio.create_task(
995
  self.full_docs.upsert(
996
  {doc_id: {"content": status_doc.content}}
 
999
  text_chunks_task = asyncio.create_task(
1000
  self.text_chunks.upsert(chunks)
1001
  )
1002
+
1003
+ # First stage tasks (parallel execution)
1004
+ first_stage_tasks = [
1005
  doc_status_task,
1006
  chunks_vdb_task,
 
1007
  full_docs_task,
1008
  text_chunks_task,
1009
  ]
1010
+ entity_relation_task = None
1011
+
1012
+ # Execute first stage tasks
1013
+ await asyncio.gather(*first_stage_tasks)
1014
+
1015
+ # Stage 2: Process entity relation graph (after text_chunks are saved)
1016
+ entity_relation_task = asyncio.create_task(
1017
+ self._process_entity_relation_graph(
1018
+ chunks, pipeline_status, pipeline_status_lock
1019
+ )
1020
+ )
1021
+ await entity_relation_task
1022
  file_extraction_stage_ok = True
1023
 
1024
  except Exception as e:
 
1033
  )
1034
  pipeline_status["history_messages"].append(error_msg)
1035
 
1036
+ # Cancel tasks that are not yet completed
1037
+ all_tasks = first_stage_tasks + (
1038
+ [entity_relation_task]
1039
+ if entity_relation_task
1040
+ else []
1041
+ )
1042
+ for task in all_tasks:
1043
+ if task and not task.done():
1044
  task.cancel()
1045
 
1046
  # Persistent llm cache
 
1090
  doc_id: {
1091
  "status": DocStatus.PROCESSED,
1092
  "chunks_count": len(chunks),
1093
+ "chunks_list": list(
1094
+ chunks.keys()
1095
+ ), # 保留 chunks_list
1096
  "content": status_doc.content,
1097
  "content_summary": status_doc.content_summary,
1098
  "content_length": status_doc.content_length,
 
1211
  pipeline_status=pipeline_status,
1212
  pipeline_status_lock=pipeline_status_lock,
1213
  llm_response_cache=self.llm_response_cache,
1214
+ text_chunks_storage=self.text_chunks,
1215
  )
1216
  return chunk_results
1217
  except Exception as e:
 
1742
  file_path="",
1743
  )
1744
 
1745
+ # 2. Get chunk IDs from document status
1746
+ chunk_ids = set(doc_status_data.get("chunks_list", []))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1747
 
1748
+ if not chunk_ids:
1749
  logger.warning(f"No chunks found for document {doc_id}")
1750
  # Mark that deletion operations have started
1751
  deletion_operations_started = True
 
1776
  file_path=file_path,
1777
  )
1778
 
 
1779
  # Mark that deletion operations have started
1780
  deletion_operations_started = True
1781
 
 
1940
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1941
  entities_vdb=self.entities_vdb,
1942
  relationships_vdb=self.relationships_vdb,
1943
+ text_chunks_storage=self.text_chunks,
1944
  llm_response_cache=self.llm_response_cache,
1945
  global_config=asdict(self),
1946
  pipeline_status=pipeline_status,
lightrag/operate.py CHANGED
@@ -25,6 +25,7 @@ from .utils import (
25
  CacheData,
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
 
28
  )
29
  from .base import (
30
  BaseGraphStorage,
@@ -103,8 +104,6 @@ async def _handle_entity_relation_summary(
103
  entity_or_relation_name: str,
104
  description: str,
105
  global_config: dict,
106
- pipeline_status: dict = None,
107
- pipeline_status_lock=None,
108
  llm_response_cache: BaseKVStorage | None = None,
109
  ) -> str:
110
  """Handle entity relation summary
@@ -247,7 +246,7 @@ async def _rebuild_knowledge_from_chunks(
247
  knowledge_graph_inst: BaseGraphStorage,
248
  entities_vdb: BaseVectorStorage,
249
  relationships_vdb: BaseVectorStorage,
250
- text_chunks: BaseKVStorage,
251
  llm_response_cache: BaseKVStorage,
252
  global_config: dict[str, str],
253
  pipeline_status: dict | None = None,
@@ -261,6 +260,7 @@ async def _rebuild_knowledge_from_chunks(
261
  Args:
262
  entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
263
  relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
 
264
  """
265
  if not entities_to_rebuild and not relationships_to_rebuild:
266
  return
@@ -273,6 +273,8 @@ async def _rebuild_knowledge_from_chunks(
273
  all_referenced_chunk_ids.update(chunk_ids)
274
  for chunk_ids in relationships_to_rebuild.values():
275
  all_referenced_chunk_ids.update(chunk_ids)
 
 
276
 
277
  status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
278
  logger.info(status_message)
@@ -281,9 +283,11 @@ async def _rebuild_knowledge_from_chunks(
281
  pipeline_status["latest_message"] = status_message
282
  pipeline_status["history_messages"].append(status_message)
283
 
284
- # Get cached extraction results for these chunks
285
  cached_results = await _get_cached_extraction_results(
286
- llm_response_cache, all_referenced_chunk_ids
 
 
287
  )
288
 
289
  if not cached_results:
@@ -299,15 +303,25 @@ async def _rebuild_knowledge_from_chunks(
299
  chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
300
  chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
301
 
302
- for chunk_id, extraction_result in cached_results.items():
303
  try:
304
- entities, relationships = await _parse_extraction_result(
305
- text_chunks=text_chunks,
306
- extraction_result=extraction_result,
307
- chunk_id=chunk_id,
308
- )
309
- chunk_entities[chunk_id] = entities
310
- chunk_relationships[chunk_id] = relationships
 
 
 
 
 
 
 
 
 
 
311
  except Exception as e:
312
  status_message = (
313
  f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
@@ -387,43 +401,76 @@ async def _rebuild_knowledge_from_chunks(
387
 
388
 
389
  async def _get_cached_extraction_results(
390
- llm_response_cache: BaseKVStorage, chunk_ids: set[str]
391
- ) -> dict[str, str]:
 
 
392
  """Get cached extraction results for specific chunk IDs
393
 
394
  Args:
 
395
  chunk_ids: Set of chunk IDs to get cached results for
 
 
396
 
397
  Returns:
398
- Dict mapping chunk_id -> extraction_result_text
399
  """
400
  cached_results = {}
401
 
402
- # Get all cached data (flattened cache structure)
403
- all_cache = await llm_response_cache.get_all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
- for cache_key, cache_entry in all_cache.items():
 
 
 
 
 
406
  if (
407
- isinstance(cache_entry, dict)
 
408
  and cache_entry.get("cache_type") == "extract"
409
  and cache_entry.get("chunk_id") in chunk_ids
410
  ):
411
  chunk_id = cache_entry["chunk_id"]
412
  extraction_result = cache_entry["return"]
413
- cached_results[chunk_id] = extraction_result
414
 
415
- logger.debug(
416
- f"Found {len(cached_results)} cached extraction results for {len(chunk_ids)} chunk IDs"
 
 
 
 
 
417
  )
418
  return cached_results
419
 
420
 
421
  async def _parse_extraction_result(
422
- text_chunks: BaseKVStorage, extraction_result: str, chunk_id: str
423
  ) -> tuple[dict, dict]:
424
  """Parse cached extraction result using the same logic as extract_entities
425
 
426
  Args:
 
427
  extraction_result: The cached LLM extraction result
428
  chunk_id: The chunk ID for source tracking
429
 
@@ -431,8 +478,8 @@ async def _parse_extraction_result(
431
  Tuple of (entities_dict, relationships_dict)
432
  """
433
 
434
- # Get chunk data for file_path
435
- chunk_data = await text_chunks.get_by_id(chunk_id)
436
  file_path = (
437
  chunk_data.get("file_path", "unknown_source")
438
  if chunk_data
@@ -805,8 +852,6 @@ async def _merge_nodes_then_upsert(
805
  entity_name,
806
  description,
807
  global_config,
808
- pipeline_status,
809
- pipeline_status_lock,
810
  llm_response_cache,
811
  )
812
  else:
@@ -969,8 +1014,6 @@ async def _merge_edges_then_upsert(
969
  f"({src_id}, {tgt_id})",
970
  description,
971
  global_config,
972
- pipeline_status,
973
- pipeline_status_lock,
974
  llm_response_cache,
975
  )
976
  else:
@@ -1146,6 +1189,7 @@ async def extract_entities(
1146
  pipeline_status: dict = None,
1147
  pipeline_status_lock=None,
1148
  llm_response_cache: BaseKVStorage | None = None,
 
1149
  ) -> list:
1150
  use_llm_func: callable = global_config["llm_model_func"]
1151
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -1252,6 +1296,9 @@ async def extract_entities(
1252
  # Get file path from chunk data or use default
1253
  file_path = chunk_dp.get("file_path", "unknown_source")
1254
 
 
 
 
1255
  # Get initial extraction
1256
  hint_prompt = entity_extract_prompt.format(
1257
  **{**context_base, "input_text": content}
@@ -1263,7 +1310,10 @@ async def extract_entities(
1263
  llm_response_cache=llm_response_cache,
1264
  cache_type="extract",
1265
  chunk_id=chunk_key,
 
1266
  )
 
 
1267
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
1268
 
1269
  # Process initial extraction with file path
@@ -1280,6 +1330,7 @@ async def extract_entities(
1280
  history_messages=history,
1281
  cache_type="extract",
1282
  chunk_id=chunk_key,
 
1283
  )
1284
 
1285
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -1310,11 +1361,21 @@ async def extract_entities(
1310
  llm_response_cache=llm_response_cache,
1311
  history_messages=history,
1312
  cache_type="extract",
 
1313
  )
1314
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
1315
  if if_loop_result != "yes":
1316
  break
1317
 
 
 
 
 
 
 
 
 
 
1318
  processed_chunks += 1
1319
  entities_count = len(maybe_nodes)
1320
  relations_count = len(maybe_edges)
 
25
  CacheData,
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
28
+ update_chunk_cache_list,
29
  )
30
  from .base import (
31
  BaseGraphStorage,
 
104
  entity_or_relation_name: str,
105
  description: str,
106
  global_config: dict,
 
 
107
  llm_response_cache: BaseKVStorage | None = None,
108
  ) -> str:
109
  """Handle entity relation summary
 
246
  knowledge_graph_inst: BaseGraphStorage,
247
  entities_vdb: BaseVectorStorage,
248
  relationships_vdb: BaseVectorStorage,
249
+ text_chunks_storage: BaseKVStorage,
250
  llm_response_cache: BaseKVStorage,
251
  global_config: dict[str, str],
252
  pipeline_status: dict | None = None,
 
260
  Args:
261
  entities_to_rebuild: Dict mapping entity_name -> set of remaining chunk_ids
262
  relationships_to_rebuild: Dict mapping (src, tgt) -> set of remaining chunk_ids
263
+ text_chunks_data: Pre-loaded chunk data dict {chunk_id: chunk_data}
264
  """
265
  if not entities_to_rebuild and not relationships_to_rebuild:
266
  return
 
273
  all_referenced_chunk_ids.update(chunk_ids)
274
  for chunk_ids in relationships_to_rebuild.values():
275
  all_referenced_chunk_ids.update(chunk_ids)
276
+ # sort all_referenced_chunk_ids to get a stable order in merge stage
277
+ all_referenced_chunk_ids = sorted(all_referenced_chunk_ids)
278
 
279
  status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
280
  logger.info(status_message)
 
283
  pipeline_status["latest_message"] = status_message
284
  pipeline_status["history_messages"].append(status_message)
285
 
286
+ # Get cached extraction results for these chunks using storage
287
  cached_results = await _get_cached_extraction_results(
288
+ llm_response_cache,
289
+ all_referenced_chunk_ids,
290
+ text_chunks_storage=text_chunks_storage,
291
  )
292
 
293
  if not cached_results:
 
303
  chunk_entities = {} # chunk_id -> {entity_name: [entity_data]}
304
  chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]}
305
 
306
+ for chunk_id, extraction_results in cached_results.items():
307
  try:
308
+ # Handle multiple extraction results per chunk
309
+ chunk_entities[chunk_id] = defaultdict(list)
310
+ chunk_relationships[chunk_id] = defaultdict(list)
311
+
312
+ for extraction_result in extraction_results:
313
+ entities, relationships = await _parse_extraction_result(
314
+ text_chunks_storage=text_chunks_storage,
315
+ extraction_result=extraction_result,
316
+ chunk_id=chunk_id,
317
+ )
318
+
319
+ # Merge entities and relationships from this extraction result
320
+ for entity_name, entity_list in entities.items():
321
+ chunk_entities[chunk_id][entity_name].extend(entity_list)
322
+ for rel_key, rel_list in relationships.items():
323
+ chunk_relationships[chunk_id][rel_key].extend(rel_list)
324
+
325
  except Exception as e:
326
  status_message = (
327
  f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
 
401
 
402
 
403
  async def _get_cached_extraction_results(
404
+ llm_response_cache: BaseKVStorage,
405
+ chunk_ids: set[str],
406
+ text_chunks_storage: BaseKVStorage,
407
+ ) -> dict[str, list[str]]:
408
  """Get cached extraction results for specific chunk IDs
409
 
410
  Args:
411
+ llm_response_cache: LLM response cache storage
412
  chunk_ids: Set of chunk IDs to get cached results for
413
+ text_chunks_data: Pre-loaded chunk data (optional, for performance)
414
+ text_chunks_storage: Text chunks storage (fallback if text_chunks_data is None)
415
 
416
  Returns:
417
+ Dict mapping chunk_id -> list of extraction_result_text
418
  """
419
  cached_results = {}
420
 
421
+ # Collect all LLM cache IDs from chunks
422
+ all_cache_ids = set()
423
+
424
+ # Read from storage
425
+ chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
426
+ for chunk_id, chunk_data in zip(chunk_ids, chunk_data_list):
427
+ if chunk_data and isinstance(chunk_data, dict):
428
+ llm_cache_list = chunk_data.get("llm_cache_list", [])
429
+ if llm_cache_list:
430
+ all_cache_ids.update(llm_cache_list)
431
+ else:
432
+ logger.warning(
433
+ f"Chunk {chunk_id} data is invalid or None: {type(chunk_data)}"
434
+ )
435
+
436
+ if not all_cache_ids:
437
+ logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
438
+ return cached_results
439
 
440
+ # Batch get LLM cache entries
441
+ cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids))
442
+
443
+ # Process cache entries and group by chunk_id
444
+ valid_entries = 0
445
+ for cache_id, cache_entry in zip(all_cache_ids, cache_data_list):
446
  if (
447
+ cache_entry is not None
448
+ and isinstance(cache_entry, dict)
449
  and cache_entry.get("cache_type") == "extract"
450
  and cache_entry.get("chunk_id") in chunk_ids
451
  ):
452
  chunk_id = cache_entry["chunk_id"]
453
  extraction_result = cache_entry["return"]
454
+ valid_entries += 1
455
 
456
+ # Support multiple LLM caches per chunk
457
+ if chunk_id not in cached_results:
458
+ cached_results[chunk_id] = []
459
+ cached_results[chunk_id].append(extraction_result)
460
+
461
+ logger.info(
462
+ f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
463
  )
464
  return cached_results
465
 
466
 
467
  async def _parse_extraction_result(
468
+ text_chunks_storage: BaseKVStorage, extraction_result: str, chunk_id: str
469
  ) -> tuple[dict, dict]:
470
  """Parse cached extraction result using the same logic as extract_entities
471
 
472
  Args:
473
+ text_chunks_storage: Text chunks storage to get chunk data
474
  extraction_result: The cached LLM extraction result
475
  chunk_id: The chunk ID for source tracking
476
 
 
478
  Tuple of (entities_dict, relationships_dict)
479
  """
480
 
481
+ # Get chunk data for file_path from storage
482
+ chunk_data = await text_chunks_storage.get_by_id(chunk_id)
483
  file_path = (
484
  chunk_data.get("file_path", "unknown_source")
485
  if chunk_data
 
852
  entity_name,
853
  description,
854
  global_config,
 
 
855
  llm_response_cache,
856
  )
857
  else:
 
1014
  f"({src_id}, {tgt_id})",
1015
  description,
1016
  global_config,
 
 
1017
  llm_response_cache,
1018
  )
1019
  else:
 
1189
  pipeline_status: dict = None,
1190
  pipeline_status_lock=None,
1191
  llm_response_cache: BaseKVStorage | None = None,
1192
+ text_chunks_storage: BaseKVStorage | None = None,
1193
  ) -> list:
1194
  use_llm_func: callable = global_config["llm_model_func"]
1195
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
 
1296
  # Get file path from chunk data or use default
1297
  file_path = chunk_dp.get("file_path", "unknown_source")
1298
 
1299
+ # Create cache keys collector for batch processing
1300
+ cache_keys_collector = []
1301
+
1302
  # Get initial extraction
1303
  hint_prompt = entity_extract_prompt.format(
1304
  **{**context_base, "input_text": content}
 
1310
  llm_response_cache=llm_response_cache,
1311
  cache_type="extract",
1312
  chunk_id=chunk_key,
1313
+ cache_keys_collector=cache_keys_collector,
1314
  )
1315
+
1316
+ # Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
1317
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
1318
 
1319
  # Process initial extraction with file path
 
1330
  history_messages=history,
1331
  cache_type="extract",
1332
  chunk_id=chunk_key,
1333
+ cache_keys_collector=cache_keys_collector,
1334
  )
1335
 
1336
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
 
1361
  llm_response_cache=llm_response_cache,
1362
  history_messages=history,
1363
  cache_type="extract",
1364
+ cache_keys_collector=cache_keys_collector,
1365
  )
1366
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
1367
  if if_loop_result != "yes":
1368
  break
1369
 
1370
+ # Batch update chunk's llm_cache_list with all collected cache keys
1371
+ if cache_keys_collector and text_chunks_storage:
1372
+ await update_chunk_cache_list(
1373
+ chunk_key,
1374
+ text_chunks_storage,
1375
+ cache_keys_collector,
1376
+ "entity_extraction",
1377
+ )
1378
+
1379
  processed_chunks += 1
1380
  entities_count = len(maybe_nodes)
1381
  relations_count = len(maybe_edges)
lightrag/utils.py CHANGED
@@ -1423,6 +1423,48 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
1423
  return import_class
1424
 
1425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1426
  async def use_llm_func_with_cache(
1427
  input_text: str,
1428
  use_llm_func: callable,
@@ -1431,6 +1473,7 @@ async def use_llm_func_with_cache(
1431
  history_messages: list[dict[str, str]] = None,
1432
  cache_type: str = "extract",
1433
  chunk_id: str | None = None,
 
1434
  ) -> str:
1435
  """Call LLM function with cache support
1436
 
@@ -1445,6 +1488,8 @@ async def use_llm_func_with_cache(
1445
  history_messages: History messages list
1446
  cache_type: Type of cache
1447
  chunk_id: Chunk identifier to store in cache
 
 
1448
 
1449
  Returns:
1450
  LLM response text
@@ -1457,6 +1502,9 @@ async def use_llm_func_with_cache(
1457
  _prompt = input_text
1458
 
1459
  arg_hash = compute_args_hash(_prompt)
 
 
 
1460
  cached_return, _1, _2, _3 = await handle_cache(
1461
  llm_response_cache,
1462
  arg_hash,
@@ -1467,6 +1515,11 @@ async def use_llm_func_with_cache(
1467
  if cached_return:
1468
  logger.debug(f"Found cache for {arg_hash}")
1469
  statistic_data["llm_cache"] += 1
 
 
 
 
 
1470
  return cached_return
1471
  statistic_data["llm_call"] += 1
1472
 
@@ -1491,6 +1544,10 @@ async def use_llm_func_with_cache(
1491
  ),
1492
  )
1493
 
 
 
 
 
1494
  return res
1495
 
1496
  # When cache is disabled, directly call LLM
 
1423
  return import_class
1424
 
1425
 
1426
+ async def update_chunk_cache_list(
1427
+ chunk_id: str,
1428
+ text_chunks_storage: "BaseKVStorage",
1429
+ cache_keys: list[str],
1430
+ cache_scenario: str = "batch_update",
1431
+ ) -> None:
1432
+ """Update chunk's llm_cache_list with the given cache keys
1433
+
1434
+ Args:
1435
+ chunk_id: Chunk identifier
1436
+ text_chunks_storage: Text chunks storage instance
1437
+ cache_keys: List of cache keys to add to the list
1438
+ cache_scenario: Description of the cache scenario for logging
1439
+ """
1440
+ if not cache_keys:
1441
+ return
1442
+
1443
+ try:
1444
+ chunk_data = await text_chunks_storage.get_by_id(chunk_id)
1445
+ if chunk_data:
1446
+ # Ensure llm_cache_list exists
1447
+ if "llm_cache_list" not in chunk_data:
1448
+ chunk_data["llm_cache_list"] = []
1449
+
1450
+ # Add cache keys to the list if not already present
1451
+ existing_keys = set(chunk_data["llm_cache_list"])
1452
+ new_keys = [key for key in cache_keys if key not in existing_keys]
1453
+
1454
+ if new_keys:
1455
+ chunk_data["llm_cache_list"].extend(new_keys)
1456
+
1457
+ # Update the chunk in storage
1458
+ await text_chunks_storage.upsert({chunk_id: chunk_data})
1459
+ logger.debug(
1460
+ f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})"
1461
+ )
1462
+ except Exception as e:
1463
+ logger.warning(
1464
+ f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}"
1465
+ )
1466
+
1467
+
1468
  async def use_llm_func_with_cache(
1469
  input_text: str,
1470
  use_llm_func: callable,
 
1473
  history_messages: list[dict[str, str]] = None,
1474
  cache_type: str = "extract",
1475
  chunk_id: str | None = None,
1476
+ cache_keys_collector: list = None,
1477
  ) -> str:
1478
  """Call LLM function with cache support
1479
 
 
1488
  history_messages: History messages list
1489
  cache_type: Type of cache
1490
  chunk_id: Chunk identifier to store in cache
1491
+ text_chunks_storage: Text chunks storage to update llm_cache_list
1492
+ cache_keys_collector: Optional list to collect cache keys for batch processing
1493
 
1494
  Returns:
1495
  LLM response text
 
1502
  _prompt = input_text
1503
 
1504
  arg_hash = compute_args_hash(_prompt)
1505
+ # Generate cache key for this LLM call
1506
+ cache_key = generate_cache_key("default", cache_type, arg_hash)
1507
+
1508
  cached_return, _1, _2, _3 = await handle_cache(
1509
  llm_response_cache,
1510
  arg_hash,
 
1515
  if cached_return:
1516
  logger.debug(f"Found cache for {arg_hash}")
1517
  statistic_data["llm_cache"] += 1
1518
+
1519
+ # Add cache key to collector if provided
1520
+ if cache_keys_collector is not None:
1521
+ cache_keys_collector.append(cache_key)
1522
+
1523
  return cached_return
1524
  statistic_data["llm_call"] += 1
1525
 
 
1544
  ),
1545
  )
1546
 
1547
+ # Add cache key to collector if provided
1548
+ if cache_keys_collector is not None:
1549
+ cache_keys_collector.append(cache_key)
1550
+
1551
  return res
1552
 
1553
  # When cache is disabled, directly call LLM