yangdx commited on
Commit
8eb5fb6
·
1 Parent(s): 731e6b1

Fix edge direction problem for Neo4j storage

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +30 -5
lightrag/kg/neo4j_impl.py CHANGED
@@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage):
665
  ) -> dict[str, list[tuple[str, str]]]:
666
  """
667
  Batch retrieve edges for multiple nodes in one query using UNWIND.
 
 
668
 
669
  Args:
670
  node_ids: List of node IDs (entity_id) for which to retrieve edges.
671
 
672
  Returns:
673
  A dictionary mapping each node ID to its list of edge tuples (source, target).
 
 
 
674
  """
675
  async with self._driver.session(
676
  database=self._DATABASE, default_access_mode="READ"
677
  ) as session:
 
678
  query = """
679
  UNWIND $node_ids AS id
680
  MATCH (n:base {entity_id: id})
681
  OPTIONAL MATCH (n)-[r]-(connected:base)
682
- RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
 
 
683
  """
684
  result = await session.run(query, node_ids=node_ids)
 
685
  # Initialize the dictionary with empty lists for each node ID
686
  edges_dict = {node_id: [] for node_id in node_ids}
 
 
687
  async for record in result:
688
  queried_id = record["queried_id"]
689
- source_label = record["source_entity_id"]
690
- target_label = record["target_entity_id"]
691
- if source_label and target_label:
692
- edges_dict[queried_id].append((source_label, target_label))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  await result.consume() # Ensure results are fully consumed
694
  return edges_dict
695
 
 
665
  ) -> dict[str, list[tuple[str, str]]]:
666
  """
667
  Batch retrieve edges for multiple nodes in one query using UNWIND.
668
+ For each node, returns both outgoing and incoming edges to properly represent
669
+ the undirected graph nature.
670
 
671
  Args:
672
  node_ids: List of node IDs (entity_id) for which to retrieve edges.
673
 
674
  Returns:
675
  A dictionary mapping each node ID to its list of edge tuples (source, target).
676
+ For each node, the list includes both:
677
+ - Outgoing edges: (queried_node, connected_node)
678
+ - Incoming edges: (connected_node, queried_node)
679
  """
680
  async with self._driver.session(
681
  database=self._DATABASE, default_access_mode="READ"
682
  ) as session:
683
+ # Query to get both outgoing and incoming edges
684
  query = """
685
  UNWIND $node_ids AS id
686
  MATCH (n:base {entity_id: id})
687
  OPTIONAL MATCH (n)-[r]-(connected:base)
688
+ RETURN id AS queried_id, n.entity_id AS node_entity_id,
689
+ connected.entity_id AS connected_entity_id,
690
+ startNode(r).entity_id AS start_entity_id
691
  """
692
  result = await session.run(query, node_ids=node_ids)
693
+
694
  # Initialize the dictionary with empty lists for each node ID
695
  edges_dict = {node_id: [] for node_id in node_ids}
696
+
697
+ # Process results to include both outgoing and incoming edges
698
  async for record in result:
699
  queried_id = record["queried_id"]
700
+ node_entity_id = record["node_entity_id"]
701
+ connected_entity_id = record["connected_entity_id"]
702
+ start_entity_id = record["start_entity_id"]
703
+
704
+ # Skip if either node is None
705
+ if not node_entity_id or not connected_entity_id:
706
+ continue
707
+
708
+ # Determine the actual direction of the edge
709
+ # If the start node is the queried node, it's an outgoing edge
710
+ # Otherwise, it's an incoming edge
711
+ if start_entity_id == node_entity_id:
712
+ # Outgoing edge: (queried_node -> connected_node)
713
+ edges_dict[queried_id].append((node_entity_id, connected_entity_id))
714
+ else:
715
+ # Incoming edge: (connected_node -> queried_node)
716
+ edges_dict[queried_id].append((connected_entity_id, node_entity_id))
717
+
718
  await result.consume() # Ensure results are fully consumed
719
  return edges_dict
720