yangdx commited on
Commit
c586d56
·
1 Parent(s): 6e3f788

Improve parallel handling logic between extraction and merge operation

Browse files
Files changed (2) hide show
  1. lightrag/lightrag.py +62 -14
  2. lightrag/operate.py +137 -106
lightrag/lightrag.py CHANGED
@@ -46,6 +46,7 @@ from .namespace import NameSpace, make_namespace
46
  from .operate import (
47
  chunking_by_token_size,
48
  extract_entities,
 
49
  kg_query,
50
  mix_kg_vector_query,
51
  naive_query,
@@ -902,6 +903,7 @@ class LightRAG:
902
  semaphore: asyncio.Semaphore,
903
  ) -> None:
904
  """Process single document"""
 
905
  async with semaphore:
906
  nonlocal processed_count
907
  current_file_number = 0
@@ -919,7 +921,7 @@ class LightRAG:
919
  )
920
  pipeline_status["cur_batch"] = processed_count
921
 
922
- log_message = f"Processing file ({current_file_number}/{total_files}): {file_path}"
923
  logger.info(log_message)
924
  pipeline_status["history_messages"].append(log_message)
925
  log_message = f"Processing d-id: {doc_id}"
@@ -986,6 +988,61 @@ class LightRAG:
986
  text_chunks_task,
987
  ]
988
  await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
989
  await self.doc_status.upsert(
990
  {
991
  doc_id: {
@@ -1012,22 +1069,12 @@ class LightRAG:
1012
 
1013
  except Exception as e:
1014
  # Log error and update pipeline status
1015
- error_msg = f"Failed to process document {doc_id}: {traceback.format_exc()}"
1016
-
1017
  logger.error(error_msg)
1018
  async with pipeline_status_lock:
1019
  pipeline_status["latest_message"] = error_msg
1020
  pipeline_status["history_messages"].append(error_msg)
1021
 
1022
- # Cancel other tasks as they are no longer meaningful
1023
- for task in [
1024
- chunks_vdb_task,
1025
- entity_relation_task,
1026
- full_docs_task,
1027
- text_chunks_task,
1028
- ]:
1029
- if not task.done():
1030
- task.cancel()
1031
  # Update document status to failed
1032
  await self.doc_status.upsert(
1033
  {
@@ -1101,9 +1148,9 @@ class LightRAG:
1101
 
1102
  async def _process_entity_relation_graph(
1103
  self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
1104
- ) -> None:
1105
  try:
1106
- await extract_entities(
1107
  chunk,
1108
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1109
  entity_vdb=self.entities_vdb,
@@ -1113,6 +1160,7 @@ class LightRAG:
1113
  pipeline_status_lock=pipeline_status_lock,
1114
  llm_response_cache=self.llm_response_cache,
1115
  )
 
1116
  except Exception as e:
1117
  error_msg = f"Failed to extract entities and relationships: {str(e)}"
1118
  logger.error(error_msg)
 
46
  from .operate import (
47
  chunking_by_token_size,
48
  extract_entities,
49
+ merge_nodes_and_edges,
50
  kg_query,
51
  mix_kg_vector_query,
52
  naive_query,
 
903
  semaphore: asyncio.Semaphore,
904
  ) -> None:
905
  """Process single document"""
906
+ file_extraction_stage_ok = False
907
  async with semaphore:
908
  nonlocal processed_count
909
  current_file_number = 0
 
921
  )
922
  pipeline_status["cur_batch"] = processed_count
923
 
924
+ log_message = f"Processing file {current_file_number}/{total_files}: {file_path}"
925
  logger.info(log_message)
926
  pipeline_status["history_messages"].append(log_message)
927
  log_message = f"Processing d-id: {doc_id}"
 
988
  text_chunks_task,
989
  ]
990
  await asyncio.gather(*tasks)
991
+ file_extraction_stage_ok = True
992
+
993
+ except Exception as e:
994
+ # Log error and update pipeline status
995
+ error_msg = f"Failed to extrat document {doc_id}: {traceback.format_exc()}"
996
+ logger.error(error_msg)
997
+ async with pipeline_status_lock:
998
+ pipeline_status["latest_message"] = error_msg
999
+ pipeline_status["history_messages"].append(error_msg)
1000
+
1001
+ # Cancel other tasks as they are no longer meaningful
1002
+ for task in [
1003
+ chunks_vdb_task,
1004
+ entity_relation_task,
1005
+ full_docs_task,
1006
+ text_chunks_task,
1007
+ ]:
1008
+ if not task.done():
1009
+ task.cancel()
1010
+
1011
+ # Update document status to failed
1012
+ await self.doc_status.upsert(
1013
+ {
1014
+ doc_id: {
1015
+ "status": DocStatus.FAILED,
1016
+ "error": str(e),
1017
+ "content": status_doc.content,
1018
+ "content_summary": status_doc.content_summary,
1019
+ "content_length": status_doc.content_length,
1020
+ "created_at": status_doc.created_at,
1021
+ "updated_at": datetime.now().isoformat(),
1022
+ "file_path": file_path,
1023
+ }
1024
+ }
1025
+ )
1026
+
1027
+ # Release semphore before entering to merge stage
1028
+ if file_extraction_stage_ok:
1029
+ try:
1030
+ # Get chunk_results from entity_relation_task
1031
+ chunk_results = await entity_relation_task
1032
+ await merge_nodes_and_edges(
1033
+ chunk_results=chunk_results, # result collected from entity_relation_task
1034
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
1035
+ entity_vdb=self.entities_vdb,
1036
+ relationships_vdb=self.relationships_vdb,
1037
+ global_config=asdict(self),
1038
+ pipeline_status=pipeline_status,
1039
+ pipeline_status_lock=pipeline_status_lock,
1040
+ llm_response_cache=self.llm_response_cache,
1041
+ current_file_number=current_file_number,
1042
+ total_files=total_files,
1043
+ file_path=file_path,
1044
+ )
1045
+
1046
  await self.doc_status.upsert(
1047
  {
1048
  doc_id: {
 
1069
 
1070
  except Exception as e:
1071
  # Log error and update pipeline status
1072
+ error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}"
 
1073
  logger.error(error_msg)
1074
  async with pipeline_status_lock:
1075
  pipeline_status["latest_message"] = error_msg
1076
  pipeline_status["history_messages"].append(error_msg)
1077
 
 
 
 
 
 
 
 
 
 
1078
  # Update document status to failed
1079
  await self.doc_status.upsert(
1080
  {
 
1148
 
1149
  async def _process_entity_relation_graph(
1150
  self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
1151
+ ) -> list:
1152
  try:
1153
+ chunk_results = await extract_entities(
1154
  chunk,
1155
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1156
  entity_vdb=self.entities_vdb,
 
1160
  pipeline_status_lock=pipeline_status_lock,
1161
  llm_response_cache=self.llm_response_cache,
1162
  )
1163
+ return chunk_results
1164
  except Exception as e:
1165
  error_msg = f"Failed to extract entities and relationships: {str(e)}"
1166
  logger.error(error_msg)
lightrag/operate.py CHANGED
@@ -476,6 +476,139 @@ async def _merge_edges_then_upsert(
476
  return edge_data
477
 
478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  async def extract_entities(
480
  chunks: dict[str, TextChunkSchema],
481
  knowledge_graph_inst: BaseGraphStorage,
@@ -485,7 +618,7 @@ async def extract_entities(
485
  pipeline_status: dict = None,
486
  pipeline_status_lock=None,
487
  llm_response_cache: BaseKVStorage | None = None,
488
- ) -> None:
489
  use_llm_func: callable = global_config["llm_model_func"]
490
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
491
 
@@ -530,15 +663,6 @@ async def extract_entities(
530
 
531
  processed_chunks = 0
532
  total_chunks = len(ordered_chunks)
533
- total_entities_count = 0
534
- total_relations_count = 0
535
-
536
- # Get lock manager from shared storage
537
- from .kg.shared_storage import get_graph_db_lock
538
-
539
- graph_db_lock = get_graph_db_lock(enable_logging=False)
540
-
541
- # Use the global use_llm_func_with_cache function from utils.py
542
 
543
  async def _process_extraction_result(
544
  result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -708,102 +832,9 @@ async def extract_entities(
708
 
709
  # If all tasks completed successfully, collect results
710
  chunk_results = [task.result() for task in tasks]
711
-
712
- # Collect all nodes and edges from all chunks
713
- all_nodes = defaultdict(list)
714
- all_edges = defaultdict(list)
715
-
716
- for maybe_nodes, maybe_edges in chunk_results:
717
- # Collect nodes
718
- for entity_name, entities in maybe_nodes.items():
719
- all_nodes[entity_name].extend(entities)
720
-
721
- # Collect edges with sorted keys for undirected graph
722
- for edge_key, edges in maybe_edges.items():
723
- sorted_edge_key = tuple(sorted(edge_key))
724
- all_edges[sorted_edge_key].extend(edges)
725
-
726
- # Centralized processing of all nodes and edges
727
- entities_data = []
728
- relationships_data = []
729
-
730
- # Use graph database lock to ensure atomic merges and updates
731
- async with graph_db_lock:
732
- # Process and update all entities at once
733
- for entity_name, entities in all_nodes.items():
734
- entity_data = await _merge_nodes_then_upsert(
735
- entity_name,
736
- entities,
737
- knowledge_graph_inst,
738
- global_config,
739
- pipeline_status,
740
- pipeline_status_lock,
741
- llm_response_cache,
742
- )
743
- entities_data.append(entity_data)
744
-
745
- # Process and update all relationships at once
746
- for edge_key, edges in all_edges.items():
747
- edge_data = await _merge_edges_then_upsert(
748
- edge_key[0],
749
- edge_key[1],
750
- edges,
751
- knowledge_graph_inst,
752
- global_config,
753
- pipeline_status,
754
- pipeline_status_lock,
755
- llm_response_cache,
756
- )
757
- if edge_data is not None:
758
- relationships_data.append(edge_data)
759
-
760
- # Update total counts
761
- total_entities_count = len(entities_data)
762
- total_relations_count = len(relationships_data)
763
-
764
- log_message = f"Updating vector storage: {total_entities_count} entities..."
765
- logger.info(log_message)
766
- if pipeline_status is not None:
767
- async with pipeline_status_lock:
768
- pipeline_status["latest_message"] = log_message
769
- pipeline_status["history_messages"].append(log_message)
770
-
771
- # Update vector databases with all collected data
772
- if entity_vdb is not None and entities_data:
773
- data_for_vdb = {
774
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
775
- "entity_name": dp["entity_name"],
776
- "entity_type": dp["entity_type"],
777
- "content": f"{dp['entity_name']}\n{dp['description']}",
778
- "source_id": dp["source_id"],
779
- "file_path": dp.get("file_path", "unknown_source"),
780
- }
781
- for dp in entities_data
782
- }
783
- await entity_vdb.upsert(data_for_vdb)
784
-
785
- log_message = (
786
- f"Updating vector storage: {total_relations_count} relationships..."
787
- )
788
- logger.info(log_message)
789
- if pipeline_status is not None:
790
- async with pipeline_status_lock:
791
- pipeline_status["latest_message"] = log_message
792
- pipeline_status["history_messages"].append(log_message)
793
-
794
- if relationships_vdb is not None and relationships_data:
795
- data_for_vdb = {
796
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
797
- "src_id": dp["src_id"],
798
- "tgt_id": dp["tgt_id"],
799
- "keywords": dp["keywords"],
800
- "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
801
- "source_id": dp["source_id"],
802
- "file_path": dp.get("file_path", "unknown_source"),
803
- }
804
- for dp in relationships_data
805
- }
806
- await relationships_vdb.upsert(data_for_vdb)
807
 
808
 
809
  async def kg_query(
 
476
  return edge_data
477
 
478
 
479
+ async def merge_nodes_and_edges(
480
+ chunk_results: list,
481
+ knowledge_graph_inst: BaseGraphStorage,
482
+ entity_vdb: BaseVectorStorage,
483
+ relationships_vdb: BaseVectorStorage,
484
+ global_config: dict[str, str],
485
+ pipeline_status: dict = None,
486
+ pipeline_status_lock=None,
487
+ llm_response_cache: BaseKVStorage | None = None,
488
+ current_file_number: int = 0,
489
+ total_files: int = 0,
490
+ file_path: str = "unknown_source",
491
+ ) -> None:
492
+ """Merge nodes and edges from extraction results
493
+
494
+ Args:
495
+ chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
496
+ knowledge_graph_inst: Knowledge graph storage
497
+ entity_vdb: Entity vector database
498
+ relationships_vdb: Relationship vector database
499
+ global_config: Global configuration
500
+ pipeline_status: Pipeline status dictionary
501
+ pipeline_status_lock: Lock for pipeline status
502
+ llm_response_cache: LLM response cache
503
+ """
504
+ # Get lock manager from shared storage
505
+ from .kg.shared_storage import get_graph_db_lock
506
+ graph_db_lock = get_graph_db_lock(enable_logging=False)
507
+
508
+ # Collect all nodes and edges from all chunks
509
+ all_nodes = defaultdict(list)
510
+ all_edges = defaultdict(list)
511
+
512
+ for maybe_nodes, maybe_edges in chunk_results:
513
+ # Collect nodes
514
+ for entity_name, entities in maybe_nodes.items():
515
+ all_nodes[entity_name].extend(entities)
516
+
517
+ # Collect edges with sorted keys for undirected graph
518
+ for edge_key, edges in maybe_edges.items():
519
+ sorted_edge_key = tuple(sorted(edge_key))
520
+ all_edges[sorted_edge_key].extend(edges)
521
+
522
+ # Centralized processing of all nodes and edges
523
+ entities_data = []
524
+ relationships_data = []
525
+
526
+ # Merge nodes and edges
527
+ # Use graph database lock to ensure atomic merges and updates
528
+ async with graph_db_lock:
529
+ async with pipeline_status_lock:
530
+ log_message = f"Merging nodes/edges {current_file_number}/{total_files}: {file_path}"
531
+ logger.info(log_message)
532
+ pipeline_status["latest_message"] = log_message
533
+ pipeline_status["history_messages"].append(log_message)
534
+
535
+ # Process and update all entities at once
536
+ for entity_name, entities in all_nodes.items():
537
+ entity_data = await _merge_nodes_then_upsert(
538
+ entity_name,
539
+ entities,
540
+ knowledge_graph_inst,
541
+ global_config,
542
+ pipeline_status,
543
+ pipeline_status_lock,
544
+ llm_response_cache,
545
+ )
546
+ entities_data.append(entity_data)
547
+
548
+ # Process and update all relationships at once
549
+ for edge_key, edges in all_edges.items():
550
+ edge_data = await _merge_edges_then_upsert(
551
+ edge_key[0],
552
+ edge_key[1],
553
+ edges,
554
+ knowledge_graph_inst,
555
+ global_config,
556
+ pipeline_status,
557
+ pipeline_status_lock,
558
+ llm_response_cache,
559
+ )
560
+ if edge_data is not None:
561
+ relationships_data.append(edge_data)
562
+
563
+ # Update total counts
564
+ total_entities_count = len(entities_data)
565
+ total_relations_count = len(relationships_data)
566
+
567
+ log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
568
+ logger.info(log_message)
569
+ if pipeline_status is not None:
570
+ async with pipeline_status_lock:
571
+ pipeline_status["latest_message"] = log_message
572
+ pipeline_status["history_messages"].append(log_message)
573
+
574
+ # Update vector databases with all collected data
575
+ if entity_vdb is not None and entities_data:
576
+ data_for_vdb = {
577
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
578
+ "entity_name": dp["entity_name"],
579
+ "entity_type": dp["entity_type"],
580
+ "content": f"{dp['entity_name']}\n{dp['description']}",
581
+ "source_id": dp["source_id"],
582
+ "file_path": dp.get("file_path", "unknown_source"),
583
+ }
584
+ for dp in entities_data
585
+ }
586
+ await entity_vdb.upsert(data_for_vdb)
587
+
588
+ log_message = (
589
+ f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
590
+ )
591
+ logger.info(log_message)
592
+ if pipeline_status is not None:
593
+ async with pipeline_status_lock:
594
+ pipeline_status["latest_message"] = log_message
595
+ pipeline_status["history_messages"].append(log_message)
596
+
597
+ if relationships_vdb is not None and relationships_data:
598
+ data_for_vdb = {
599
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
600
+ "src_id": dp["src_id"],
601
+ "tgt_id": dp["tgt_id"],
602
+ "keywords": dp["keywords"],
603
+ "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
604
+ "source_id": dp["source_id"],
605
+ "file_path": dp.get("file_path", "unknown_source"),
606
+ }
607
+ for dp in relationships_data
608
+ }
609
+ await relationships_vdb.upsert(data_for_vdb)
610
+
611
+
612
  async def extract_entities(
613
  chunks: dict[str, TextChunkSchema],
614
  knowledge_graph_inst: BaseGraphStorage,
 
618
  pipeline_status: dict = None,
619
  pipeline_status_lock=None,
620
  llm_response_cache: BaseKVStorage | None = None,
621
+ ) -> list:
622
  use_llm_func: callable = global_config["llm_model_func"]
623
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
624
 
 
663
 
664
  processed_chunks = 0
665
  total_chunks = len(ordered_chunks)
 
 
 
 
 
 
 
 
 
666
 
667
  async def _process_extraction_result(
668
  result: str, chunk_key: str, file_path: str = "unknown_source"
 
832
 
833
  # If all tasks completed successfully, collect results
834
  chunk_results = [task.result() for task in tasks]
835
+
836
+ # Return the chunk_results for later processing in merge_nodes_and_edges
837
+ return chunk_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
 
839
 
840
  async def kg_query(