DavIvek
commited on
Commit
·
7f20a21
1
Parent(s):
8365801
add Memgraph graph storage backend
Browse files- config.ini.example +3 -0
- examples/graph_visual_with_neo4j.py +1 -1
- examples/lightrag_openai_demo.py +1 -0
- lightrag/kg/__init__.py +3 -0
- lightrag/kg/memgraph_impl.py +423 -0
config.ini.example
CHANGED
@@ -21,3 +21,6 @@ password = your_password
|
|
21 |
database = your_database
|
22 |
workspace = default # 可选,默认为default
|
23 |
max_connections = 12
|
|
|
|
|
|
|
|
21 |
database = your_database
|
22 |
workspace = default # 可选,默认为default
|
23 |
max_connections = 12
|
24 |
+
|
25 |
+
[memgraph]
|
26 |
+
uri = bolt://localhost:7687
|
examples/graph_visual_with_neo4j.py
CHANGED
@@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100
|
|
11 |
# Neo4j connection credentials
|
12 |
NEO4J_URI = "bolt://localhost:7687"
|
13 |
NEO4J_USERNAME = "neo4j"
|
14 |
-
NEO4J_PASSWORD = "
|
15 |
|
16 |
|
17 |
def xml_to_json(xml_file):
|
|
|
11 |
# Neo4j connection credentials
|
12 |
NEO4J_URI = "bolt://localhost:7687"
|
13 |
NEO4J_USERNAME = "neo4j"
|
14 |
+
NEO4J_PASSWORD = "david123"
|
15 |
|
16 |
|
17 |
def xml_to_json(xml_file):
|
examples/lightrag_openai_demo.py
CHANGED
@@ -82,6 +82,7 @@ async def initialize_rag():
|
|
82 |
working_dir=WORKING_DIR,
|
83 |
embedding_func=openai_embed,
|
84 |
llm_model_func=gpt_4o_mini_complete,
|
|
|
85 |
)
|
86 |
|
87 |
await rag.initialize_storages()
|
|
|
82 |
working_dir=WORKING_DIR,
|
83 |
embedding_func=openai_embed,
|
84 |
llm_model_func=gpt_4o_mini_complete,
|
85 |
+
graph_storage="MemgraphStorage",
|
86 |
)
|
87 |
|
88 |
await rag.initialize_storages()
|
lightrag/kg/__init__.py
CHANGED
@@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|
15 |
"Neo4JStorage",
|
16 |
"PGGraphStorage",
|
17 |
"MongoGraphStorage",
|
|
|
18 |
# "AGEStorage",
|
19 |
# "TiDBGraphStorage",
|
20 |
# "GremlinStorage",
|
@@ -56,6 +57,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
56 |
"NetworkXStorage": [],
|
57 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
58 |
"MongoGraphStorage": [],
|
|
|
59 |
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
60 |
"AGEStorage": [
|
61 |
"AGE_POSTGRES_DB",
|
@@ -108,6 +110,7 @@ STORAGES = {
|
|
108 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
109 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
110 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
|
111 |
}
|
112 |
|
113 |
|
|
|
15 |
"Neo4JStorage",
|
16 |
"PGGraphStorage",
|
17 |
"MongoGraphStorage",
|
18 |
+
"MemgraphStorage",
|
19 |
# "AGEStorage",
|
20 |
# "TiDBGraphStorage",
|
21 |
# "GremlinStorage",
|
|
|
57 |
"NetworkXStorage": [],
|
58 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
59 |
"MongoGraphStorage": [],
|
60 |
+
"MemgraphStorage": ["MEMGRAPH_URI"],
|
61 |
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
62 |
"AGEStorage": [
|
63 |
"AGE_POSTGRES_DB",
|
|
|
110 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
111 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
112 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
113 |
+
"MemgraphStorage": ".kg.memgraph_impl",
|
114 |
}
|
115 |
|
116 |
|
lightrag/kg/memgraph_impl.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import final
|
5 |
+
import configparser
|
6 |
+
|
7 |
+
from ..utils import logger
|
8 |
+
from ..base import BaseGraphStorage
|
9 |
+
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
10 |
+
from ..constants import GRAPH_FIELD_SEP
|
11 |
+
import pipmaster as pm
|
12 |
+
|
13 |
+
if not pm.is_installed("neo4j"):
|
14 |
+
pm.install("neo4j")
|
15 |
+
|
16 |
+
from neo4j import (
|
17 |
+
AsyncGraphDatabase,
|
18 |
+
AsyncManagedTransaction,
|
19 |
+
)
|
20 |
+
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
|
23 |
+
# use the .env that is inside the current folder
|
24 |
+
load_dotenv(dotenv_path=".env", override=False)
|
25 |
+
|
26 |
+
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
27 |
+
|
28 |
+
config = configparser.ConfigParser()
|
29 |
+
config.read("config.ini", "utf-8")
|
30 |
+
|
31 |
+
@final
|
32 |
+
@dataclass
|
33 |
+
class MemgraphStorage(BaseGraphStorage):
|
34 |
+
def __init__(self, namespace, global_config, embedding_func):
|
35 |
+
super().__init__(
|
36 |
+
namespace=namespace,
|
37 |
+
global_config=global_config,
|
38 |
+
embedding_func=embedding_func,
|
39 |
+
)
|
40 |
+
self._driver = None
|
41 |
+
|
42 |
+
async def initialize(self):
|
43 |
+
URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687"))
|
44 |
+
USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback=""))
|
45 |
+
PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback=""))
|
46 |
+
DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph"))
|
47 |
+
|
48 |
+
self._driver = AsyncGraphDatabase.driver(
|
49 |
+
URI,
|
50 |
+
auth=(USERNAME, PASSWORD),
|
51 |
+
)
|
52 |
+
self._DATABASE = DATABASE
|
53 |
+
try:
|
54 |
+
async with self._driver.session(database=DATABASE) as session:
|
55 |
+
# Create index for base nodes on entity_id if it doesn't exist
|
56 |
+
try:
|
57 |
+
await session.run("""CREATE INDEX ON :base(entity_id)""")
|
58 |
+
logger.info("Created index on :base(entity_id) in Memgraph.")
|
59 |
+
except Exception as e:
|
60 |
+
# Index may already exist, which is not an error
|
61 |
+
logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}")
|
62 |
+
await session.run("RETURN 1")
|
63 |
+
logger.info(f"Connected to Memgraph at {URI}")
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
|
66 |
+
raise
|
67 |
+
|
68 |
+
async def finalize(self):
|
69 |
+
if self._driver is not None:
|
70 |
+
await self._driver.close()
|
71 |
+
self._driver = None
|
72 |
+
|
73 |
+
async def __aexit__(self, exc_type, exc, tb):
|
74 |
+
await self.finalize()
|
75 |
+
|
76 |
+
async def index_done_callback(self):
|
77 |
+
# Memgraph handles persistence automatically
|
78 |
+
pass
|
79 |
+
|
80 |
+
async def has_node(self, node_id: str) -> bool:
|
81 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
82 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
|
83 |
+
result = await session.run(query, entity_id=node_id)
|
84 |
+
single_result = await result.single()
|
85 |
+
await result.consume()
|
86 |
+
return single_result["node_exists"]
|
87 |
+
|
88 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
89 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
90 |
+
query = (
|
91 |
+
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
|
92 |
+
"RETURN COUNT(r) > 0 AS edgeExists"
|
93 |
+
)
|
94 |
+
result = await session.run(
|
95 |
+
query,
|
96 |
+
source_entity_id=source_node_id,
|
97 |
+
target_entity_id=target_node_id,
|
98 |
+
)
|
99 |
+
single_result = await result.single()
|
100 |
+
await result.consume()
|
101 |
+
return single_result["edgeExists"]
|
102 |
+
|
103 |
+
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
104 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
105 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
106 |
+
result = await session.run(query, entity_id=node_id)
|
107 |
+
records = await result.fetch(2)
|
108 |
+
await result.consume()
|
109 |
+
if records:
|
110 |
+
node = records[0]["n"]
|
111 |
+
node_dict = dict(node)
|
112 |
+
if "labels" in node_dict:
|
113 |
+
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
114 |
+
return node_dict
|
115 |
+
return None
|
116 |
+
|
117 |
+
async def get_all_labels(self) -> list[str]:
|
118 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
119 |
+
query = """
|
120 |
+
MATCH (n:base)
|
121 |
+
WHERE n.entity_id IS NOT NULL
|
122 |
+
RETURN DISTINCT n.entity_id AS label
|
123 |
+
ORDER BY label
|
124 |
+
"""
|
125 |
+
result = await session.run(query)
|
126 |
+
labels = []
|
127 |
+
async for record in result:
|
128 |
+
labels.append(record["label"])
|
129 |
+
await result.consume()
|
130 |
+
return labels
|
131 |
+
|
132 |
+
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
133 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
134 |
+
query = """
|
135 |
+
MATCH (n:base {entity_id: $entity_id})
|
136 |
+
OPTIONAL MATCH (n)-[r]-(connected:base)
|
137 |
+
WHERE connected.entity_id IS NOT NULL
|
138 |
+
RETURN n, r, connected
|
139 |
+
"""
|
140 |
+
results = await session.run(query, entity_id=source_node_id)
|
141 |
+
edges = []
|
142 |
+
async for record in results:
|
143 |
+
source_node = record["n"]
|
144 |
+
connected_node = record["connected"]
|
145 |
+
if not source_node or not connected_node:
|
146 |
+
continue
|
147 |
+
source_label = source_node.get("entity_id")
|
148 |
+
target_label = connected_node.get("entity_id")
|
149 |
+
if source_label and target_label:
|
150 |
+
edges.append((source_label, target_label))
|
151 |
+
await results.consume()
|
152 |
+
return edges
|
153 |
+
|
154 |
+
async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None:
|
155 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
156 |
+
query = """
|
157 |
+
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
|
158 |
+
RETURN properties(r) as edge_properties
|
159 |
+
"""
|
160 |
+
result = await session.run(
|
161 |
+
query,
|
162 |
+
source_entity_id=source_node_id,
|
163 |
+
target_entity_id=target_node_id,
|
164 |
+
)
|
165 |
+
records = await result.fetch(2)
|
166 |
+
await result.consume()
|
167 |
+
if records:
|
168 |
+
edge_result = dict(records[0]["edge_properties"])
|
169 |
+
for key, default_value in {
|
170 |
+
"weight": 0.0,
|
171 |
+
"source_id": None,
|
172 |
+
"description": None,
|
173 |
+
"keywords": None,
|
174 |
+
}.items():
|
175 |
+
if key not in edge_result:
|
176 |
+
edge_result[key] = default_value
|
177 |
+
return edge_result
|
178 |
+
return None
|
179 |
+
|
180 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
181 |
+
properties = node_data
|
182 |
+
entity_type = properties.get("entity_type", "base")
|
183 |
+
if "entity_id" not in properties:
|
184 |
+
raise ValueError("Memgraph: node properties must contain an 'entity_id' field")
|
185 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
186 |
+
async def execute_upsert(tx: AsyncManagedTransaction):
|
187 |
+
query = (
|
188 |
+
f"""
|
189 |
+
MERGE (n:base {{entity_id: $entity_id}})
|
190 |
+
SET n += $properties
|
191 |
+
SET n:`{entity_type}`
|
192 |
+
"""
|
193 |
+
)
|
194 |
+
result = await tx.run(query, entity_id=node_id, properties=properties)
|
195 |
+
await result.consume()
|
196 |
+
await session.execute_write(execute_upsert)
|
197 |
+
|
198 |
+
async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
|
199 |
+
edge_properties = edge_data
|
200 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
201 |
+
async def execute_upsert(tx: AsyncManagedTransaction):
|
202 |
+
query = """
|
203 |
+
MATCH (source:base {entity_id: $source_entity_id})
|
204 |
+
WITH source
|
205 |
+
MATCH (target:base {entity_id: $target_entity_id})
|
206 |
+
MERGE (source)-[r:DIRECTED]-(target)
|
207 |
+
SET r += $properties
|
208 |
+
RETURN r, source, target
|
209 |
+
"""
|
210 |
+
result = await tx.run(
|
211 |
+
query,
|
212 |
+
source_entity_id=source_node_id,
|
213 |
+
target_entity_id=target_node_id,
|
214 |
+
properties=edge_properties,
|
215 |
+
)
|
216 |
+
await result.consume()
|
217 |
+
await session.execute_write(execute_upsert)
|
218 |
+
|
219 |
+
async def delete_node(self, node_id: str) -> None:
|
220 |
+
async def _do_delete(tx: AsyncManagedTransaction):
|
221 |
+
query = """
|
222 |
+
MATCH (n:base {entity_id: $entity_id})
|
223 |
+
DETACH DELETE n
|
224 |
+
"""
|
225 |
+
result = await tx.run(query, entity_id=node_id)
|
226 |
+
await result.consume()
|
227 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
228 |
+
await session.execute_write(_do_delete)
|
229 |
+
|
230 |
+
async def remove_nodes(self, nodes: list[str]):
|
231 |
+
for node in nodes:
|
232 |
+
await self.delete_node(node)
|
233 |
+
|
234 |
+
async def remove_edges(self, edges: list[tuple[str, str]]):
|
235 |
+
for source, target in edges:
|
236 |
+
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
237 |
+
query = """
|
238 |
+
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
|
239 |
+
DELETE r
|
240 |
+
"""
|
241 |
+
result = await tx.run(
|
242 |
+
query, source_entity_id=source, target_entity_id=target
|
243 |
+
)
|
244 |
+
await result.consume()
|
245 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
246 |
+
await session.execute_write(_do_delete_edge)
|
247 |
+
|
248 |
+
async def drop(self) -> dict[str, str]:
|
249 |
+
try:
|
250 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
251 |
+
query = "MATCH (n) DETACH DELETE n"
|
252 |
+
result = await session.run(query)
|
253 |
+
await result.consume()
|
254 |
+
logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}")
|
255 |
+
return {"status": "success", "message": "data dropped"}
|
256 |
+
except Exception as e:
|
257 |
+
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
|
258 |
+
return {"status": "error", "message": str(e)}
|
259 |
+
|
260 |
+
async def node_degree(self, node_id: str) -> int:
|
261 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
262 |
+
query = """
|
263 |
+
MATCH (n:base {entity_id: $entity_id})
|
264 |
+
OPTIONAL MATCH (n)-[r]-()
|
265 |
+
RETURN COUNT(r) AS degree
|
266 |
+
"""
|
267 |
+
result = await session.run(query, entity_id=node_id)
|
268 |
+
record = await result.single()
|
269 |
+
await result.consume()
|
270 |
+
if not record:
|
271 |
+
return 0
|
272 |
+
return record["degree"]
|
273 |
+
|
274 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
275 |
+
src_degree = await self.node_degree(src_id)
|
276 |
+
trg_degree = await self.node_degree(tgt_id)
|
277 |
+
src_degree = 0 if src_degree is None else src_degree
|
278 |
+
trg_degree = 0 if trg_degree is None else trg_degree
|
279 |
+
return int(src_degree) + int(trg_degree)
|
280 |
+
|
281 |
+
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
282 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
283 |
+
query = """
|
284 |
+
UNWIND $chunk_ids AS chunk_id
|
285 |
+
MATCH (n:base)
|
286 |
+
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
|
287 |
+
RETURN DISTINCT n
|
288 |
+
"""
|
289 |
+
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
290 |
+
nodes = []
|
291 |
+
async for record in result:
|
292 |
+
node = record["n"]
|
293 |
+
node_dict = dict(node)
|
294 |
+
node_dict["id"] = node_dict.get("entity_id")
|
295 |
+
nodes.append(node_dict)
|
296 |
+
await result.consume()
|
297 |
+
return nodes
|
298 |
+
|
299 |
+
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
300 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
301 |
+
query = """
|
302 |
+
UNWIND $chunk_ids AS chunk_id
|
303 |
+
MATCH (a:base)-[r]-(b:base)
|
304 |
+
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
305 |
+
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
306 |
+
"""
|
307 |
+
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
308 |
+
edges = []
|
309 |
+
async for record in result:
|
310 |
+
edge_properties = record["properties"]
|
311 |
+
edge_properties["source"] = record["source"]
|
312 |
+
edge_properties["target"] = record["target"]
|
313 |
+
edges.append(edge_properties)
|
314 |
+
await result.consume()
|
315 |
+
return edges
|
316 |
+
|
317 |
+
async def get_knowledge_graph(
|
318 |
+
self,
|
319 |
+
node_label: str,
|
320 |
+
max_depth: int = 3,
|
321 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
322 |
+
) -> KnowledgeGraph:
|
323 |
+
result = KnowledgeGraph()
|
324 |
+
seen_nodes = set()
|
325 |
+
seen_edges = set()
|
326 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
327 |
+
if node_label == "*":
|
328 |
+
count_query = "MATCH (n) RETURN count(n) as total"
|
329 |
+
count_result = await session.run(count_query)
|
330 |
+
count_record = await count_result.single()
|
331 |
+
await count_result.consume()
|
332 |
+
if count_record and count_record["total"] > max_nodes:
|
333 |
+
result.is_truncated = True
|
334 |
+
logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}")
|
335 |
+
main_query = """
|
336 |
+
MATCH (n)
|
337 |
+
OPTIONAL MATCH (n)-[r]-()
|
338 |
+
WITH n, COALESCE(count(r), 0) AS degree
|
339 |
+
ORDER BY degree DESC
|
340 |
+
LIMIT $max_nodes
|
341 |
+
WITH collect({node: n}) AS filtered_nodes
|
342 |
+
UNWIND filtered_nodes AS node_info
|
343 |
+
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
344 |
+
OPTIONAL MATCH (a)-[r]-(b)
|
345 |
+
WHERE a IN kept_nodes AND b IN kept_nodes
|
346 |
+
RETURN filtered_nodes AS node_info,
|
347 |
+
collect(DISTINCT r) AS relationships
|
348 |
+
"""
|
349 |
+
result_set = await session.run(main_query, {"max_nodes": max_nodes})
|
350 |
+
record = await result_set.single()
|
351 |
+
await result_set.consume()
|
352 |
+
else:
|
353 |
+
# BFS fallback for Memgraph (no APOC)
|
354 |
+
from collections import deque
|
355 |
+
# Get the starting node
|
356 |
+
start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
357 |
+
node_result = await session.run(start_query, entity_id=node_label)
|
358 |
+
node_record = await node_result.single()
|
359 |
+
await node_result.consume()
|
360 |
+
if not node_record:
|
361 |
+
return result
|
362 |
+
start_node = node_record["n"]
|
363 |
+
start_node_id = start_node.get("entity_id")
|
364 |
+
queue = deque([(start_node, 0)])
|
365 |
+
visited = set()
|
366 |
+
bfs_nodes = []
|
367 |
+
while queue and len(bfs_nodes) < max_nodes:
|
368 |
+
current_node, depth = queue.popleft()
|
369 |
+
node_id = current_node.get("entity_id")
|
370 |
+
if node_id in visited:
|
371 |
+
continue
|
372 |
+
visited.add(node_id)
|
373 |
+
bfs_nodes.append(current_node)
|
374 |
+
if depth < max_depth:
|
375 |
+
# Get neighbors
|
376 |
+
neighbor_query = """
|
377 |
+
MATCH (n:base {entity_id: $entity_id})-[]-(m:base)
|
378 |
+
RETURN m
|
379 |
+
"""
|
380 |
+
neighbors_result = await session.run(neighbor_query, entity_id=node_id)
|
381 |
+
neighbors = [rec["m"] for rec in await neighbors_result.to_list()]
|
382 |
+
await neighbors_result.consume()
|
383 |
+
for neighbor in neighbors:
|
384 |
+
neighbor_id = neighbor.get("entity_id")
|
385 |
+
if neighbor_id not in visited:
|
386 |
+
queue.append((neighbor, depth + 1))
|
387 |
+
# Build subgraph
|
388 |
+
subgraph_ids = [n.get("entity_id") for n in bfs_nodes]
|
389 |
+
# Nodes
|
390 |
+
for n in bfs_nodes:
|
391 |
+
node_id = n.get("entity_id")
|
392 |
+
if node_id not in seen_nodes:
|
393 |
+
result.nodes.append(KnowledgeGraphNode(
|
394 |
+
id=node_id,
|
395 |
+
labels=[node_id],
|
396 |
+
properties=dict(n),
|
397 |
+
))
|
398 |
+
seen_nodes.add(node_id)
|
399 |
+
# Edges
|
400 |
+
if subgraph_ids:
|
401 |
+
edge_query = """
|
402 |
+
MATCH (a:base)-[r]-(b:base)
|
403 |
+
WHERE a.entity_id IN $ids AND b.entity_id IN $ids
|
404 |
+
RETURN DISTINCT r, a, b
|
405 |
+
"""
|
406 |
+
edge_result = await session.run(edge_query, ids=subgraph_ids)
|
407 |
+
async for record in edge_result:
|
408 |
+
r = record["r"]
|
409 |
+
a = record["a"]
|
410 |
+
b = record["b"]
|
411 |
+
edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}"
|
412 |
+
if edge_id not in seen_edges:
|
413 |
+
result.edges.append(KnowledgeGraphEdge(
|
414 |
+
id=edge_id,
|
415 |
+
type="DIRECTED",
|
416 |
+
source=a.get("entity_id"),
|
417 |
+
target=b.get("entity_id"),
|
418 |
+
properties=dict(r),
|
419 |
+
))
|
420 |
+
seen_edges.add(edge_id)
|
421 |
+
await edge_result.consume()
|
422 |
+
logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}")
|
423 |
+
return result
|