yangdx commited on
Commit
faceaee
·
1 Parent(s): 52aae9d

Add is_truncated to graph query for Neo4j

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +143 -56
lightrag/kg/neo4j_impl.py CHANGED
@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
658
  max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
659
 
660
  Returns:
661
- KnowledgeGraph object containing nodes and edges
 
662
  """
663
  result = KnowledgeGraph()
664
  seen_nodes = set()
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
669
  ) as session:
670
  try:
671
  if node_label == "*":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  main_query = """
673
  MATCH (n)
674
  OPTIONAL MATCH (n)-[r]-()
@@ -683,14 +701,20 @@ class Neo4JStorage(BaseGraphStorage):
683
  RETURN filtered_nodes AS node_info,
684
  collect(DISTINCT r) AS relationships
685
  """
686
- result_set = await session.run(
687
- main_query,
688
- {"max_nodes": max_nodes},
689
- )
 
 
 
 
 
 
690
 
691
  else:
692
- # Main query uses partial matching
693
- main_query = """
694
  MATCH (start)
695
  WHERE start.entity_id = $entity_id
696
  WITH start
@@ -698,63 +722,118 @@ class Neo4JStorage(BaseGraphStorage):
698
  relationshipFilter: '',
699
  minLevel: 0,
700
  maxLevel: $max_depth,
701
- limit: $max_nodes,
702
  bfs: true
703
  })
704
  YIELD nodes, relationships
 
705
  UNWIND nodes AS node
706
- WITH collect({node: node}) AS node_info, relationships
707
- RETURN node_info, relationships
708
  """
709
- result_set = await session.run(
710
- main_query,
711
- {
712
- "entity_id": node_label,
713
- "max_depth": max_depth,
714
- "max_nodes": max_nodes,
715
- },
716
- )
717
 
718
- try:
719
- record = await result_set.single()
720
-
721
- if record:
722
- # Handle nodes (compatible with multi-label cases)
723
- for node_info in record["node_info"]:
724
- node = node_info["node"]
725
- node_id = node.id
726
- if node_id not in seen_nodes:
727
- result.nodes.append(
728
- KnowledgeGraphNode(
729
- id=f"{node_id}",
730
- labels=[node.get("entity_id")],
731
- properties=dict(node),
732
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  )
734
- seen_nodes.add(node_id)
735
-
736
- # Handle relationships (including direction information)
737
- for rel in record["relationships"]:
738
- edge_id = rel.id
739
- if edge_id not in seen_edges:
740
- start = rel.start_node
741
- end = rel.end_node
742
- result.edges.append(
743
- KnowledgeGraphEdge(
744
- id=f"{edge_id}",
745
- type=rel.type,
746
- source=f"{start.id}",
747
- target=f"{end.id}",
748
- properties=dict(rel),
749
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  )
751
- seen_edges.add(edge_id)
 
752
 
753
- logger.info(
754
- f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
755
- )
756
- finally:
757
- await result_set.consume() # Ensure result set is consumed
758
 
759
  except neo4jExceptions.ClientError as e:
760
  logger.warning(f"APOC plugin error: {str(e)}")
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
763
  "Neo4j: falling back to basic Cypher recursive search..."
764
  )
765
  return await self._robust_fallback(node_label, max_depth, max_nodes)
 
 
 
 
766
 
767
  return result
768
 
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
788
  logger.debug(f"Reached max depth: {max_depth}")
789
  return
790
  if len(visited_nodes) >= max_nodes:
791
- logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
 
 
 
 
792
  return
793
 
794
  # Check if node already visited
 
658
  max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
659
 
660
  Returns:
661
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
662
+ indicating whether the graph was truncated due to max_nodes limit
663
  """
664
  result = KnowledgeGraph()
665
  seen_nodes = set()
 
670
  ) as session:
671
  try:
672
  if node_label == "*":
673
+ # First check total node count to determine if graph is truncated
674
+ count_query = "MATCH (n) RETURN count(n) as total"
675
+ count_result = None
676
+ try:
677
+ count_result = await session.run(count_query)
678
+ count_record = await count_result.single()
679
+
680
+ if count_record and count_record["total"] > max_nodes:
681
+ result.is_truncated = True
682
+ logger.info(
683
+ f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
684
+ )
685
+ finally:
686
+ if count_result:
687
+ await count_result.consume()
688
+
689
+ # Run main query to get nodes with highest degree
690
  main_query = """
691
  MATCH (n)
692
  OPTIONAL MATCH (n)-[r]-()
 
701
  RETURN filtered_nodes AS node_info,
702
  collect(DISTINCT r) AS relationships
703
  """
704
+ result_set = None
705
+ try:
706
+ result_set = await session.run(
707
+ main_query,
708
+ {"max_nodes": max_nodes},
709
+ )
710
+ record = await result_set.single()
711
+ finally:
712
+ if result_set:
713
+ await result_set.consume()
714
 
715
  else:
716
+ # First try without limit to check if we need to truncate
717
+ full_query = """
718
  MATCH (start)
719
  WHERE start.entity_id = $entity_id
720
  WITH start
 
722
  relationshipFilter: '',
723
  minLevel: 0,
724
  maxLevel: $max_depth,
 
725
  bfs: true
726
  })
727
  YIELD nodes, relationships
728
+ WITH nodes, relationships, size(nodes) AS total_nodes
729
  UNWIND nodes AS node
730
+ WITH collect({node: node}) AS node_info, relationships, total_nodes
731
+ RETURN node_info, relationships, total_nodes
732
  """
 
 
 
 
 
 
 
 
733
 
734
+ # Try to get full result
735
+ full_result = None
736
+ try:
737
+ full_result = await session.run(
738
+ full_query,
739
+ {
740
+ "entity_id": node_label,
741
+ "max_depth": max_depth,
742
+ },
743
+ )
744
+ full_record = await full_result.single()
745
+
746
+ # If no record found, return empty KnowledgeGraph
747
+ if not full_record:
748
+ logger.debug(f"No nodes found for entity_id: {node_label}")
749
+ return result
750
+
751
+ # If record found, check node count
752
+ total_nodes = full_record["total_nodes"]
753
+
754
+ if total_nodes <= max_nodes:
755
+ # If node count is within limit, use full result directly
756
+ logger.debug(
757
+ f"Using full result with {total_nodes} nodes (no truncation needed)"
758
+ )
759
+ record = full_record
760
+ else:
761
+ # If node count exceeds limit, set truncated flag and run limited query
762
+ result.is_truncated = True
763
+ logger.info(
764
+ f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
765
+ )
766
+
767
+ # Run limited query
768
+ limited_query = """
769
+ MATCH (start)
770
+ WHERE start.entity_id = $entity_id
771
+ WITH start
772
+ CALL apoc.path.subgraphAll(start, {
773
+ relationshipFilter: '',
774
+ minLevel: 0,
775
+ maxLevel: $max_depth,
776
+ limit: $max_nodes,
777
+ bfs: true
778
+ })
779
+ YIELD nodes, relationships
780
+ UNWIND nodes AS node
781
+ WITH collect({node: node}) AS node_info, relationships
782
+ RETURN node_info, relationships
783
+ """
784
+ result_set = None
785
+ try:
786
+ result_set = await session.run(
787
+ limited_query,
788
+ {
789
+ "entity_id": node_label,
790
+ "max_depth": max_depth,
791
+ "max_nodes": max_nodes,
792
+ },
793
  )
794
+ record = await result_set.single()
795
+ finally:
796
+ if result_set:
797
+ await result_set.consume()
798
+ finally:
799
+ if full_result:
800
+ await full_result.consume()
801
+
802
+ if record:
803
+ # Handle nodes (compatible with multi-label cases)
804
+ for node_info in record["node_info"]:
805
+ node = node_info["node"]
806
+ node_id = node.id
807
+ if node_id not in seen_nodes:
808
+ result.nodes.append(
809
+ KnowledgeGraphNode(
810
+ id=f"{node_id}",
811
+ labels=[node.get("entity_id")],
812
+ properties=dict(node),
813
+ )
814
+ )
815
+ seen_nodes.add(node_id)
816
+
817
+ # Handle relationships (including direction information)
818
+ for rel in record["relationships"]:
819
+ edge_id = rel.id
820
+ if edge_id not in seen_edges:
821
+ start = rel.start_node
822
+ end = rel.end_node
823
+ result.edges.append(
824
+ KnowledgeGraphEdge(
825
+ id=f"{edge_id}",
826
+ type=rel.type,
827
+ source=f"{start.id}",
828
+ target=f"{end.id}",
829
+ properties=dict(rel),
830
  )
831
+ )
832
+ seen_edges.add(edge_id)
833
 
834
+ logger.info(
835
+ f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
836
+ )
 
 
837
 
838
  except neo4jExceptions.ClientError as e:
839
  logger.warning(f"APOC plugin error: {str(e)}")
 
842
  "Neo4j: falling back to basic Cypher recursive search..."
843
  )
844
  return await self._robust_fallback(node_label, max_depth, max_nodes)
845
+ else:
846
+ logger.warning(
847
+ "Neo4j: APOC plugin error with wildcard query, returning empty result"
848
+ )
849
 
850
  return result
851
 
 
871
  logger.debug(f"Reached max depth: {max_depth}")
872
  return
873
  if len(visited_nodes) >= max_nodes:
874
+ # Set truncated flag when we hit the max_nodes limit
875
+ result.is_truncated = True
876
+ logger.info(
877
+ f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
878
+ )
879
  return
880
 
881
  # Check if node already visited