DavIvek
commited on
Commit
·
7c8ef0b
1
Parent(s):
040b53c
polish Memgraph implementation
Browse files- lightrag/kg/memgraph_impl.py +550 -251
lightrag/kg/memgraph_impl.py
CHANGED
@@ -89,183 +89,419 @@ class MemgraphStorage(BaseGraphStorage):
|
|
89 |
pass
|
90 |
|
91 |
async def has_node(self, node_id: str) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
async with self._driver.session(
|
93 |
database=self._DATABASE, default_access_mode="READ"
|
94 |
) as session:
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
async with self._driver.session(
|
103 |
database=self._DATABASE, default_access_mode="READ"
|
104 |
) as session:
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
async with self._driver.session(
|
120 |
database=self._DATABASE, default_access_mode="READ"
|
121 |
) as session:
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if "labels" in node_dict:
|
130 |
-
node_dict["labels"] = [
|
131 |
-
label for label in node_dict["labels"] if label != "base"
|
132 |
-
]
|
133 |
-
return node_dict
|
134 |
-
return None
|
135 |
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
async with self._driver.session(
|
138 |
database=self._DATABASE, default_access_mode="READ"
|
139 |
) as session:
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
labels.append(record["label"])
|
150 |
-
await result.consume()
|
151 |
-
return labels
|
152 |
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
async with self._driver.session(
|
155 |
database=self._DATABASE, default_access_mode="READ"
|
156 |
) as session:
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
async def get_edge(
|
178 |
self, source_node_id: str, target_node_id: str
|
179 |
) -> dict[str, str] | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
async with self._driver.session(
|
181 |
database=self._DATABASE, default_access_mode="READ"
|
182 |
) as session:
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
properties = node_data
|
209 |
-
entity_type = properties
|
210 |
if "entity_id" not in properties:
|
211 |
-
raise ValueError(
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
219 |
SET n += $properties
|
220 |
-
SET n
|
221 |
"""
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
|
|
|
|
|
|
|
226 |
|
227 |
async def upsert_edge(
|
228 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
229 |
) -> None:
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
249 |
|
250 |
-
|
|
|
|
|
|
|
251 |
|
252 |
async def delete_node(self, node_id: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
async def _do_delete(tx: AsyncManagedTransaction):
|
254 |
query = """
|
255 |
MATCH (n:base {entity_id: $entity_id})
|
256 |
DETACH DELETE n
|
257 |
"""
|
258 |
result = await tx.run(query, entity_id=node_id)
|
|
|
259 |
await result.consume()
|
260 |
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
263 |
|
264 |
async def remove_nodes(self, nodes: list[str]):
|
|
|
|
|
|
|
|
|
|
|
265 |
for node in nodes:
|
266 |
await self.delete_node(node)
|
267 |
|
268 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
for source, target in edges:
|
270 |
|
271 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
@@ -276,15 +512,32 @@ class MemgraphStorage(BaseGraphStorage):
|
|
276 |
result = await tx.run(
|
277 |
query, source_entity_id=source, target_entity_id=target
|
278 |
)
|
279 |
-
|
|
|
280 |
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
283 |
|
284 |
async def drop(self) -> dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
try:
|
286 |
async with self._driver.session(database=self._DATABASE) as session:
|
287 |
-
query = "
|
288 |
result = await session.run(query)
|
289 |
await result.consume()
|
290 |
logger.info(
|
@@ -295,30 +548,36 @@ class MemgraphStorage(BaseGraphStorage):
|
|
295 |
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
|
296 |
return {"status": "error", "message": str(e)}
|
297 |
|
298 |
-
async def node_degree(self, node_id: str) -> int:
|
299 |
-
async with self._driver.session(
|
300 |
-
database=self._DATABASE, default_access_mode="READ"
|
301 |
-
) as session:
|
302 |
-
query = """
|
303 |
-
MATCH (n:base {entity_id: $entity_id})
|
304 |
-
OPTIONAL MATCH (n)-[r]-()
|
305 |
-
RETURN COUNT(r) AS degree
|
306 |
-
"""
|
307 |
-
result = await session.run(query, entity_id=node_id)
|
308 |
-
record = await result.single()
|
309 |
-
await result.consume()
|
310 |
-
if not record:
|
311 |
-
return 0
|
312 |
-
return record["degree"]
|
313 |
-
|
314 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
src_degree = await self.node_degree(src_id)
|
316 |
trg_degree = await self.node_degree(tgt_id)
|
|
|
|
|
317 |
src_degree = 0 if src_degree is None else src_degree
|
318 |
trg_degree = 0 if trg_degree is None else trg_degree
|
319 |
-
|
|
|
|
|
320 |
|
321 |
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
async with self._driver.session(
|
323 |
database=self._DATABASE, default_access_mode="READ"
|
324 |
) as session:
|
@@ -335,10 +594,19 @@ class MemgraphStorage(BaseGraphStorage):
|
|
335 |
node_dict = dict(node)
|
336 |
node_dict["id"] = node_dict.get("entity_id")
|
337 |
nodes.append(node_dict)
|
338 |
-
|
339 |
-
|
340 |
|
341 |
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
async with self._driver.session(
|
343 |
database=self._DATABASE, default_access_mode="READ"
|
344 |
) as session:
|
@@ -364,118 +632,149 @@ class MemgraphStorage(BaseGraphStorage):
|
|
364 |
max_depth: int = 3,
|
365 |
max_nodes: int = MAX_GRAPH_NODES,
|
366 |
) -> KnowledgeGraph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
result = KnowledgeGraph()
|
368 |
seen_nodes = set()
|
369 |
seen_edges = set()
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
MATCH (n)
|
385 |
-
OPTIONAL MATCH (n)-[r]-()
|
386 |
-
WITH n, COALESCE(count(r), 0) AS degree
|
387 |
-
ORDER BY degree DESC
|
388 |
-
LIMIT $max_nodes
|
389 |
-
WITH collect({node: n}) AS filtered_nodes
|
390 |
-
UNWIND filtered_nodes AS node_info
|
391 |
-
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
392 |
-
OPTIONAL MATCH (a)-[r]-(b)
|
393 |
-
WHERE a IN kept_nodes AND b IN kept_nodes
|
394 |
-
RETURN filtered_nodes AS node_info,
|
395 |
-
collect(DISTINCT r) AS relationships
|
396 |
-
"""
|
397 |
-
result_set = await session.run(main_query, {"max_nodes": max_nodes})
|
398 |
-
record = await result_set.single()
|
399 |
-
await result_set.consume()
|
400 |
-
else:
|
401 |
-
# BFS fallback for Memgraph (no APOC)
|
402 |
-
from collections import deque
|
403 |
-
|
404 |
-
# Get the starting node
|
405 |
-
start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
406 |
-
node_result = await session.run(start_query, entity_id=node_label)
|
407 |
-
node_record = await node_result.single()
|
408 |
-
await node_result.consume()
|
409 |
-
if not node_record:
|
410 |
-
return result
|
411 |
-
start_node = node_record["n"]
|
412 |
-
queue = deque([(start_node, 0)])
|
413 |
-
visited = set()
|
414 |
-
bfs_nodes = []
|
415 |
-
while queue and len(bfs_nodes) < max_nodes:
|
416 |
-
current_node, depth = queue.popleft()
|
417 |
-
node_id = current_node.get("entity_id")
|
418 |
-
if node_id in visited:
|
419 |
-
continue
|
420 |
-
visited.add(node_id)
|
421 |
-
bfs_nodes.append(current_node)
|
422 |
-
if depth < max_depth:
|
423 |
-
# Get neighbors
|
424 |
-
neighbor_query = """
|
425 |
-
MATCH (n:base {entity_id: $entity_id})-[]-(m:base)
|
426 |
-
RETURN m
|
427 |
-
"""
|
428 |
-
neighbors_result = await session.run(
|
429 |
-
neighbor_query, entity_id=node_id
|
430 |
-
)
|
431 |
-
neighbors = [
|
432 |
-
rec["m"] for rec in await neighbors_result.to_list()
|
433 |
-
]
|
434 |
-
await neighbors_result.consume()
|
435 |
-
for neighbor in neighbors:
|
436 |
-
neighbor_id = neighbor.get("entity_id")
|
437 |
-
if neighbor_id not in visited:
|
438 |
-
queue.append((neighbor, depth + 1))
|
439 |
-
# Build subgraph
|
440 |
-
subgraph_ids = [n.get("entity_id") for n in bfs_nodes]
|
441 |
-
# Nodes
|
442 |
-
for n in bfs_nodes:
|
443 |
-
node_id = n.get("entity_id")
|
444 |
-
if node_id not in seen_nodes:
|
445 |
-
result.nodes.append(
|
446 |
-
KnowledgeGraphNode(
|
447 |
-
id=node_id,
|
448 |
-
labels=[node_id],
|
449 |
-
properties=dict(n),
|
450 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
)
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
"""
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
)
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
pass
|
90 |
|
91 |
async def has_node(self, node_id: str) -> bool:
|
92 |
+
"""
|
93 |
+
Check if a node exists in the graph.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
node_id: The ID of the node to check.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
bool: True if the node exists, False otherwise.
|
100 |
+
|
101 |
+
Raises:
|
102 |
+
Exception: If there is an error checking the node existence.
|
103 |
+
"""
|
104 |
async with self._driver.session(
|
105 |
database=self._DATABASE, default_access_mode="READ"
|
106 |
) as session:
|
107 |
+
try:
|
108 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
|
109 |
+
result = await session.run(query, entity_id=node_id)
|
110 |
+
single_result = await result.single()
|
111 |
+
await result.consume() # Ensure result is fully consumed
|
112 |
+
return single_result["node_exists"]
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
|
115 |
+
await result.consume() # Ensure the result is consumed even on error
|
116 |
+
raise
|
117 |
|
118 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
119 |
+
"""
|
120 |
+
Check if an edge exists between two nodes in the graph.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
source_node_id: The ID of the source node.
|
124 |
+
target_node_id: The ID of the target node.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
bool: True if the edge exists, False otherwise.
|
128 |
+
|
129 |
+
Raises:
|
130 |
+
Exception: If there is an error checking the edge existence.
|
131 |
+
"""
|
132 |
async with self._driver.session(
|
133 |
database=self._DATABASE, default_access_mode="READ"
|
134 |
) as session:
|
135 |
+
try:
|
136 |
+
query = (
|
137 |
+
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
|
138 |
+
"RETURN COUNT(r) > 0 AS edgeExists"
|
139 |
+
)
|
140 |
+
result = await session.run(
|
141 |
+
query,
|
142 |
+
source_entity_id=source_node_id,
|
143 |
+
target_entity_id=target_node_id,
|
144 |
+
)
|
145 |
+
single_result = await result.single()
|
146 |
+
await result.consume() # Ensure result is fully consumed
|
147 |
+
return single_result["edgeExists"]
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(
|
150 |
+
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
151 |
+
)
|
152 |
+
await result.consume() # Ensure the result is consumed even on error
|
153 |
+
raise
|
154 |
|
155 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
156 |
+
"""Get node by its label identifier, return only node properties
|
157 |
+
|
158 |
+
Args:
|
159 |
+
node_id: The node label to look up
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
dict: Node properties if found
|
163 |
+
None: If node not found
|
164 |
+
|
165 |
+
Raises:
|
166 |
+
Exception: If there is an error executing the query
|
167 |
+
"""
|
168 |
async with self._driver.session(
|
169 |
database=self._DATABASE, default_access_mode="READ"
|
170 |
) as session:
|
171 |
+
try:
|
172 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
173 |
+
result = await session.run(query, entity_id=node_id)
|
174 |
+
try:
|
175 |
+
records = await result.fetch(
|
176 |
+
2
|
177 |
+
) # Get 2 records for duplication check
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
if len(records) > 1:
|
180 |
+
logger.warning(
|
181 |
+
f"Multiple nodes found with label '{node_id}'. Using first node."
|
182 |
+
)
|
183 |
+
if records:
|
184 |
+
node = records[0]["n"]
|
185 |
+
node_dict = dict(node)
|
186 |
+
# Remove base label from labels list if it exists
|
187 |
+
if "labels" in node_dict:
|
188 |
+
node_dict["labels"] = [
|
189 |
+
label
|
190 |
+
for label in node_dict["labels"]
|
191 |
+
if label != "base"
|
192 |
+
]
|
193 |
+
return node_dict
|
194 |
+
return None
|
195 |
+
finally:
|
196 |
+
await result.consume() # Ensure result is fully consumed
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
199 |
+
raise
|
200 |
+
|
201 |
+
async def node_degree(self, node_id: str) -> int:
|
202 |
+
"""Get the degree (number of relationships) of a node with the given label.
|
203 |
+
If multiple nodes have the same label, returns the degree of the first node.
|
204 |
+
If no node is found, returns 0.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
node_id: The label of the node
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
int: The number of relationships the node has, or 0 if no node found
|
211 |
+
|
212 |
+
Raises:
|
213 |
+
Exception: If there is an error executing the query
|
214 |
+
"""
|
215 |
async with self._driver.session(
|
216 |
database=self._DATABASE, default_access_mode="READ"
|
217 |
) as session:
|
218 |
+
try:
|
219 |
+
query = """
|
220 |
+
MATCH (n:base {entity_id: $entity_id})
|
221 |
+
OPTIONAL MATCH (n)-[r]-()
|
222 |
+
RETURN COUNT(r) AS degree
|
223 |
+
"""
|
224 |
+
result = await session.run(query, entity_id=node_id)
|
225 |
+
try:
|
226 |
+
record = await result.single()
|
|
|
|
|
|
|
227 |
|
228 |
+
if not record:
|
229 |
+
logger.warning(f"No node found with label '{node_id}'")
|
230 |
+
return 0
|
231 |
+
|
232 |
+
degree = record["degree"]
|
233 |
+
return degree
|
234 |
+
finally:
|
235 |
+
await result.consume() # Ensure result is fully consumed
|
236 |
+
except Exception as e:
|
237 |
+
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
238 |
+
raise
|
239 |
+
|
240 |
+
async def get_all_labels(self) -> list[str]:
|
241 |
+
"""
|
242 |
+
Get all existing node labels in the database
|
243 |
+
Returns:
|
244 |
+
["Person", "Company", ...] # Alphabetically sorted label list
|
245 |
+
|
246 |
+
Raises:
|
247 |
+
Exception: If there is an error executing the query
|
248 |
+
"""
|
249 |
async with self._driver.session(
|
250 |
database=self._DATABASE, default_access_mode="READ"
|
251 |
) as session:
|
252 |
+
try:
|
253 |
+
query = """
|
254 |
+
MATCH (n:base)
|
255 |
+
WHERE n.entity_id IS NOT NULL
|
256 |
+
RETURN DISTINCT n.entity_id AS label
|
257 |
+
ORDER BY label
|
258 |
+
"""
|
259 |
+
result = await session.run(query)
|
260 |
+
labels = []
|
261 |
+
async for record in result:
|
262 |
+
labels.append(record["label"])
|
263 |
+
await result.consume()
|
264 |
+
return labels
|
265 |
+
except Exception as e:
|
266 |
+
logger.error(f"Error getting all labels: {str(e)}")
|
267 |
+
await result.consume() # Ensure the result is consumed even on error
|
268 |
+
raise
|
269 |
+
|
270 |
+
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
271 |
+
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
source_node_id: Label of the node to get edges for
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
|
278 |
+
None: If no edges found
|
279 |
+
|
280 |
+
Raises:
|
281 |
+
Exception: If there is an error executing the query
|
282 |
+
"""
|
283 |
+
try:
|
284 |
+
async with self._driver.session(
|
285 |
+
database=self._DATABASE, default_access_mode="READ"
|
286 |
+
) as session:
|
287 |
+
try:
|
288 |
+
query = """MATCH (n:base {entity_id: $entity_id})
|
289 |
+
OPTIONAL MATCH (n)-[r]-(connected:base)
|
290 |
+
WHERE connected.entity_id IS NOT NULL
|
291 |
+
RETURN n, r, connected"""
|
292 |
+
results = await session.run(query, entity_id=source_node_id)
|
293 |
+
|
294 |
+
edges = []
|
295 |
+
async for record in results:
|
296 |
+
source_node = record["n"]
|
297 |
+
connected_node = record["connected"]
|
298 |
+
|
299 |
+
# Skip if either node is None
|
300 |
+
if not source_node or not connected_node:
|
301 |
+
continue
|
302 |
+
|
303 |
+
source_label = (
|
304 |
+
source_node.get("entity_id")
|
305 |
+
if source_node.get("entity_id")
|
306 |
+
else None
|
307 |
+
)
|
308 |
+
target_label = (
|
309 |
+
connected_node.get("entity_id")
|
310 |
+
if connected_node.get("entity_id")
|
311 |
+
else None
|
312 |
+
)
|
313 |
+
|
314 |
+
if source_label and target_label:
|
315 |
+
edges.append((source_label, target_label))
|
316 |
+
|
317 |
+
await results.consume() # Ensure results are consumed
|
318 |
+
return edges
|
319 |
+
except Exception as e:
|
320 |
+
logger.error(
|
321 |
+
f"Error getting edges for node {source_node_id}: {str(e)}"
|
322 |
+
)
|
323 |
+
await results.consume() # Ensure results are consumed even on error
|
324 |
+
raise
|
325 |
+
except Exception as e:
|
326 |
+
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
327 |
+
raise
|
328 |
|
329 |
async def get_edge(
|
330 |
self, source_node_id: str, target_node_id: str
|
331 |
) -> dict[str, str] | None:
|
332 |
+
"""Get edge properties between two nodes.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
source_node_id: Label of the source node
|
336 |
+
target_node_id: Label of the target node
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
dict: Edge properties if found, default properties if not found or on error
|
340 |
+
|
341 |
+
Raises:
|
342 |
+
Exception: If there is an error executing the query
|
343 |
+
"""
|
344 |
async with self._driver.session(
|
345 |
database=self._DATABASE, default_access_mode="READ"
|
346 |
) as session:
|
347 |
+
try:
|
348 |
+
query = """
|
349 |
+
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
|
350 |
+
RETURN properties(r) as edge_properties
|
351 |
+
"""
|
352 |
+
result = await session.run(
|
353 |
+
query,
|
354 |
+
source_entity_id=source_node_id,
|
355 |
+
target_entity_id=target_node_id,
|
356 |
+
)
|
357 |
+
records = await result.fetch(2)
|
358 |
+
await result.consume()
|
359 |
+
if records:
|
360 |
+
edge_result = dict(records[0]["edge_properties"])
|
361 |
+
for key, default_value in {
|
362 |
+
"weight": 0.0,
|
363 |
+
"source_id": None,
|
364 |
+
"description": None,
|
365 |
+
"keywords": None,
|
366 |
+
}.items():
|
367 |
+
if key not in edge_result:
|
368 |
+
edge_result[key] = default_value
|
369 |
+
logger.warning(
|
370 |
+
f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
|
371 |
+
)
|
372 |
+
return edge_result
|
373 |
+
return None
|
374 |
+
except Exception as e:
|
375 |
+
logger.error(
|
376 |
+
f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
377 |
+
)
|
378 |
+
await result.consume() # Ensure the result is consumed even on error
|
379 |
+
raise
|
380 |
|
381 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
382 |
+
"""
|
383 |
+
Upsert a node in the Neo4j database.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
node_id: The unique identifier for the node (used as label)
|
387 |
+
node_data: Dictionary of node properties
|
388 |
+
"""
|
389 |
properties = node_data
|
390 |
+
entity_type = properties["entity_type"]
|
391 |
if "entity_id" not in properties:
|
392 |
+
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
393 |
+
|
394 |
+
try:
|
395 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
396 |
+
|
397 |
+
async def execute_upsert(tx: AsyncManagedTransaction):
|
398 |
+
query = (
|
399 |
+
"""
|
400 |
+
MERGE (n:base {entity_id: $entity_id})
|
401 |
SET n += $properties
|
402 |
+
SET n:`%s`
|
403 |
"""
|
404 |
+
% entity_type
|
405 |
+
)
|
406 |
+
result = await tx.run(
|
407 |
+
query, entity_id=node_id, properties=properties
|
408 |
+
)
|
409 |
+
await result.consume() # Ensure result is fully consumed
|
410 |
|
411 |
+
await session.execute_write(execute_upsert)
|
412 |
+
except Exception as e:
|
413 |
+
logger.error(f"Error during upsert: {str(e)}")
|
414 |
+
raise
|
415 |
|
416 |
async def upsert_edge(
|
417 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
418 |
) -> None:
|
419 |
+
"""
|
420 |
+
Upsert an edge and its properties between two nodes identified by their labels.
|
421 |
+
Ensures both source and target nodes exist and are unique before creating the edge.
|
422 |
+
Uses entity_id property to uniquely identify nodes.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
source_node_id (str): Label of the source node (used as identifier)
|
426 |
+
target_node_id (str): Label of the target node (used as identifier)
|
427 |
+
edge_data (dict): Dictionary of properties to set on the edge
|
428 |
+
|
429 |
+
Raises:
|
430 |
+
Exception: If there is an error executing the query
|
431 |
+
"""
|
432 |
+
try:
|
433 |
+
edge_properties = edge_data
|
434 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
435 |
|
436 |
+
async def execute_upsert(tx: AsyncManagedTransaction):
|
437 |
+
query = """
|
438 |
+
MATCH (source:base {entity_id: $source_entity_id})
|
439 |
+
WITH source
|
440 |
+
MATCH (target:base {entity_id: $target_entity_id})
|
441 |
+
MERGE (source)-[r:DIRECTED]-(target)
|
442 |
+
SET r += $properties
|
443 |
+
RETURN r, source, target
|
444 |
+
"""
|
445 |
+
result = await tx.run(
|
446 |
+
query,
|
447 |
+
source_entity_id=source_node_id,
|
448 |
+
target_entity_id=target_node_id,
|
449 |
+
properties=edge_properties,
|
450 |
+
)
|
451 |
+
try:
|
452 |
+
await result.fetch(2)
|
453 |
+
finally:
|
454 |
+
await result.consume() # Ensure result is consumed
|
455 |
|
456 |
+
await session.execute_write(execute_upsert)
|
457 |
+
except Exception as e:
|
458 |
+
logger.error(f"Error during edge upsert: {str(e)}")
|
459 |
+
raise
|
460 |
|
461 |
async def delete_node(self, node_id: str) -> None:
|
462 |
+
"""Delete a node with the specified label
|
463 |
+
|
464 |
+
Args:
|
465 |
+
node_id: The label of the node to delete
|
466 |
+
|
467 |
+
Raises:
|
468 |
+
Exception: If there is an error executing the query
|
469 |
+
"""
|
470 |
+
|
471 |
async def _do_delete(tx: AsyncManagedTransaction):
|
472 |
query = """
|
473 |
MATCH (n:base {entity_id: $entity_id})
|
474 |
DETACH DELETE n
|
475 |
"""
|
476 |
result = await tx.run(query, entity_id=node_id)
|
477 |
+
logger.debug(f"Deleted node with label {node_id}")
|
478 |
await result.consume()
|
479 |
|
480 |
+
try:
|
481 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
482 |
+
await session.execute_write(_do_delete)
|
483 |
+
except Exception as e:
|
484 |
+
logger.error(f"Error during node deletion: {str(e)}")
|
485 |
+
raise
|
486 |
|
487 |
async def remove_nodes(self, nodes: list[str]):
|
488 |
+
"""Delete multiple nodes
|
489 |
+
|
490 |
+
Args:
|
491 |
+
nodes: List of node labels to be deleted
|
492 |
+
"""
|
493 |
for node in nodes:
|
494 |
await self.delete_node(node)
|
495 |
|
496 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
497 |
+
"""Delete multiple edges
|
498 |
+
|
499 |
+
Args:
|
500 |
+
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
501 |
+
|
502 |
+
Raises:
|
503 |
+
Exception: If there is an error executing the query
|
504 |
+
"""
|
505 |
for source, target in edges:
|
506 |
|
507 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
|
|
512 |
result = await tx.run(
|
513 |
query, source_entity_id=source, target_entity_id=target
|
514 |
)
|
515 |
+
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
516 |
+
await result.consume() # Ensure result is fully consumed
|
517 |
|
518 |
+
try:
|
519 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
520 |
+
await session.execute_write(_do_delete_edge)
|
521 |
+
except Exception as e:
|
522 |
+
logger.error(f"Error during edge deletion: {str(e)}")
|
523 |
+
raise
|
524 |
|
525 |
async def drop(self) -> dict[str, str]:
|
526 |
+
"""Drop all data from storage and clean up resources
|
527 |
+
|
528 |
+
This method will delete all nodes and relationships in the Neo4j database.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
dict[str, str]: Operation status and message
|
532 |
+
- On success: {"status": "success", "message": "data dropped"}
|
533 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
534 |
+
|
535 |
+
Raises:
|
536 |
+
Exception: If there is an error executing the query
|
537 |
+
"""
|
538 |
try:
|
539 |
async with self._driver.session(database=self._DATABASE) as session:
|
540 |
+
query = "DROP GRAPH"
|
541 |
result = await session.run(query)
|
542 |
await result.consume()
|
543 |
logger.info(
|
|
|
548 |
logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
|
549 |
return {"status": "error", "message": str(e)}
|
550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
552 |
+
"""Get the total degree (sum of relationships) of two nodes.
|
553 |
+
|
554 |
+
Args:
|
555 |
+
src_id: Label of the source node
|
556 |
+
tgt_id: Label of the target node
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
int: Sum of the degrees of both nodes
|
560 |
+
"""
|
561 |
src_degree = await self.node_degree(src_id)
|
562 |
trg_degree = await self.node_degree(tgt_id)
|
563 |
+
|
564 |
+
# Convert None to 0 for addition
|
565 |
src_degree = 0 if src_degree is None else src_degree
|
566 |
trg_degree = 0 if trg_degree is None else trg_degree
|
567 |
+
|
568 |
+
degrees = int(src_degree) + int(trg_degree)
|
569 |
+
return degrees
|
570 |
|
571 |
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
572 |
+
"""Get all nodes that are associated with the given chunk_ids.
|
573 |
+
|
574 |
+
Args:
|
575 |
+
chunk_ids: List of chunk IDs to find associated nodes for
|
576 |
+
|
577 |
+
Returns:
|
578 |
+
list[dict]: A list of nodes, where each node is a dictionary of its properties.
|
579 |
+
An empty list if no matching nodes are found.
|
580 |
+
"""
|
581 |
async with self._driver.session(
|
582 |
database=self._DATABASE, default_access_mode="READ"
|
583 |
) as session:
|
|
|
594 |
node_dict = dict(node)
|
595 |
node_dict["id"] = node_dict.get("entity_id")
|
596 |
nodes.append(node_dict)
|
597 |
+
await result.consume()
|
598 |
+
return nodes
|
599 |
|
600 |
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
601 |
+
"""Get all edges that are associated with the given chunk_ids.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
chunk_ids: List of chunk IDs to find associated edges for
|
605 |
+
|
606 |
+
Returns:
|
607 |
+
list[dict]: A list of edges, where each edge is a dictionary of its properties.
|
608 |
+
An empty list if no matching edges are found.
|
609 |
+
"""
|
610 |
async with self._driver.session(
|
611 |
database=self._DATABASE, default_access_mode="READ"
|
612 |
) as session:
|
|
|
632 |
max_depth: int = 3,
|
633 |
max_nodes: int = MAX_GRAPH_NODES,
|
634 |
) -> KnowledgeGraph:
|
635 |
+
"""
|
636 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
node_label: Label of the starting node, * means all nodes
|
640 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
641 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
642 |
+
|
643 |
+
Returns:
|
644 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
645 |
+
indicating whether the graph was truncated due to max_nodes limit
|
646 |
+
|
647 |
+
Raises:
|
648 |
+
Exception: If there is an error executing the query
|
649 |
+
"""
|
650 |
result = KnowledgeGraph()
|
651 |
seen_nodes = set()
|
652 |
seen_edges = set()
|
653 |
+
try:
|
654 |
+
async with self._driver.session(
|
655 |
+
database=self._DATABASE, default_access_mode="READ"
|
656 |
+
) as session:
|
657 |
+
if node_label == "*":
|
658 |
+
count_query = "MATCH (n) RETURN count(n) as total"
|
659 |
+
count_result = None
|
660 |
+
try:
|
661 |
+
count_result = await session.run(count_query)
|
662 |
+
count_record = await count_result.single()
|
663 |
+
if count_record and count_record["total"] > max_nodes:
|
664 |
+
result.is_truncated = True
|
665 |
+
logger.info(
|
666 |
+
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
)
|
668 |
+
finally:
|
669 |
+
if count_result:
|
670 |
+
await count_result.consume()
|
671 |
+
|
672 |
+
# Run the main query to get nodes with highest degree
|
673 |
+
main_query = """
|
674 |
+
MATCH (n)
|
675 |
+
OPTIONAL MATCH (n)-[r]-()
|
676 |
+
WITH n, COALESCE(count(r), 0) AS degree
|
677 |
+
ORDER BY degree DESC
|
678 |
+
LIMIT $max_nodes
|
679 |
+
WITH collect({node: n}) AS filtered_nodes
|
680 |
+
UNWIND filtered_nodes AS node_info
|
681 |
+
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
682 |
+
OPTIONAL MATCH (a)-[r]-(b)
|
683 |
+
WHERE a IN kept_nodes AND b IN kept_nodes
|
684 |
+
RETURN filtered_nodes AS node_info,
|
685 |
+
collect(DISTINCT r) AS relationships
|
686 |
+
"""
|
687 |
+
result_set = None
|
688 |
+
try:
|
689 |
+
result_set = await session.run(
|
690 |
+
main_query, {"max_nodes": max_nodes}
|
691 |
)
|
692 |
+
record = await result_set.single()
|
693 |
+
finally:
|
694 |
+
if result_set:
|
695 |
+
await result_set.consume()
|
696 |
+
|
697 |
+
else:
|
698 |
+
bfs_query = """
|
699 |
+
MATCH (start) WHERE start.entity_id = $entity_id
|
700 |
+
WITH start
|
701 |
+
CALL {
|
702 |
+
WITH start
|
703 |
+
MATCH path = (start)-[*0..$max_depth]-(node)
|
704 |
+
WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
|
705 |
+
UNWIND path_nodes AS n
|
706 |
+
WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
|
707 |
+
WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
|
708 |
+
RETURN all_nodes, all_rels
|
709 |
+
}
|
710 |
+
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
|
711 |
+
|
712 |
+
// Apply node limiting here
|
713 |
+
WITH CASE
|
714 |
+
WHEN total_nodes <= $max_nodes THEN nodes
|
715 |
+
ELSE nodes[0..$max_nodes]
|
716 |
+
END AS limited_nodes,
|
717 |
+
relationships,
|
718 |
+
total_nodes,
|
719 |
+
total_nodes > $max_nodes AS is_truncated
|
720 |
+
UNWIND limited_nodes AS node
|
721 |
+
WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated
|
722 |
+
RETURN node_info, relationships, total_nodes, is_truncated
|
723 |
"""
|
724 |
+
result_set = None
|
725 |
+
try:
|
726 |
+
result_set = await session.run(
|
727 |
+
bfs_query,
|
728 |
+
{
|
729 |
+
"entity_id": node_label,
|
730 |
+
"max_depth": max_depth,
|
731 |
+
"max_nodes": max_nodes,
|
732 |
+
},
|
733 |
+
)
|
734 |
+
record = await result_set.single()
|
735 |
+
if not record:
|
736 |
+
logger.debug(f"No record found for node {node_label}")
|
737 |
+
return result
|
738 |
+
|
739 |
+
for node_info in record["node_info"]:
|
740 |
+
node = node_info["node"]
|
741 |
+
node_id = node.id
|
742 |
+
if node_id not in seen_nodes:
|
743 |
+
seen_nodes.add(node_id)
|
744 |
+
result.nodes.append(
|
745 |
+
KnowledgeGraphNode(
|
746 |
+
id=f"{node_id}",
|
747 |
+
labels=[node.get("entity_id")],
|
748 |
+
properties=dict(node),
|
749 |
+
)
|
750 |
)
|
751 |
+
|
752 |
+
for rel in record["relationships"]:
|
753 |
+
edge_id = rel.id
|
754 |
+
if edge_id not in seen_edges:
|
755 |
+
seen_edges.add(edge_id)
|
756 |
+
start = rel.start_node
|
757 |
+
end = rel.end_node
|
758 |
+
result.edges.append(
|
759 |
+
KnowledgeGraphEdge(
|
760 |
+
id=f"{edge_id}",
|
761 |
+
type=rel.type,
|
762 |
+
source=f"{start.id}",
|
763 |
+
target=f"{end.id}",
|
764 |
+
properties=dict(rel),
|
765 |
+
)
|
766 |
+
)
|
767 |
+
|
768 |
+
logger.info(
|
769 |
+
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
770 |
+
)
|
771 |
+
|
772 |
+
return result
|
773 |
+
|
774 |
+
finally:
|
775 |
+
if result_set:
|
776 |
+
await result_set.consume()
|
777 |
+
|
778 |
+
except Exception as e:
|
779 |
+
logger.error(f"Error getting knowledge graph: {str(e)}")
|
780 |
+
return result
|