DavIvek commited on
Commit
be6f7dd
·
1 Parent(s): 5cc69d3

run pre-commit

Browse files
Files changed (1) hide show
  1. lightrag/kg/memgraph_impl.py +82 -31
lightrag/kg/memgraph_impl.py CHANGED
@@ -73,8 +73,12 @@ class MemgraphStorage(BaseGraphStorage):
73
  # Create index for base nodes on entity_id if it doesn't exist
74
  try:
75
  workspace_label = self._get_workspace_label()
76
- await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""")
77
- logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.")
 
 
 
 
78
  except Exception as e:
79
  # Index may already exist, which is not an error
80
  logger.warning(
@@ -112,7 +116,9 @@ class MemgraphStorage(BaseGraphStorage):
112
  Exception: If there is an error checking the node existence.
113
  """
114
  if self._driver is None:
115
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
116
  async with self._driver.session(
117
  database=self._DATABASE, default_access_mode="READ"
118
  ) as session:
@@ -122,7 +128,9 @@ class MemgraphStorage(BaseGraphStorage):
122
  result = await session.run(query, entity_id=node_id)
123
  single_result = await result.single()
124
  await result.consume() # Ensure result is fully consumed
125
- return single_result["node_exists"] if single_result is not None else False
 
 
126
  except Exception as e:
127
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
128
  await result.consume() # Ensure the result is consumed even on error
@@ -143,7 +151,9 @@ class MemgraphStorage(BaseGraphStorage):
143
  Exception: If there is an error checking the edge existence.
144
  """
145
  if self._driver is None:
146
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
147
  async with self._driver.session(
148
  database=self._DATABASE, default_access_mode="READ"
149
  ) as session:
@@ -153,10 +163,16 @@ class MemgraphStorage(BaseGraphStorage):
153
  f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
154
  "RETURN COUNT(r) > 0 AS edgeExists"
155
  )
156
- result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore
 
 
 
 
157
  single_result = await result.single()
158
  await result.consume() # Ensure result is fully consumed
159
- return single_result["edgeExists"] if single_result is not None else False
 
 
160
  except Exception as e:
161
  logger.error(
162
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@@ -178,13 +194,17 @@ class MemgraphStorage(BaseGraphStorage):
178
  Exception: If there is an error executing the query
179
  """
180
  if self._driver is None:
181
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
182
  async with self._driver.session(
183
  database=self._DATABASE, default_access_mode="READ"
184
  ) as session:
185
  try:
186
  workspace_label = self._get_workspace_label()
187
- query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
 
 
188
  result = await session.run(query, entity_id=node_id)
189
  try:
190
  records = await result.fetch(
@@ -228,7 +248,9 @@ class MemgraphStorage(BaseGraphStorage):
228
  Exception: If there is an error executing the query
229
  """
230
  if self._driver is None:
231
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
232
  async with self._driver.session(
233
  database=self._DATABASE, default_access_mode="READ"
234
  ) as session:
@@ -265,7 +287,9 @@ class MemgraphStorage(BaseGraphStorage):
265
  Exception: If there is an error executing the query
266
  """
267
  if self._driver is None:
268
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
269
  async with self._driver.session(
270
  database=self._DATABASE, default_access_mode="READ"
271
  ) as session:
@@ -302,7 +326,9 @@ class MemgraphStorage(BaseGraphStorage):
302
  Exception: If there is an error executing the query
303
  """
304
  if self._driver is None:
305
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
306
  try:
307
  async with self._driver.session(
308
  database=self._DATABASE, default_access_mode="READ"
@@ -366,7 +392,9 @@ class MemgraphStorage(BaseGraphStorage):
366
  Exception: If there is an error executing the query
367
  """
368
  if self._driver is None:
369
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
370
  async with self._driver.session(
371
  database=self._DATABASE, default_access_mode="READ"
372
  ) as session:
@@ -414,7 +442,9 @@ class MemgraphStorage(BaseGraphStorage):
414
  node_data: Dictionary of node properties
415
  """
416
  if self._driver is None:
417
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
418
  properties = node_data
419
  entity_type = properties["entity_type"]
420
  if "entity_id" not in properties:
@@ -423,14 +453,13 @@ class MemgraphStorage(BaseGraphStorage):
423
  try:
424
  async with self._driver.session(database=self._DATABASE) as session:
425
  workspace_label = self._get_workspace_label()
 
426
  async def execute_upsert(tx: AsyncManagedTransaction):
427
- query = (
428
- f"""
429
  MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
430
  SET n += $properties
431
  SET n:`{entity_type}`
432
  """
433
- )
434
  result = await tx.run(
435
  query, entity_id=node_id, properties=properties
436
  )
@@ -458,7 +487,9 @@ class MemgraphStorage(BaseGraphStorage):
458
  Exception: If there is an error executing the query
459
  """
460
  if self._driver is None:
461
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
462
  try:
463
  edge_properties = edge_data
464
  async with self._driver.session(database=self._DATABASE) as session:
@@ -499,7 +530,9 @@ class MemgraphStorage(BaseGraphStorage):
499
  Exception: If there is an error executing the query
500
  """
501
  if self._driver is None:
502
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
503
 
504
  async def _do_delete(tx: AsyncManagedTransaction):
505
  workspace_label = self._get_workspace_label()
@@ -525,7 +558,9 @@ class MemgraphStorage(BaseGraphStorage):
525
  nodes: List of node labels to be deleted
526
  """
527
  if self._driver is None:
528
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
529
  for node in nodes:
530
  await self.delete_node(node)
531
 
@@ -539,7 +574,9 @@ class MemgraphStorage(BaseGraphStorage):
539
  Exception: If there is an error executing the query
540
  """
541
  if self._driver is None:
542
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
543
  for source, target in edges:
544
 
545
  async def _do_delete_edge(tx: AsyncManagedTransaction):
@@ -575,17 +612,23 @@ class MemgraphStorage(BaseGraphStorage):
575
  Exception: If there is an error executing the query
576
  """
577
  if self._driver is None:
578
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
579
  try:
580
  async with self._driver.session(database=self._DATABASE) as session:
581
  workspace_label = self._get_workspace_label()
582
  query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
583
  result = await session.run(query)
584
  await result.consume()
585
- logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}")
 
 
586
  return {"status": "success", "message": "workspace data dropped"}
587
  except Exception as e:
588
- logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}")
 
 
589
  return {"status": "error", "message": str(e)}
590
 
591
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -599,7 +642,9 @@ class MemgraphStorage(BaseGraphStorage):
599
  int: Sum of the degrees of both nodes
600
  """
601
  if self._driver is None:
602
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
603
  src_degree = await self.node_degree(src_id)
604
  trg_degree = await self.node_degree(tgt_id)
605
 
@@ -621,7 +666,9 @@ class MemgraphStorage(BaseGraphStorage):
621
  An empty list if no matching nodes are found.
622
  """
623
  if self._driver is None:
624
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
625
  workspace_label = self._get_workspace_label()
626
  async with self._driver.session(
627
  database=self._DATABASE, default_access_mode="READ"
@@ -653,7 +700,9 @@ class MemgraphStorage(BaseGraphStorage):
653
  An empty list if no matching edges are found.
654
  """
655
  if self._driver is None:
656
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
657
  workspace_label = self._get_workspace_label()
658
  async with self._driver.session(
659
  database=self._DATABASE, default_access_mode="READ"
@@ -701,7 +750,9 @@ class MemgraphStorage(BaseGraphStorage):
701
  Exception: If there is an error executing the query
702
  """
703
  if self._driver is None:
704
- raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
 
 
705
 
706
  result = KnowledgeGraph()
707
  seen_nodes = set()
@@ -761,7 +812,7 @@ class MemgraphStorage(BaseGraphStorage):
761
 
762
  else:
763
  bfs_query = f"""
764
- MATCH (start:`{workspace_label}`)
765
  WHERE start.entity_id = $entity_id
766
  WITH start
767
  CALL {{
@@ -774,7 +825,7 @@ class MemgraphStorage(BaseGraphStorage):
774
  RETURN all_nodes, all_rels
775
  }}
776
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
777
- WITH
778
  CASE
779
  WHEN total_nodes <= {max_nodes} THEN nodes
780
  ELSE nodes[0..{max_nodes}]
@@ -782,7 +833,7 @@ class MemgraphStorage(BaseGraphStorage):
782
  relationships,
783
  total_nodes,
784
  total_nodes > {max_nodes} AS is_truncated
785
- RETURN
786
  [node IN limited_nodes | {{node: node}}] AS node_info,
787
  relationships,
788
  total_nodes,
 
73
  # Create index for base nodes on entity_id if it doesn't exist
74
  try:
75
  workspace_label = self._get_workspace_label()
76
+ await session.run(
77
+ f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
78
+ )
79
+ logger.info(
80
+ f"Created index on :{workspace_label}(entity_id) in Memgraph."
81
+ )
82
  except Exception as e:
83
  # Index may already exist, which is not an error
84
  logger.warning(
 
116
  Exception: If there is an error checking the node existence.
117
  """
118
  if self._driver is None:
119
+ raise RuntimeError(
120
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
121
+ )
122
  async with self._driver.session(
123
  database=self._DATABASE, default_access_mode="READ"
124
  ) as session:
 
128
  result = await session.run(query, entity_id=node_id)
129
  single_result = await result.single()
130
  await result.consume() # Ensure result is fully consumed
131
+ return (
132
+ single_result["node_exists"] if single_result is not None else False
133
+ )
134
  except Exception as e:
135
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
136
  await result.consume() # Ensure the result is consumed even on error
 
151
  Exception: If there is an error checking the edge existence.
152
  """
153
  if self._driver is None:
154
+ raise RuntimeError(
155
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
156
+ )
157
  async with self._driver.session(
158
  database=self._DATABASE, default_access_mode="READ"
159
  ) as session:
 
163
  f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
164
  "RETURN COUNT(r) > 0 AS edgeExists"
165
  )
166
+ result = await session.run(
167
+ query,
168
+ source_entity_id=source_node_id,
169
+ target_entity_id=target_node_id,
170
+ ) # type: ignore
171
  single_result = await result.single()
172
  await result.consume() # Ensure result is fully consumed
173
+ return (
174
+ single_result["edgeExists"] if single_result is not None else False
175
+ )
176
  except Exception as e:
177
  logger.error(
178
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
 
194
  Exception: If there is an error executing the query
195
  """
196
  if self._driver is None:
197
+ raise RuntimeError(
198
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
199
+ )
200
  async with self._driver.session(
201
  database=self._DATABASE, default_access_mode="READ"
202
  ) as session:
203
  try:
204
  workspace_label = self._get_workspace_label()
205
+ query = (
206
+ f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
207
+ )
208
  result = await session.run(query, entity_id=node_id)
209
  try:
210
  records = await result.fetch(
 
248
  Exception: If there is an error executing the query
249
  """
250
  if self._driver is None:
251
+ raise RuntimeError(
252
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
253
+ )
254
  async with self._driver.session(
255
  database=self._DATABASE, default_access_mode="READ"
256
  ) as session:
 
287
  Exception: If there is an error executing the query
288
  """
289
  if self._driver is None:
290
+ raise RuntimeError(
291
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
292
+ )
293
  async with self._driver.session(
294
  database=self._DATABASE, default_access_mode="READ"
295
  ) as session:
 
326
  Exception: If there is an error executing the query
327
  """
328
  if self._driver is None:
329
+ raise RuntimeError(
330
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
331
+ )
332
  try:
333
  async with self._driver.session(
334
  database=self._DATABASE, default_access_mode="READ"
 
392
  Exception: If there is an error executing the query
393
  """
394
  if self._driver is None:
395
+ raise RuntimeError(
396
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
397
+ )
398
  async with self._driver.session(
399
  database=self._DATABASE, default_access_mode="READ"
400
  ) as session:
 
442
  node_data: Dictionary of node properties
443
  """
444
  if self._driver is None:
445
+ raise RuntimeError(
446
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
447
+ )
448
  properties = node_data
449
  entity_type = properties["entity_type"]
450
  if "entity_id" not in properties:
 
453
  try:
454
  async with self._driver.session(database=self._DATABASE) as session:
455
  workspace_label = self._get_workspace_label()
456
+
457
  async def execute_upsert(tx: AsyncManagedTransaction):
458
+ query = f"""
 
459
  MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
460
  SET n += $properties
461
  SET n:`{entity_type}`
462
  """
 
463
  result = await tx.run(
464
  query, entity_id=node_id, properties=properties
465
  )
 
487
  Exception: If there is an error executing the query
488
  """
489
  if self._driver is None:
490
+ raise RuntimeError(
491
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
492
+ )
493
  try:
494
  edge_properties = edge_data
495
  async with self._driver.session(database=self._DATABASE) as session:
 
530
  Exception: If there is an error executing the query
531
  """
532
  if self._driver is None:
533
+ raise RuntimeError(
534
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
535
+ )
536
 
537
  async def _do_delete(tx: AsyncManagedTransaction):
538
  workspace_label = self._get_workspace_label()
 
558
  nodes: List of node labels to be deleted
559
  """
560
  if self._driver is None:
561
+ raise RuntimeError(
562
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
563
+ )
564
  for node in nodes:
565
  await self.delete_node(node)
566
 
 
574
  Exception: If there is an error executing the query
575
  """
576
  if self._driver is None:
577
+ raise RuntimeError(
578
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
579
+ )
580
  for source, target in edges:
581
 
582
  async def _do_delete_edge(tx: AsyncManagedTransaction):
 
612
  Exception: If there is an error executing the query
613
  """
614
  if self._driver is None:
615
+ raise RuntimeError(
616
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
617
+ )
618
  try:
619
  async with self._driver.session(database=self._DATABASE) as session:
620
  workspace_label = self._get_workspace_label()
621
  query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
622
  result = await session.run(query)
623
  await result.consume()
624
+ logger.info(
625
+ f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
626
+ )
627
  return {"status": "success", "message": "workspace data dropped"}
628
  except Exception as e:
629
+ logger.error(
630
+ f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
631
+ )
632
  return {"status": "error", "message": str(e)}
633
 
634
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
642
  int: Sum of the degrees of both nodes
643
  """
644
  if self._driver is None:
645
+ raise RuntimeError(
646
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
647
+ )
648
  src_degree = await self.node_degree(src_id)
649
  trg_degree = await self.node_degree(tgt_id)
650
 
 
666
  An empty list if no matching nodes are found.
667
  """
668
  if self._driver is None:
669
+ raise RuntimeError(
670
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
671
+ )
672
  workspace_label = self._get_workspace_label()
673
  async with self._driver.session(
674
  database=self._DATABASE, default_access_mode="READ"
 
700
  An empty list if no matching edges are found.
701
  """
702
  if self._driver is None:
703
+ raise RuntimeError(
704
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
705
+ )
706
  workspace_label = self._get_workspace_label()
707
  async with self._driver.session(
708
  database=self._DATABASE, default_access_mode="READ"
 
750
  Exception: If there is an error executing the query
751
  """
752
  if self._driver is None:
753
+ raise RuntimeError(
754
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
755
+ )
756
 
757
  result = KnowledgeGraph()
758
  seen_nodes = set()
 
812
 
813
  else:
814
  bfs_query = f"""
815
+ MATCH (start:`{workspace_label}`)
816
  WHERE start.entity_id = $entity_id
817
  WITH start
818
  CALL {{
 
825
  RETURN all_nodes, all_rels
826
  }}
827
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
828
+ WITH
829
  CASE
830
  WHEN total_nodes <= {max_nodes} THEN nodes
831
  ELSE nodes[0..{max_nodes}]
 
833
  relationships,
834
  total_nodes,
835
  total_nodes > {max_nodes} AS is_truncated
836
+ RETURN
837
  [node IN limited_nodes | {{node: node}}] AS node_info,
838
  relationships,
839
  total_nodes,