yangdx
commited on
Commit
·
faceaee
1
Parent(s):
52aae9d
Add is_truncated to graph query for Neo4j
Browse files- lightrag/kg/neo4j_impl.py +143 -56
lightrag/kg/neo4j_impl.py
CHANGED
@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|
658 |
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
659 |
|
660 |
Returns:
|
661 |
-
KnowledgeGraph object containing nodes and edges
|
|
|
662 |
"""
|
663 |
result = KnowledgeGraph()
|
664 |
seen_nodes = set()
|
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
669 |
) as session:
|
670 |
try:
|
671 |
if node_label == "*":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
672 |
main_query = """
|
673 |
MATCH (n)
|
674 |
OPTIONAL MATCH (n)-[r]-()
|
@@ -683,14 +701,20 @@ class Neo4JStorage(BaseGraphStorage):
|
|
683 |
RETURN filtered_nodes AS node_info,
|
684 |
collect(DISTINCT r) AS relationships
|
685 |
"""
|
686 |
-
result_set =
|
687 |
-
|
688 |
-
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
|
691 |
else:
|
692 |
-
#
|
693 |
-
|
694 |
MATCH (start)
|
695 |
WHERE start.entity_id = $entity_id
|
696 |
WITH start
|
@@ -698,63 +722,118 @@ class Neo4JStorage(BaseGraphStorage):
|
|
698 |
relationshipFilter: '',
|
699 |
minLevel: 0,
|
700 |
maxLevel: $max_depth,
|
701 |
-
limit: $max_nodes,
|
702 |
bfs: true
|
703 |
})
|
704 |
YIELD nodes, relationships
|
|
|
705 |
UNWIND nodes AS node
|
706 |
-
WITH collect({node: node}) AS node_info, relationships
|
707 |
-
RETURN node_info, relationships
|
708 |
"""
|
709 |
-
result_set = await session.run(
|
710 |
-
main_query,
|
711 |
-
{
|
712 |
-
"entity_id": node_label,
|
713 |
-
"max_depth": max_depth,
|
714 |
-
"max_nodes": max_nodes,
|
715 |
-
},
|
716 |
-
)
|
717 |
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
)
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
)
|
751 |
-
|
|
|
752 |
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
finally:
|
757 |
-
await result_set.consume() # Ensure result set is consumed
|
758 |
|
759 |
except neo4jExceptions.ClientError as e:
|
760 |
logger.warning(f"APOC plugin error: {str(e)}")
|
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|
763 |
"Neo4j: falling back to basic Cypher recursive search..."
|
764 |
)
|
765 |
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
|
|
|
|
|
|
|
|
766 |
|
767 |
return result
|
768 |
|
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|
788 |
logger.debug(f"Reached max depth: {max_depth}")
|
789 |
return
|
790 |
if len(visited_nodes) >= max_nodes:
|
791 |
-
|
|
|
|
|
|
|
|
|
792 |
return
|
793 |
|
794 |
# Check if node already visited
|
|
|
658 |
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
659 |
|
660 |
Returns:
|
661 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
662 |
+
indicating whether the graph was truncated due to max_nodes limit
|
663 |
"""
|
664 |
result = KnowledgeGraph()
|
665 |
seen_nodes = set()
|
|
|
670 |
) as session:
|
671 |
try:
|
672 |
if node_label == "*":
|
673 |
+
# First check total node count to determine if graph is truncated
|
674 |
+
count_query = "MATCH (n) RETURN count(n) as total"
|
675 |
+
count_result = None
|
676 |
+
try:
|
677 |
+
count_result = await session.run(count_query)
|
678 |
+
count_record = await count_result.single()
|
679 |
+
|
680 |
+
if count_record and count_record["total"] > max_nodes:
|
681 |
+
result.is_truncated = True
|
682 |
+
logger.info(
|
683 |
+
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
684 |
+
)
|
685 |
+
finally:
|
686 |
+
if count_result:
|
687 |
+
await count_result.consume()
|
688 |
+
|
689 |
+
# Run main query to get nodes with highest degree
|
690 |
main_query = """
|
691 |
MATCH (n)
|
692 |
OPTIONAL MATCH (n)-[r]-()
|
|
|
701 |
RETURN filtered_nodes AS node_info,
|
702 |
collect(DISTINCT r) AS relationships
|
703 |
"""
|
704 |
+
result_set = None
|
705 |
+
try:
|
706 |
+
result_set = await session.run(
|
707 |
+
main_query,
|
708 |
+
{"max_nodes": max_nodes},
|
709 |
+
)
|
710 |
+
record = await result_set.single()
|
711 |
+
finally:
|
712 |
+
if result_set:
|
713 |
+
await result_set.consume()
|
714 |
|
715 |
else:
|
716 |
+
# First try without limit to check if we need to truncate
|
717 |
+
full_query = """
|
718 |
MATCH (start)
|
719 |
WHERE start.entity_id = $entity_id
|
720 |
WITH start
|
|
|
722 |
relationshipFilter: '',
|
723 |
minLevel: 0,
|
724 |
maxLevel: $max_depth,
|
|
|
725 |
bfs: true
|
726 |
})
|
727 |
YIELD nodes, relationships
|
728 |
+
WITH nodes, relationships, size(nodes) AS total_nodes
|
729 |
UNWIND nodes AS node
|
730 |
+
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
731 |
+
RETURN node_info, relationships, total_nodes
|
732 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
|
734 |
+
# Try to get full result
|
735 |
+
full_result = None
|
736 |
+
try:
|
737 |
+
full_result = await session.run(
|
738 |
+
full_query,
|
739 |
+
{
|
740 |
+
"entity_id": node_label,
|
741 |
+
"max_depth": max_depth,
|
742 |
+
},
|
743 |
+
)
|
744 |
+
full_record = await full_result.single()
|
745 |
+
|
746 |
+
# If no record found, return empty KnowledgeGraph
|
747 |
+
if not full_record:
|
748 |
+
logger.debug(f"No nodes found for entity_id: {node_label}")
|
749 |
+
return result
|
750 |
+
|
751 |
+
# If record found, check node count
|
752 |
+
total_nodes = full_record["total_nodes"]
|
753 |
+
|
754 |
+
if total_nodes <= max_nodes:
|
755 |
+
# If node count is within limit, use full result directly
|
756 |
+
logger.debug(
|
757 |
+
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
758 |
+
)
|
759 |
+
record = full_record
|
760 |
+
else:
|
761 |
+
# If node count exceeds limit, set truncated flag and run limited query
|
762 |
+
result.is_truncated = True
|
763 |
+
logger.info(
|
764 |
+
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
765 |
+
)
|
766 |
+
|
767 |
+
# Run limited query
|
768 |
+
limited_query = """
|
769 |
+
MATCH (start)
|
770 |
+
WHERE start.entity_id = $entity_id
|
771 |
+
WITH start
|
772 |
+
CALL apoc.path.subgraphAll(start, {
|
773 |
+
relationshipFilter: '',
|
774 |
+
minLevel: 0,
|
775 |
+
maxLevel: $max_depth,
|
776 |
+
limit: $max_nodes,
|
777 |
+
bfs: true
|
778 |
+
})
|
779 |
+
YIELD nodes, relationships
|
780 |
+
UNWIND nodes AS node
|
781 |
+
WITH collect({node: node}) AS node_info, relationships
|
782 |
+
RETURN node_info, relationships
|
783 |
+
"""
|
784 |
+
result_set = None
|
785 |
+
try:
|
786 |
+
result_set = await session.run(
|
787 |
+
limited_query,
|
788 |
+
{
|
789 |
+
"entity_id": node_label,
|
790 |
+
"max_depth": max_depth,
|
791 |
+
"max_nodes": max_nodes,
|
792 |
+
},
|
793 |
)
|
794 |
+
record = await result_set.single()
|
795 |
+
finally:
|
796 |
+
if result_set:
|
797 |
+
await result_set.consume()
|
798 |
+
finally:
|
799 |
+
if full_result:
|
800 |
+
await full_result.consume()
|
801 |
+
|
802 |
+
if record:
|
803 |
+
# Handle nodes (compatible with multi-label cases)
|
804 |
+
for node_info in record["node_info"]:
|
805 |
+
node = node_info["node"]
|
806 |
+
node_id = node.id
|
807 |
+
if node_id not in seen_nodes:
|
808 |
+
result.nodes.append(
|
809 |
+
KnowledgeGraphNode(
|
810 |
+
id=f"{node_id}",
|
811 |
+
labels=[node.get("entity_id")],
|
812 |
+
properties=dict(node),
|
813 |
+
)
|
814 |
+
)
|
815 |
+
seen_nodes.add(node_id)
|
816 |
+
|
817 |
+
# Handle relationships (including direction information)
|
818 |
+
for rel in record["relationships"]:
|
819 |
+
edge_id = rel.id
|
820 |
+
if edge_id not in seen_edges:
|
821 |
+
start = rel.start_node
|
822 |
+
end = rel.end_node
|
823 |
+
result.edges.append(
|
824 |
+
KnowledgeGraphEdge(
|
825 |
+
id=f"{edge_id}",
|
826 |
+
type=rel.type,
|
827 |
+
source=f"{start.id}",
|
828 |
+
target=f"{end.id}",
|
829 |
+
properties=dict(rel),
|
830 |
)
|
831 |
+
)
|
832 |
+
seen_edges.add(edge_id)
|
833 |
|
834 |
+
logger.info(
|
835 |
+
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
836 |
+
)
|
|
|
|
|
837 |
|
838 |
except neo4jExceptions.ClientError as e:
|
839 |
logger.warning(f"APOC plugin error: {str(e)}")
|
|
|
842 |
"Neo4j: falling back to basic Cypher recursive search..."
|
843 |
)
|
844 |
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
845 |
+
else:
|
846 |
+
logger.warning(
|
847 |
+
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
848 |
+
)
|
849 |
|
850 |
return result
|
851 |
|
|
|
871 |
logger.debug(f"Reached max depth: {max_depth}")
|
872 |
return
|
873 |
if len(visited_nodes) >= max_nodes:
|
874 |
+
# Set truncated flag when we hit the max_nodes limit
|
875 |
+
result.is_truncated = True
|
876 |
+
logger.info(
|
877 |
+
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
878 |
+
)
|
879 |
return
|
880 |
|
881 |
# Check if node already visited
|