DavIvek commited on
Commit
5cc69d3
·
1 Parent(s): 277f2d7

add changes based on review

Browse files
Files changed (2) hide show
  1. env.example +7 -0
  2. lightrag/kg/memgraph_impl.py +136 -87
env.example CHANGED
@@ -179,3 +179,10 @@ QDRANT_URL=http://localhost:6333
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
 
 
 
 
 
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
182
+
183
+ ### Memgraph Configuration
184
+ MEMGRAPH_URI=bolt://localhost:7687
185
+ MEMGRAPH_USERNAME=
186
+ MEMGRAPH_PASSWORD=
187
+ MEMGRAPH_DATABASE=memgraph
188
+ # MEMGRAPH_WORKSPACE=forced_workspace_name
lightrag/kg/memgraph_impl.py CHANGED
@@ -31,14 +31,23 @@ config.read("config.ini", "utf-8")
31
  @final
32
  @dataclass
33
  class MemgraphStorage(BaseGraphStorage):
34
- def __init__(self, namespace, global_config, embedding_func):
 
 
 
35
  super().__init__(
36
  namespace=namespace,
 
37
  global_config=global_config,
38
  embedding_func=embedding_func,
39
  )
40
  self._driver = None
41
 
 
 
 
 
 
42
  async def initialize(self):
43
  URI = os.environ.get(
44
  "MEMGRAPH_URI",
@@ -63,12 +72,13 @@ class MemgraphStorage(BaseGraphStorage):
63
  async with self._driver.session(database=DATABASE) as session:
64
  # Create index for base nodes on entity_id if it doesn't exist
65
  try:
66
- await session.run("""CREATE INDEX ON :base(entity_id)""")
67
- logger.info("Created index on :base(entity_id) in Memgraph.")
 
68
  except Exception as e:
69
  # Index may already exist, which is not an error
70
  logger.warning(
71
- f"Index creation on :base(entity_id) may have failed or already exists: {e}"
72
  )
73
  await session.run("RETURN 1")
74
  logger.info(f"Connected to Memgraph at {URI}")
@@ -101,15 +111,18 @@ class MemgraphStorage(BaseGraphStorage):
101
  Raises:
102
  Exception: If there is an error checking the node existence.
103
  """
 
 
104
  async with self._driver.session(
105
  database=self._DATABASE, default_access_mode="READ"
106
  ) as session:
107
  try:
108
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
 
109
  result = await session.run(query, entity_id=node_id)
110
  single_result = await result.single()
111
  await result.consume() # Ensure result is fully consumed
112
- return single_result["node_exists"]
113
  except Exception as e:
114
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
115
  await result.consume() # Ensure the result is consumed even on error
@@ -129,22 +142,21 @@ class MemgraphStorage(BaseGraphStorage):
129
  Raises:
130
  Exception: If there is an error checking the edge existence.
131
  """
 
 
132
  async with self._driver.session(
133
  database=self._DATABASE, default_access_mode="READ"
134
  ) as session:
135
  try:
 
136
  query = (
137
- "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
138
  "RETURN COUNT(r) > 0 AS edgeExists"
139
  )
140
- result = await session.run(
141
- query,
142
- source_entity_id=source_node_id,
143
- target_entity_id=target_node_id,
144
- )
145
  single_result = await result.single()
146
  await result.consume() # Ensure result is fully consumed
147
- return single_result["edgeExists"]
148
  except Exception as e:
149
  logger.error(
150
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@@ -165,11 +177,14 @@ class MemgraphStorage(BaseGraphStorage):
165
  Raises:
166
  Exception: If there is an error executing the query
167
  """
 
 
168
  async with self._driver.session(
169
  database=self._DATABASE, default_access_mode="READ"
170
  ) as session:
171
  try:
172
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
 
173
  result = await session.run(query, entity_id=node_id)
174
  try:
175
  records = await result.fetch(
@@ -183,12 +198,12 @@ class MemgraphStorage(BaseGraphStorage):
183
  if records:
184
  node = records[0]["n"]
185
  node_dict = dict(node)
186
- # Remove base label from labels list if it exists
187
  if "labels" in node_dict:
188
  node_dict["labels"] = [
189
  label
190
  for label in node_dict["labels"]
191
- if label != "base"
192
  ]
193
  return node_dict
194
  return None
@@ -212,12 +227,15 @@ class MemgraphStorage(BaseGraphStorage):
212
  Raises:
213
  Exception: If there is an error executing the query
214
  """
 
 
215
  async with self._driver.session(
216
  database=self._DATABASE, default_access_mode="READ"
217
  ) as session:
218
  try:
219
- query = """
220
- MATCH (n:base {entity_id: $entity_id})
 
221
  OPTIONAL MATCH (n)-[r]-()
222
  RETURN COUNT(r) AS degree
223
  """
@@ -246,12 +264,15 @@ class MemgraphStorage(BaseGraphStorage):
246
  Raises:
247
  Exception: If there is an error executing the query
248
  """
 
 
249
  async with self._driver.session(
250
  database=self._DATABASE, default_access_mode="READ"
251
  ) as session:
252
  try:
253
- query = """
254
- MATCH (n:base)
 
255
  WHERE n.entity_id IS NOT NULL
256
  RETURN DISTINCT n.entity_id AS label
257
  ORDER BY label
@@ -280,13 +301,16 @@ class MemgraphStorage(BaseGraphStorage):
280
  Raises:
281
  Exception: If there is an error executing the query
282
  """
 
 
283
  try:
284
  async with self._driver.session(
285
  database=self._DATABASE, default_access_mode="READ"
286
  ) as session:
287
  try:
288
- query = """MATCH (n:base {entity_id: $entity_id})
289
- OPTIONAL MATCH (n)-[r]-(connected:base)
 
290
  WHERE connected.entity_id IS NOT NULL
291
  RETURN n, r, connected"""
292
  results = await session.run(query, entity_id=source_node_id)
@@ -341,12 +365,15 @@ class MemgraphStorage(BaseGraphStorage):
341
  Raises:
342
  Exception: If there is an error executing the query
343
  """
 
 
344
  async with self._driver.session(
345
  database=self._DATABASE, default_access_mode="READ"
346
  ) as session:
347
  try:
348
- query = """
349
- MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
 
350
  RETURN properties(r) as edge_properties
351
  """
352
  result = await session.run(
@@ -386,6 +413,8 @@ class MemgraphStorage(BaseGraphStorage):
386
  node_id: The unique identifier for the node (used as label)
387
  node_data: Dictionary of node properties
388
  """
 
 
389
  properties = node_data
390
  entity_type = properties["entity_type"]
391
  if "entity_id" not in properties:
@@ -393,15 +422,14 @@ class MemgraphStorage(BaseGraphStorage):
393
 
394
  try:
395
  async with self._driver.session(database=self._DATABASE) as session:
396
-
397
  async def execute_upsert(tx: AsyncManagedTransaction):
398
  query = (
399
- """
400
- MERGE (n:base {entity_id: $entity_id})
401
  SET n += $properties
402
- SET n:`%s`
403
  """
404
- % entity_type
405
  )
406
  result = await tx.run(
407
  query, entity_id=node_id, properties=properties
@@ -429,15 +457,18 @@ class MemgraphStorage(BaseGraphStorage):
429
  Raises:
430
  Exception: If there is an error executing the query
431
  """
 
 
432
  try:
433
  edge_properties = edge_data
434
  async with self._driver.session(database=self._DATABASE) as session:
435
 
436
  async def execute_upsert(tx: AsyncManagedTransaction):
437
- query = """
438
- MATCH (source:base {entity_id: $source_entity_id})
 
439
  WITH source
440
- MATCH (target:base {entity_id: $target_entity_id})
441
  MERGE (source)-[r:DIRECTED]-(target)
442
  SET r += $properties
443
  RETURN r, source, target
@@ -467,10 +498,13 @@ class MemgraphStorage(BaseGraphStorage):
467
  Raises:
468
  Exception: If there is an error executing the query
469
  """
 
 
470
 
471
  async def _do_delete(tx: AsyncManagedTransaction):
472
- query = """
473
- MATCH (n:base {entity_id: $entity_id})
 
474
  DETACH DELETE n
475
  """
476
  result = await tx.run(query, entity_id=node_id)
@@ -490,6 +524,8 @@ class MemgraphStorage(BaseGraphStorage):
490
  Args:
491
  nodes: List of node labels to be deleted
492
  """
 
 
493
  for node in nodes:
494
  await self.delete_node(node)
495
 
@@ -502,11 +538,14 @@ class MemgraphStorage(BaseGraphStorage):
502
  Raises:
503
  Exception: If there is an error executing the query
504
  """
 
 
505
  for source, target in edges:
506
 
507
  async def _do_delete_edge(tx: AsyncManagedTransaction):
508
- query = """
509
- MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
 
510
  DELETE r
511
  """
512
  result = await tx.run(
@@ -523,9 +562,9 @@ class MemgraphStorage(BaseGraphStorage):
523
  raise
524
 
525
  async def drop(self) -> dict[str, str]:
526
- """Drop all data from storage and clean up resources
527
 
528
- This method will delete all nodes and relationships in the Neo4j database.
529
 
530
  Returns:
531
  dict[str, str]: Operation status and message
@@ -535,17 +574,18 @@ class MemgraphStorage(BaseGraphStorage):
535
  Raises:
536
  Exception: If there is an error executing the query
537
  """
 
 
538
  try:
539
  async with self._driver.session(database=self._DATABASE) as session:
540
- query = "MATCH (n) DETACH DELETE n"
 
541
  result = await session.run(query)
542
  await result.consume()
543
- logger.info(
544
- f"Process {os.getpid()} drop Memgraph database {self._DATABASE}"
545
- )
546
- return {"status": "success", "message": "data dropped"}
547
  except Exception as e:
548
- logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
549
  return {"status": "error", "message": str(e)}
550
 
551
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -558,6 +598,8 @@ class MemgraphStorage(BaseGraphStorage):
558
  Returns:
559
  int: Sum of the degrees of both nodes
560
  """
 
 
561
  src_degree = await self.node_degree(src_id)
562
  trg_degree = await self.node_degree(tgt_id)
563
 
@@ -578,12 +620,15 @@ class MemgraphStorage(BaseGraphStorage):
578
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
579
  An empty list if no matching nodes are found.
580
  """
 
 
 
581
  async with self._driver.session(
582
  database=self._DATABASE, default_access_mode="READ"
583
  ) as session:
584
- query = """
585
  UNWIND $chunk_ids AS chunk_id
586
- MATCH (n:base)
587
  WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
588
  RETURN DISTINCT n
589
  """
@@ -607,12 +652,15 @@ class MemgraphStorage(BaseGraphStorage):
607
  list[dict]: A list of edges, where each edge is a dictionary of its properties.
608
  An empty list if no matching edges are found.
609
  """
 
 
 
610
  async with self._driver.session(
611
  database=self._DATABASE, default_access_mode="READ"
612
  ) as session:
613
- query = """
614
  UNWIND $chunk_ids AS chunk_id
615
- MATCH (a:base)-[r]-(b:base)
616
  WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
617
  WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
618
  // Ensure we only return each unique edge once by ordering the source and target
@@ -652,9 +700,13 @@ class MemgraphStorage(BaseGraphStorage):
652
  Raises:
653
  Exception: If there is an error executing the query
654
  """
 
 
 
655
  result = KnowledgeGraph()
656
  seen_nodes = set()
657
  seen_edges = set()
 
658
  async with self._driver.session(
659
  database=self._DATABASE, default_access_mode="READ"
660
  ) as session:
@@ -682,19 +734,17 @@ class MemgraphStorage(BaseGraphStorage):
682
  await count_result.consume()
683
 
684
  # Run the main query to get nodes with highest degree
685
- main_query = """
686
- MATCH (n)
687
  OPTIONAL MATCH (n)-[r]-()
688
  WITH n, COALESCE(count(r), 0) AS degree
689
  ORDER BY degree DESC
690
  LIMIT $max_nodes
691
- WITH collect({node: n}) AS filtered_nodes
692
- UNWIND filtered_nodes AS node_info
693
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
694
- OPTIONAL MATCH (a)-[r]-(b)
695
  WHERE a IN kept_nodes AND b IN kept_nodes
696
- RETURN filtered_nodes AS node_info,
697
- collect(DISTINCT r) AS relationships
698
  """
699
  result_set = None
700
  try:
@@ -710,31 +760,33 @@ class MemgraphStorage(BaseGraphStorage):
710
  await result_set.consume()
711
 
712
  else:
713
- bfs_query = """
714
- MATCH (start) WHERE start.entity_id = $entity_id
 
715
  WITH start
716
- CALL {
717
  WITH start
718
- MATCH path = (start)-[*0..$max_depth]-(node)
719
  WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
720
  UNWIND path_nodes AS n
721
  WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
722
  WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
723
  RETURN all_nodes, all_rels
724
- }
725
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
726
-
727
- // Apply node limiting here
728
- WITH CASE
729
- WHEN total_nodes <= $max_nodes THEN nodes
730
- ELSE nodes[0..$max_nodes]
731
  END AS limited_nodes,
732
  relationships,
733
  total_nodes,
734
- total_nodes > $max_nodes AS is_truncated
735
- UNWIND limited_nodes AS node
736
- WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated
737
- RETURN node_info, relationships, total_nodes, is_truncated
 
 
738
  """
739
  result_set = None
740
  try:
@@ -742,8 +794,6 @@ class MemgraphStorage(BaseGraphStorage):
742
  bfs_query,
743
  {
744
  "entity_id": node_label,
745
- "max_depth": max_depth,
746
- "max_nodes": max_nodes,
747
  },
748
  )
749
  record = await result_set.single()
@@ -777,22 +827,21 @@ class MemgraphStorage(BaseGraphStorage):
777
  )
778
  )
779
 
780
- if "relationships" in record and record["relationships"]:
781
- for rel in record["relationships"]:
782
- edge_id = rel.id
783
- if edge_id not in seen_edges:
784
- seen_edges.add(edge_id)
785
- start = rel.start_node
786
- end = rel.end_node
787
- result.edges.append(
788
- KnowledgeGraphEdge(
789
- id=f"{edge_id}",
790
- type=rel.type,
791
- source=f"{start.id}",
792
- target=f"{end.id}",
793
- properties=dict(rel),
794
- )
795
  )
 
796
 
797
  logger.info(
798
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
 
31
  @final
32
  @dataclass
33
  class MemgraphStorage(BaseGraphStorage):
34
+ def __init__(self, namespace, global_config, embedding_func, workspace=None):
35
+ memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
36
+ if memgraph_workspace and memgraph_workspace.strip():
37
+ workspace = memgraph_workspace
38
  super().__init__(
39
  namespace=namespace,
40
+ workspace=workspace or "",
41
  global_config=global_config,
42
  embedding_func=embedding_func,
43
  )
44
  self._driver = None
45
 
46
+ def _get_workspace_label(self) -> str:
47
+ """Get workspace label, return 'base' for compatibility when workspace is empty"""
48
+ workspace = getattr(self, "workspace", None)
49
+ return workspace if workspace else "base"
50
+
51
  async def initialize(self):
52
  URI = os.environ.get(
53
  "MEMGRAPH_URI",
 
72
  async with self._driver.session(database=DATABASE) as session:
73
  # Create index for base nodes on entity_id if it doesn't exist
74
  try:
75
+ workspace_label = self._get_workspace_label()
76
+ await session.run(f"""CREATE INDEX ON :{workspace_label}(entity_id)""")
77
+ logger.info(f"Created index on :{workspace_label}(entity_id) in Memgraph.")
78
  except Exception as e:
79
  # Index may already exist, which is not an error
80
  logger.warning(
81
+ f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
82
  )
83
  await session.run("RETURN 1")
84
  logger.info(f"Connected to Memgraph at {URI}")
 
111
  Raises:
112
  Exception: If there is an error checking the node existence.
113
  """
114
+ if self._driver is None:
115
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
116
  async with self._driver.session(
117
  database=self._DATABASE, default_access_mode="READ"
118
  ) as session:
119
  try:
120
+ workspace_label = self._get_workspace_label()
121
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
122
  result = await session.run(query, entity_id=node_id)
123
  single_result = await result.single()
124
  await result.consume() # Ensure result is fully consumed
125
+ return single_result["node_exists"] if single_result is not None else False
126
  except Exception as e:
127
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
128
  await result.consume() # Ensure the result is consumed even on error
 
142
  Raises:
143
  Exception: If there is an error checking the edge existence.
144
  """
145
+ if self._driver is None:
146
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
147
  async with self._driver.session(
148
  database=self._DATABASE, default_access_mode="READ"
149
  ) as session:
150
  try:
151
+ workspace_label = self._get_workspace_label()
152
  query = (
153
+ f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
154
  "RETURN COUNT(r) > 0 AS edgeExists"
155
  )
156
+ result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) # type: ignore
 
 
 
 
157
  single_result = await result.single()
158
  await result.consume() # Ensure result is fully consumed
159
+ return single_result["edgeExists"] if single_result is not None else False
160
  except Exception as e:
161
  logger.error(
162
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
 
177
  Raises:
178
  Exception: If there is an error executing the query
179
  """
180
+ if self._driver is None:
181
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
182
  async with self._driver.session(
183
  database=self._DATABASE, default_access_mode="READ"
184
  ) as session:
185
  try:
186
+ workspace_label = self._get_workspace_label()
187
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
188
  result = await session.run(query, entity_id=node_id)
189
  try:
190
  records = await result.fetch(
 
198
  if records:
199
  node = records[0]["n"]
200
  node_dict = dict(node)
201
+ # Remove workspace label from labels list if it exists
202
  if "labels" in node_dict:
203
  node_dict["labels"] = [
204
  label
205
  for label in node_dict["labels"]
206
+ if label != workspace_label
207
  ]
208
  return node_dict
209
  return None
 
227
  Raises:
228
  Exception: If there is an error executing the query
229
  """
230
+ if self._driver is None:
231
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
232
  async with self._driver.session(
233
  database=self._DATABASE, default_access_mode="READ"
234
  ) as session:
235
  try:
236
+ workspace_label = self._get_workspace_label()
237
+ query = f"""
238
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
239
  OPTIONAL MATCH (n)-[r]-()
240
  RETURN COUNT(r) AS degree
241
  """
 
264
  Raises:
265
  Exception: If there is an error executing the query
266
  """
267
+ if self._driver is None:
268
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
269
  async with self._driver.session(
270
  database=self._DATABASE, default_access_mode="READ"
271
  ) as session:
272
  try:
273
+ workspace_label = self._get_workspace_label()
274
+ query = f"""
275
+ MATCH (n:`{workspace_label}`)
276
  WHERE n.entity_id IS NOT NULL
277
  RETURN DISTINCT n.entity_id AS label
278
  ORDER BY label
 
301
  Raises:
302
  Exception: If there is an error executing the query
303
  """
304
+ if self._driver is None:
305
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
306
  try:
307
  async with self._driver.session(
308
  database=self._DATABASE, default_access_mode="READ"
309
  ) as session:
310
  try:
311
+ workspace_label = self._get_workspace_label()
312
+ query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
313
+ OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
314
  WHERE connected.entity_id IS NOT NULL
315
  RETURN n, r, connected"""
316
  results = await session.run(query, entity_id=source_node_id)
 
365
  Raises:
366
  Exception: If there is an error executing the query
367
  """
368
+ if self._driver is None:
369
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
370
  async with self._driver.session(
371
  database=self._DATABASE, default_access_mode="READ"
372
  ) as session:
373
  try:
374
+ workspace_label = self._get_workspace_label()
375
+ query = f"""
376
+ MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
377
  RETURN properties(r) as edge_properties
378
  """
379
  result = await session.run(
 
413
  node_id: The unique identifier for the node (used as label)
414
  node_data: Dictionary of node properties
415
  """
416
+ if self._driver is None:
417
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
418
  properties = node_data
419
  entity_type = properties["entity_type"]
420
  if "entity_id" not in properties:
 
422
 
423
  try:
424
  async with self._driver.session(database=self._DATABASE) as session:
425
+ workspace_label = self._get_workspace_label()
426
  async def execute_upsert(tx: AsyncManagedTransaction):
427
  query = (
428
+ f"""
429
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
430
  SET n += $properties
431
+ SET n:`{entity_type}`
432
  """
 
433
  )
434
  result = await tx.run(
435
  query, entity_id=node_id, properties=properties
 
457
  Raises:
458
  Exception: If there is an error executing the query
459
  """
460
+ if self._driver is None:
461
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
462
  try:
463
  edge_properties = edge_data
464
  async with self._driver.session(database=self._DATABASE) as session:
465
 
466
  async def execute_upsert(tx: AsyncManagedTransaction):
467
+ workspace_label = self._get_workspace_label()
468
+ query = f"""
469
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
470
  WITH source
471
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
472
  MERGE (source)-[r:DIRECTED]-(target)
473
  SET r += $properties
474
  RETURN r, source, target
 
498
  Raises:
499
  Exception: If there is an error executing the query
500
  """
501
+ if self._driver is None:
502
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
503
 
504
  async def _do_delete(tx: AsyncManagedTransaction):
505
+ workspace_label = self._get_workspace_label()
506
+ query = f"""
507
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
508
  DETACH DELETE n
509
  """
510
  result = await tx.run(query, entity_id=node_id)
 
524
  Args:
525
  nodes: List of node labels to be deleted
526
  """
527
+ if self._driver is None:
528
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
529
  for node in nodes:
530
  await self.delete_node(node)
531
 
 
538
  Raises:
539
  Exception: If there is an error executing the query
540
  """
541
+ if self._driver is None:
542
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
543
  for source, target in edges:
544
 
545
  async def _do_delete_edge(tx: AsyncManagedTransaction):
546
+ workspace_label = self._get_workspace_label()
547
+ query = f"""
548
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
549
  DELETE r
550
  """
551
  result = await tx.run(
 
562
  raise
563
 
564
  async def drop(self) -> dict[str, str]:
565
+ """Drop all data from the current workspace and clean up resources
566
 
567
+ This method will delete all nodes and relationships in the Memgraph database.
568
 
569
  Returns:
570
  dict[str, str]: Operation status and message
 
574
  Raises:
575
  Exception: If there is an error executing the query
576
  """
577
+ if self._driver is None:
578
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
579
  try:
580
  async with self._driver.session(database=self._DATABASE) as session:
581
+ workspace_label = self._get_workspace_label()
582
+ query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
583
  result = await session.run(query)
584
  await result.consume()
585
+ logger.info(f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}")
586
+ return {"status": "success", "message": "workspace data dropped"}
 
 
587
  except Exception as e:
588
+ logger.error(f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}")
589
  return {"status": "error", "message": str(e)}
590
 
591
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
598
  Returns:
599
  int: Sum of the degrees of both nodes
600
  """
601
+ if self._driver is None:
602
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
603
  src_degree = await self.node_degree(src_id)
604
  trg_degree = await self.node_degree(tgt_id)
605
 
 
620
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
621
  An empty list if no matching nodes are found.
622
  """
623
+ if self._driver is None:
624
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
625
+ workspace_label = self._get_workspace_label()
626
  async with self._driver.session(
627
  database=self._DATABASE, default_access_mode="READ"
628
  ) as session:
629
+ query = f"""
630
  UNWIND $chunk_ids AS chunk_id
631
+ MATCH (n:`{workspace_label}`)
632
  WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
633
  RETURN DISTINCT n
634
  """
 
652
  list[dict]: A list of edges, where each edge is a dictionary of its properties.
653
  An empty list if no matching edges are found.
654
  """
655
+ if self._driver is None:
656
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
657
+ workspace_label = self._get_workspace_label()
658
  async with self._driver.session(
659
  database=self._DATABASE, default_access_mode="READ"
660
  ) as session:
661
+ query = f"""
662
  UNWIND $chunk_ids AS chunk_id
663
+ MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
664
  WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
665
  WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
666
  // Ensure we only return each unique edge once by ordering the source and target
 
700
  Raises:
701
  Exception: If there is an error executing the query
702
  """
703
+ if self._driver is None:
704
+ raise RuntimeError("Memgraph driver is not initialized. Call 'await initialize()' first.")
705
+
706
  result = KnowledgeGraph()
707
  seen_nodes = set()
708
  seen_edges = set()
709
+ workspace_label = self._get_workspace_label()
710
  async with self._driver.session(
711
  database=self._DATABASE, default_access_mode="READ"
712
  ) as session:
 
734
  await count_result.consume()
735
 
736
  # Run the main query to get nodes with highest degree
737
+ main_query = f"""
738
+ MATCH (n:`{workspace_label}`)
739
  OPTIONAL MATCH (n)-[r]-()
740
  WITH n, COALESCE(count(r), 0) AS degree
741
  ORDER BY degree DESC
742
  LIMIT $max_nodes
743
+ WITH collect(n) AS kept_nodes
744
+ MATCH (a)-[r]-(b)
 
 
745
  WHERE a IN kept_nodes AND b IN kept_nodes
746
+ RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
747
+ collect(DISTINCT r) AS relationships
748
  """
749
  result_set = None
750
  try:
 
760
  await result_set.consume()
761
 
762
  else:
763
+ bfs_query = f"""
764
+ MATCH (start:`{workspace_label}`)
765
+ WHERE start.entity_id = $entity_id
766
  WITH start
767
+ CALL {{
768
  WITH start
769
+ MATCH path = (start)-[*0..{max_depth}]-(node)
770
  WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
771
  UNWIND path_nodes AS n
772
  WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
773
  WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
774
  RETURN all_nodes, all_rels
775
+ }}
776
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
777
+ WITH
778
+ CASE
779
+ WHEN total_nodes <= {max_nodes} THEN nodes
780
+ ELSE nodes[0..{max_nodes}]
 
781
  END AS limited_nodes,
782
  relationships,
783
  total_nodes,
784
+ total_nodes > {max_nodes} AS is_truncated
785
+ RETURN
786
+ [node IN limited_nodes | {{node: node}}] AS node_info,
787
+ relationships,
788
+ total_nodes,
789
+ is_truncated
790
  """
791
  result_set = None
792
  try:
 
794
  bfs_query,
795
  {
796
  "entity_id": node_label,
 
 
797
  },
798
  )
799
  record = await result_set.single()
 
827
  )
828
  )
829
 
830
+ for rel in record["relationships"]:
831
+ edge_id = rel.id
832
+ if edge_id not in seen_edges:
833
+ seen_edges.add(edge_id)
834
+ start = rel.start_node
835
+ end = rel.end_node
836
+ result.edges.append(
837
+ KnowledgeGraphEdge(
838
+ id=f"{edge_id}",
839
+ type=rel.type,
840
+ source=f"{start.id}",
841
+ target=f"{end.id}",
842
+ properties=dict(rel),
 
 
843
  )
844
+ )
845
 
846
  logger.info(
847
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"