DavIvek
commited on
Commit
·
be6f7dd
1
Parent(s):
5cc69d3
run pre-commit
Browse files- lightrag/kg/memgraph_impl.py +82 -31
lightrag/kg/memgraph_impl.py
CHANGED
@@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage):
|
|
73 |
# Create index for base nodes on entity_id if it doesn't exist
|
74 |
try:
|
75 |
workspace_label = self._get_workspace_label()
|
76 |
-
await session.run(
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
except Exception as e:
|
79 |
# Index may already exist, which is not an error
|
80 |
logger.warning(
|
@@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
112 |
Exception: If there is an error checking the node existence.
|
113 |
"""
|
114 |
if self._driver is None:
|
115 |
-
raise RuntimeError(
|
|
|
|
|
116 |
async with self._driver.session(
|
117 |
database=self._DATABASE, default_access_mode="READ"
|
118 |
) as session:
|
@@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
122 |
result = await session.run(query, entity_id=node_id)
|
123 |
single_result = await result.single()
|
124 |
await result.consume() # Ensure result is fully consumed
|
125 |
-
return
|
|
|
|
|
126 |
except Exception as e:
|
127 |
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
|
128 |
await result.consume() # Ensure the result is consumed even on error
|
@@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
143 |
Exception: If there is an error checking the edge existence.
|
144 |
"""
|
145 |
if self._driver is None:
|
146 |
-
raise RuntimeError(
|
|
|
|
|
147 |
async with self._driver.session(
|
148 |
database=self._DATABASE, default_access_mode="READ"
|
149 |
) as session:
|
@@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage):
|
|
153 |
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
154 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
155 |
)
|
156 |
-
result = await session.run(
|
|
|
|
|
|
|
|
|
157 |
single_result = await result.single()
|
158 |
await result.consume() # Ensure result is fully consumed
|
159 |
-
return
|
|
|
|
|
160 |
except Exception as e:
|
161 |
logger.error(
|
162 |
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
@@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage):
|
|
178 |
Exception: If there is an error executing the query
|
179 |
"""
|
180 |
if self._driver is None:
|
181 |
-
raise RuntimeError(
|
|
|
|
|
182 |
async with self._driver.session(
|
183 |
database=self._DATABASE, default_access_mode="READ"
|
184 |
) as session:
|
185 |
try:
|
186 |
workspace_label = self._get_workspace_label()
|
187 |
-
query =
|
|
|
|
|
188 |
result = await session.run(query, entity_id=node_id)
|
189 |
try:
|
190 |
records = await result.fetch(
|
@@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
228 |
Exception: If there is an error executing the query
|
229 |
"""
|
230 |
if self._driver is None:
|
231 |
-
raise RuntimeError(
|
|
|
|
|
232 |
async with self._driver.session(
|
233 |
database=self._DATABASE, default_access_mode="READ"
|
234 |
) as session:
|
@@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
265 |
Exception: If there is an error executing the query
|
266 |
"""
|
267 |
if self._driver is None:
|
268 |
-
raise RuntimeError(
|
|
|
|
|
269 |
async with self._driver.session(
|
270 |
database=self._DATABASE, default_access_mode="READ"
|
271 |
) as session:
|
@@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
302 |
Exception: If there is an error executing the query
|
303 |
"""
|
304 |
if self._driver is None:
|
305 |
-
raise RuntimeError(
|
|
|
|
|
306 |
try:
|
307 |
async with self._driver.session(
|
308 |
database=self._DATABASE, default_access_mode="READ"
|
@@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
366 |
Exception: If there is an error executing the query
|
367 |
"""
|
368 |
if self._driver is None:
|
369 |
-
raise RuntimeError(
|
|
|
|
|
370 |
async with self._driver.session(
|
371 |
database=self._DATABASE, default_access_mode="READ"
|
372 |
) as session:
|
@@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
414 |
node_data: Dictionary of node properties
|
415 |
"""
|
416 |
if self._driver is None:
|
417 |
-
raise RuntimeError(
|
|
|
|
|
418 |
properties = node_data
|
419 |
entity_type = properties["entity_type"]
|
420 |
if "entity_id" not in properties:
|
@@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage):
|
|
423 |
try:
|
424 |
async with self._driver.session(database=self._DATABASE) as session:
|
425 |
workspace_label = self._get_workspace_label()
|
|
|
426 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
427 |
-
query =
|
428 |
-
f"""
|
429 |
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
430 |
SET n += $properties
|
431 |
SET n:`{entity_type}`
|
432 |
"""
|
433 |
-
)
|
434 |
result = await tx.run(
|
435 |
query, entity_id=node_id, properties=properties
|
436 |
)
|
@@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
458 |
Exception: If there is an error executing the query
|
459 |
"""
|
460 |
if self._driver is None:
|
461 |
-
raise RuntimeError(
|
|
|
|
|
462 |
try:
|
463 |
edge_properties = edge_data
|
464 |
async with self._driver.session(database=self._DATABASE) as session:
|
@@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
499 |
Exception: If there is an error executing the query
|
500 |
"""
|
501 |
if self._driver is None:
|
502 |
-
raise RuntimeError(
|
|
|
|
|
503 |
|
504 |
async def _do_delete(tx: AsyncManagedTransaction):
|
505 |
workspace_label = self._get_workspace_label()
|
@@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
525 |
nodes: List of node labels to be deleted
|
526 |
"""
|
527 |
if self._driver is None:
|
528 |
-
raise RuntimeError(
|
|
|
|
|
529 |
for node in nodes:
|
530 |
await self.delete_node(node)
|
531 |
|
@@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
539 |
Exception: If there is an error executing the query
|
540 |
"""
|
541 |
if self._driver is None:
|
542 |
-
raise RuntimeError(
|
|
|
|
|
543 |
for source, target in edges:
|
544 |
|
545 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
@@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage):
|
|
575 |
Exception: If there is an error executing the query
|
576 |
"""
|
577 |
if self._driver is None:
|
578 |
-
raise RuntimeError(
|
|
|
|
|
579 |
try:
|
580 |
async with self._driver.session(database=self._DATABASE) as session:
|
581 |
workspace_label = self._get_workspace_label()
|
582 |
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
583 |
result = await session.run(query)
|
584 |
await result.consume()
|
585 |
-
logger.info(
|
|
|
|
|
586 |
return {"status": "success", "message": "workspace data dropped"}
|
587 |
except Exception as e:
|
588 |
-
logger.error(
|
|
|
|
|
589 |
return {"status": "error", "message": str(e)}
|
590 |
|
591 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
@@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
599 |
int: Sum of the degrees of both nodes
|
600 |
"""
|
601 |
if self._driver is None:
|
602 |
-
raise RuntimeError(
|
|
|
|
|
603 |
src_degree = await self.node_degree(src_id)
|
604 |
trg_degree = await self.node_degree(tgt_id)
|
605 |
|
@@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
621 |
An empty list if no matching nodes are found.
|
622 |
"""
|
623 |
if self._driver is None:
|
624 |
-
raise RuntimeError(
|
|
|
|
|
625 |
workspace_label = self._get_workspace_label()
|
626 |
async with self._driver.session(
|
627 |
database=self._DATABASE, default_access_mode="READ"
|
@@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
653 |
An empty list if no matching edges are found.
|
654 |
"""
|
655 |
if self._driver is None:
|
656 |
-
raise RuntimeError(
|
|
|
|
|
657 |
workspace_label = self._get_workspace_label()
|
658 |
async with self._driver.session(
|
659 |
database=self._DATABASE, default_access_mode="READ"
|
@@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage):
|
|
701 |
Exception: If there is an error executing the query
|
702 |
"""
|
703 |
if self._driver is None:
|
704 |
-
raise RuntimeError(
|
|
|
|
|
705 |
|
706 |
result = KnowledgeGraph()
|
707 |
seen_nodes = set()
|
@@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|
761 |
|
762 |
else:
|
763 |
bfs_query = f"""
|
764 |
-
MATCH (start:`{workspace_label}`)
|
765 |
WHERE start.entity_id = $entity_id
|
766 |
WITH start
|
767 |
CALL {{
|
@@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|
774 |
RETURN all_nodes, all_rels
|
775 |
}}
|
776 |
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
|
777 |
-
WITH
|
778 |
CASE
|
779 |
WHEN total_nodes <= {max_nodes} THEN nodes
|
780 |
ELSE nodes[0..{max_nodes}]
|
@@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|
782 |
relationships,
|
783 |
total_nodes,
|
784 |
total_nodes > {max_nodes} AS is_truncated
|
785 |
-
RETURN
|
786 |
[node IN limited_nodes | {{node: node}}] AS node_info,
|
787 |
relationships,
|
788 |
total_nodes,
|
|
|
73 |
# Create index for base nodes on entity_id if it doesn't exist
|
74 |
try:
|
75 |
workspace_label = self._get_workspace_label()
|
76 |
+
await session.run(
|
77 |
+
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
|
78 |
+
)
|
79 |
+
logger.info(
|
80 |
+
f"Created index on :{workspace_label}(entity_id) in Memgraph."
|
81 |
+
)
|
82 |
except Exception as e:
|
83 |
# Index may already exist, which is not an error
|
84 |
logger.warning(
|
|
|
116 |
Exception: If there is an error checking the node existence.
|
117 |
"""
|
118 |
if self._driver is None:
|
119 |
+
raise RuntimeError(
|
120 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
121 |
+
)
|
122 |
async with self._driver.session(
|
123 |
database=self._DATABASE, default_access_mode="READ"
|
124 |
) as session:
|
|
|
128 |
result = await session.run(query, entity_id=node_id)
|
129 |
single_result = await result.single()
|
130 |
await result.consume() # Ensure result is fully consumed
|
131 |
+
return (
|
132 |
+
single_result["node_exists"] if single_result is not None else False
|
133 |
+
)
|
134 |
except Exception as e:
|
135 |
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
|
136 |
await result.consume() # Ensure the result is consumed even on error
|
|
|
151 |
Exception: If there is an error checking the edge existence.
|
152 |
"""
|
153 |
if self._driver is None:
|
154 |
+
raise RuntimeError(
|
155 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
156 |
+
)
|
157 |
async with self._driver.session(
|
158 |
database=self._DATABASE, default_access_mode="READ"
|
159 |
) as session:
|
|
|
163 |
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
164 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
165 |
)
|
166 |
+
result = await session.run(
|
167 |
+
query,
|
168 |
+
source_entity_id=source_node_id,
|
169 |
+
target_entity_id=target_node_id,
|
170 |
+
) # type: ignore
|
171 |
single_result = await result.single()
|
172 |
await result.consume() # Ensure result is fully consumed
|
173 |
+
return (
|
174 |
+
single_result["edgeExists"] if single_result is not None else False
|
175 |
+
)
|
176 |
except Exception as e:
|
177 |
logger.error(
|
178 |
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
|
|
194 |
Exception: If there is an error executing the query
|
195 |
"""
|
196 |
if self._driver is None:
|
197 |
+
raise RuntimeError(
|
198 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
199 |
+
)
|
200 |
async with self._driver.session(
|
201 |
database=self._DATABASE, default_access_mode="READ"
|
202 |
) as session:
|
203 |
try:
|
204 |
workspace_label = self._get_workspace_label()
|
205 |
+
query = (
|
206 |
+
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
|
207 |
+
)
|
208 |
result = await session.run(query, entity_id=node_id)
|
209 |
try:
|
210 |
records = await result.fetch(
|
|
|
248 |
Exception: If there is an error executing the query
|
249 |
"""
|
250 |
if self._driver is None:
|
251 |
+
raise RuntimeError(
|
252 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
253 |
+
)
|
254 |
async with self._driver.session(
|
255 |
database=self._DATABASE, default_access_mode="READ"
|
256 |
) as session:
|
|
|
287 |
Exception: If there is an error executing the query
|
288 |
"""
|
289 |
if self._driver is None:
|
290 |
+
raise RuntimeError(
|
291 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
292 |
+
)
|
293 |
async with self._driver.session(
|
294 |
database=self._DATABASE, default_access_mode="READ"
|
295 |
) as session:
|
|
|
326 |
Exception: If there is an error executing the query
|
327 |
"""
|
328 |
if self._driver is None:
|
329 |
+
raise RuntimeError(
|
330 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
331 |
+
)
|
332 |
try:
|
333 |
async with self._driver.session(
|
334 |
database=self._DATABASE, default_access_mode="READ"
|
|
|
392 |
Exception: If there is an error executing the query
|
393 |
"""
|
394 |
if self._driver is None:
|
395 |
+
raise RuntimeError(
|
396 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
397 |
+
)
|
398 |
async with self._driver.session(
|
399 |
database=self._DATABASE, default_access_mode="READ"
|
400 |
) as session:
|
|
|
442 |
node_data: Dictionary of node properties
|
443 |
"""
|
444 |
if self._driver is None:
|
445 |
+
raise RuntimeError(
|
446 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
447 |
+
)
|
448 |
properties = node_data
|
449 |
entity_type = properties["entity_type"]
|
450 |
if "entity_id" not in properties:
|
|
|
453 |
try:
|
454 |
async with self._driver.session(database=self._DATABASE) as session:
|
455 |
workspace_label = self._get_workspace_label()
|
456 |
+
|
457 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
458 |
+
query = f"""
|
|
|
459 |
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
|
460 |
SET n += $properties
|
461 |
SET n:`{entity_type}`
|
462 |
"""
|
|
|
463 |
result = await tx.run(
|
464 |
query, entity_id=node_id, properties=properties
|
465 |
)
|
|
|
487 |
Exception: If there is an error executing the query
|
488 |
"""
|
489 |
if self._driver is None:
|
490 |
+
raise RuntimeError(
|
491 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
492 |
+
)
|
493 |
try:
|
494 |
edge_properties = edge_data
|
495 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
530 |
Exception: If there is an error executing the query
|
531 |
"""
|
532 |
if self._driver is None:
|
533 |
+
raise RuntimeError(
|
534 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
535 |
+
)
|
536 |
|
537 |
async def _do_delete(tx: AsyncManagedTransaction):
|
538 |
workspace_label = self._get_workspace_label()
|
|
|
558 |
nodes: List of node labels to be deleted
|
559 |
"""
|
560 |
if self._driver is None:
|
561 |
+
raise RuntimeError(
|
562 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
563 |
+
)
|
564 |
for node in nodes:
|
565 |
await self.delete_node(node)
|
566 |
|
|
|
574 |
Exception: If there is an error executing the query
|
575 |
"""
|
576 |
if self._driver is None:
|
577 |
+
raise RuntimeError(
|
578 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
579 |
+
)
|
580 |
for source, target in edges:
|
581 |
|
582 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
|
|
612 |
Exception: If there is an error executing the query
|
613 |
"""
|
614 |
if self._driver is None:
|
615 |
+
raise RuntimeError(
|
616 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
617 |
+
)
|
618 |
try:
|
619 |
async with self._driver.session(database=self._DATABASE) as session:
|
620 |
workspace_label = self._get_workspace_label()
|
621 |
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
622 |
result = await session.run(query)
|
623 |
await result.consume()
|
624 |
+
logger.info(
|
625 |
+
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
|
626 |
+
)
|
627 |
return {"status": "success", "message": "workspace data dropped"}
|
628 |
except Exception as e:
|
629 |
+
logger.error(
|
630 |
+
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
|
631 |
+
)
|
632 |
return {"status": "error", "message": str(e)}
|
633 |
|
634 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
642 |
int: Sum of the degrees of both nodes
|
643 |
"""
|
644 |
if self._driver is None:
|
645 |
+
raise RuntimeError(
|
646 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
647 |
+
)
|
648 |
src_degree = await self.node_degree(src_id)
|
649 |
trg_degree = await self.node_degree(tgt_id)
|
650 |
|
|
|
666 |
An empty list if no matching nodes are found.
|
667 |
"""
|
668 |
if self._driver is None:
|
669 |
+
raise RuntimeError(
|
670 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
671 |
+
)
|
672 |
workspace_label = self._get_workspace_label()
|
673 |
async with self._driver.session(
|
674 |
database=self._DATABASE, default_access_mode="READ"
|
|
|
700 |
An empty list if no matching edges are found.
|
701 |
"""
|
702 |
if self._driver is None:
|
703 |
+
raise RuntimeError(
|
704 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
705 |
+
)
|
706 |
workspace_label = self._get_workspace_label()
|
707 |
async with self._driver.session(
|
708 |
database=self._DATABASE, default_access_mode="READ"
|
|
|
750 |
Exception: If there is an error executing the query
|
751 |
"""
|
752 |
if self._driver is None:
|
753 |
+
raise RuntimeError(
|
754 |
+
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
755 |
+
)
|
756 |
|
757 |
result = KnowledgeGraph()
|
758 |
seen_nodes = set()
|
|
|
812 |
|
813 |
else:
|
814 |
bfs_query = f"""
|
815 |
+
MATCH (start:`{workspace_label}`)
|
816 |
WHERE start.entity_id = $entity_id
|
817 |
WITH start
|
818 |
CALL {{
|
|
|
825 |
RETURN all_nodes, all_rels
|
826 |
}}
|
827 |
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
|
828 |
+
WITH
|
829 |
CASE
|
830 |
WHEN total_nodes <= {max_nodes} THEN nodes
|
831 |
ELSE nodes[0..{max_nodes}]
|
|
|
833 |
relationships,
|
834 |
total_nodes,
|
835 |
total_nodes > {max_nodes} AS is_truncated
|
836 |
+
RETURN
|
837 |
[node IN limited_nodes | {{node: node}}] AS node_info,
|
838 |
relationships,
|
839 |
total_nodes,
|