yangdx
commited on
Commit
·
8eb5fb6
1
Parent(s):
731e6b1
Fix edge direction problem for Neo4j storage
Browse files- lightrag/kg/neo4j_impl.py +30 -5
lightrag/kg/neo4j_impl.py
CHANGED
@@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage):
|
|
665 |
) -> dict[str, list[tuple[str, str]]]:
|
666 |
"""
|
667 |
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
|
|
|
|
668 |
|
669 |
Args:
|
670 |
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
671 |
|
672 |
Returns:
|
673 |
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
|
|
|
|
|
|
674 |
"""
|
675 |
async with self._driver.session(
|
676 |
database=self._DATABASE, default_access_mode="READ"
|
677 |
) as session:
|
|
|
678 |
query = """
|
679 |
UNWIND $node_ids AS id
|
680 |
MATCH (n:base {entity_id: id})
|
681 |
OPTIONAL MATCH (n)-[r]-(connected:base)
|
682 |
-
RETURN id AS queried_id, n.entity_id AS
|
|
|
|
|
683 |
"""
|
684 |
result = await session.run(query, node_ids=node_ids)
|
|
|
685 |
# Initialize the dictionary with empty lists for each node ID
|
686 |
edges_dict = {node_id: [] for node_id in node_ids}
|
|
|
|
|
687 |
async for record in result:
|
688 |
queried_id = record["queried_id"]
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
await result.consume() # Ensure results are fully consumed
|
694 |
return edges_dict
|
695 |
|
|
|
665 |
) -> dict[str, list[tuple[str, str]]]:
|
666 |
"""
|
667 |
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
668 |
+
For each node, returns both outgoing and incoming edges to properly represent
|
669 |
+
the undirected graph nature.
|
670 |
|
671 |
Args:
|
672 |
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
673 |
|
674 |
Returns:
|
675 |
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
676 |
+
For each node, the list includes both:
|
677 |
+
- Outgoing edges: (queried_node, connected_node)
|
678 |
+
- Incoming edges: (connected_node, queried_node)
|
679 |
"""
|
680 |
async with self._driver.session(
|
681 |
database=self._DATABASE, default_access_mode="READ"
|
682 |
) as session:
|
683 |
+
# Query to get both outgoing and incoming edges
|
684 |
query = """
|
685 |
UNWIND $node_ids AS id
|
686 |
MATCH (n:base {entity_id: id})
|
687 |
OPTIONAL MATCH (n)-[r]-(connected:base)
|
688 |
+
RETURN id AS queried_id, n.entity_id AS node_entity_id,
|
689 |
+
connected.entity_id AS connected_entity_id,
|
690 |
+
startNode(r).entity_id AS start_entity_id
|
691 |
"""
|
692 |
result = await session.run(query, node_ids=node_ids)
|
693 |
+
|
694 |
# Initialize the dictionary with empty lists for each node ID
|
695 |
edges_dict = {node_id: [] for node_id in node_ids}
|
696 |
+
|
697 |
+
# Process results to include both outgoing and incoming edges
|
698 |
async for record in result:
|
699 |
queried_id = record["queried_id"]
|
700 |
+
node_entity_id = record["node_entity_id"]
|
701 |
+
connected_entity_id = record["connected_entity_id"]
|
702 |
+
start_entity_id = record["start_entity_id"]
|
703 |
+
|
704 |
+
# Skip if either node is None
|
705 |
+
if not node_entity_id or not connected_entity_id:
|
706 |
+
continue
|
707 |
+
|
708 |
+
# Determine the actual direction of the edge
|
709 |
+
# If the start node is the queried node, it's an outgoing edge
|
710 |
+
# Otherwise, it's an incoming edge
|
711 |
+
if start_entity_id == node_entity_id:
|
712 |
+
# Outgoing edge: (queried_node -> connected_node)
|
713 |
+
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
|
714 |
+
else:
|
715 |
+
# Incoming edge: (connected_node -> queried_node)
|
716 |
+
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
|
717 |
+
|
718 |
await result.consume() # Ensure results are fully consumed
|
719 |
return edges_dict
|
720 |
|