gzdaniel commited on
Commit
1601d3e
·
2 Parent(s): 1a6ee7b d0f775b

Merge branch 'add-Memgraph-graph-db' into memgraph

Browse files
Files changed (2) hide show
  1. env.example +9 -1
  2. lightrag/kg/memgraph_impl.py +182 -82
env.example CHANGED
@@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
134
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
135
  ### Graph Storage (Recommended for production deployment)
136
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
 
137
 
138
  ####################################################################
139
  ### Default workspace for all storage types
140
  ### For the purpose of isolation of data for each LightRAG instance
141
  ### Valid characters: a-z, A-Z, 0-9, and _
142
  ####################################################################
143
- # WORKSPACE=doc—
144
 
145
  ### PostgreSQL Configuration
146
  POSTGRES_HOST=localhost
@@ -179,3 +180,10 @@ QDRANT_URL=http://localhost:6333
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
 
 
 
 
 
134
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
135
  ### Graph Storage (Recommended for production deployment)
136
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
137
+ # LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
138
 
139
  ####################################################################
140
  ### Default workspace for all storage types
141
  ### For the purpose of isolation of data for each LightRAG instance
142
  ### Valid characters: a-z, A-Z, 0-9, and _
143
  ####################################################################
144
+ # WORKSPACE=space1
145
 
146
  ### PostgreSQL Configuration
147
  POSTGRES_HOST=localhost
 
180
  ### Redis
181
  REDIS_URI=redis://localhost:6379
182
  # REDIS_WORKSPACE=forced_workspace_name
183
+
184
+ ### Memgraph Configuration
185
+ MEMGRAPH_URI=bolt://localhost:7687
186
+ MEMGRAPH_USERNAME=
187
+ MEMGRAPH_PASSWORD=
188
+ MEMGRAPH_DATABASE=memgraph
189
+ # MEMGRAPH_WORKSPACE=forced_workspace_name
lightrag/kg/memgraph_impl.py CHANGED
@@ -31,14 +31,23 @@ config.read("config.ini", "utf-8")
31
  @final
32
  @dataclass
33
  class MemgraphStorage(BaseGraphStorage):
34
- def __init__(self, namespace, global_config, embedding_func):
 
 
 
35
  super().__init__(
36
  namespace=namespace,
 
37
  global_config=global_config,
38
  embedding_func=embedding_func,
39
  )
40
  self._driver = None
41
 
 
 
 
 
 
42
  async def initialize(self):
43
  URI = os.environ.get(
44
  "MEMGRAPH_URI",
@@ -63,12 +72,17 @@ class MemgraphStorage(BaseGraphStorage):
63
  async with self._driver.session(database=DATABASE) as session:
64
  # Create index for base nodes on entity_id if it doesn't exist
65
  try:
66
- await session.run("""CREATE INDEX ON :base(entity_id)""")
67
- logger.info("Created index on :base(entity_id) in Memgraph.")
 
 
 
 
 
68
  except Exception as e:
69
  # Index may already exist, which is not an error
70
  logger.warning(
71
- f"Index creation on :base(entity_id) may have failed or already exists: {e}"
72
  )
73
  await session.run("RETURN 1")
74
  logger.info(f"Connected to Memgraph at {URI}")
@@ -101,15 +115,22 @@ class MemgraphStorage(BaseGraphStorage):
101
  Raises:
102
  Exception: If there is an error checking the node existence.
103
  """
 
 
 
 
104
  async with self._driver.session(
105
  database=self._DATABASE, default_access_mode="READ"
106
  ) as session:
107
  try:
108
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
 
109
  result = await session.run(query, entity_id=node_id)
110
  single_result = await result.single()
111
  await result.consume() # Ensure result is fully consumed
112
- return single_result["node_exists"]
 
 
113
  except Exception as e:
114
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
115
  await result.consume() # Ensure the result is consumed even on error
@@ -129,22 +150,29 @@ class MemgraphStorage(BaseGraphStorage):
129
  Raises:
130
  Exception: If there is an error checking the edge existence.
131
  """
 
 
 
 
132
  async with self._driver.session(
133
  database=self._DATABASE, default_access_mode="READ"
134
  ) as session:
135
  try:
 
136
  query = (
137
- "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
138
  "RETURN COUNT(r) > 0 AS edgeExists"
139
  )
140
  result = await session.run(
141
  query,
142
  source_entity_id=source_node_id,
143
  target_entity_id=target_node_id,
144
- )
145
  single_result = await result.single()
146
  await result.consume() # Ensure result is fully consumed
147
- return single_result["edgeExists"]
 
 
148
  except Exception as e:
149
  logger.error(
150
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
@@ -165,11 +193,18 @@ class MemgraphStorage(BaseGraphStorage):
165
  Raises:
166
  Exception: If there is an error executing the query
167
  """
 
 
 
 
168
  async with self._driver.session(
169
  database=self._DATABASE, default_access_mode="READ"
170
  ) as session:
171
  try:
172
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
 
 
 
173
  result = await session.run(query, entity_id=node_id)
174
  try:
175
  records = await result.fetch(
@@ -183,12 +218,12 @@ class MemgraphStorage(BaseGraphStorage):
183
  if records:
184
  node = records[0]["n"]
185
  node_dict = dict(node)
186
- # Remove base label from labels list if it exists
187
  if "labels" in node_dict:
188
  node_dict["labels"] = [
189
  label
190
  for label in node_dict["labels"]
191
- if label != "base"
192
  ]
193
  return node_dict
194
  return None
@@ -212,12 +247,17 @@ class MemgraphStorage(BaseGraphStorage):
212
  Raises:
213
  Exception: If there is an error executing the query
214
  """
 
 
 
 
215
  async with self._driver.session(
216
  database=self._DATABASE, default_access_mode="READ"
217
  ) as session:
218
  try:
219
- query = """
220
- MATCH (n:base {entity_id: $entity_id})
 
221
  OPTIONAL MATCH (n)-[r]-()
222
  RETURN COUNT(r) AS degree
223
  """
@@ -246,12 +286,17 @@ class MemgraphStorage(BaseGraphStorage):
246
  Raises:
247
  Exception: If there is an error executing the query
248
  """
 
 
 
 
249
  async with self._driver.session(
250
  database=self._DATABASE, default_access_mode="READ"
251
  ) as session:
252
  try:
253
- query = """
254
- MATCH (n:base)
 
255
  WHERE n.entity_id IS NOT NULL
256
  RETURN DISTINCT n.entity_id AS label
257
  ORDER BY label
@@ -280,13 +325,18 @@ class MemgraphStorage(BaseGraphStorage):
280
  Raises:
281
  Exception: If there is an error executing the query
282
  """
 
 
 
 
283
  try:
284
  async with self._driver.session(
285
  database=self._DATABASE, default_access_mode="READ"
286
  ) as session:
287
  try:
288
- query = """MATCH (n:base {entity_id: $entity_id})
289
- OPTIONAL MATCH (n)-[r]-(connected:base)
 
290
  WHERE connected.entity_id IS NOT NULL
291
  RETURN n, r, connected"""
292
  results = await session.run(query, entity_id=source_node_id)
@@ -341,12 +391,17 @@ class MemgraphStorage(BaseGraphStorage):
341
  Raises:
342
  Exception: If there is an error executing the query
343
  """
 
 
 
 
344
  async with self._driver.session(
345
  database=self._DATABASE, default_access_mode="READ"
346
  ) as session:
347
  try:
348
- query = """
349
- MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
 
350
  RETURN properties(r) as edge_properties
351
  """
352
  result = await session.run(
@@ -386,6 +441,10 @@ class MemgraphStorage(BaseGraphStorage):
386
  node_id: The unique identifier for the node (used as label)
387
  node_data: Dictionary of node properties
388
  """
 
 
 
 
389
  properties = node_data
390
  entity_type = properties["entity_type"]
391
  if "entity_id" not in properties:
@@ -393,16 +452,14 @@ class MemgraphStorage(BaseGraphStorage):
393
 
394
  try:
395
  async with self._driver.session(database=self._DATABASE) as session:
 
396
 
397
  async def execute_upsert(tx: AsyncManagedTransaction):
398
- query = (
399
- """
400
- MERGE (n:base {entity_id: $entity_id})
401
  SET n += $properties
402
- SET n:`%s`
403
  """
404
- % entity_type
405
- )
406
  result = await tx.run(
407
  query, entity_id=node_id, properties=properties
408
  )
@@ -429,15 +486,20 @@ class MemgraphStorage(BaseGraphStorage):
429
  Raises:
430
  Exception: If there is an error executing the query
431
  """
 
 
 
 
432
  try:
433
  edge_properties = edge_data
434
  async with self._driver.session(database=self._DATABASE) as session:
435
 
436
  async def execute_upsert(tx: AsyncManagedTransaction):
437
- query = """
438
- MATCH (source:base {entity_id: $source_entity_id})
 
439
  WITH source
440
- MATCH (target:base {entity_id: $target_entity_id})
441
  MERGE (source)-[r:DIRECTED]-(target)
442
  SET r += $properties
443
  RETURN r, source, target
@@ -467,10 +529,15 @@ class MemgraphStorage(BaseGraphStorage):
467
  Raises:
468
  Exception: If there is an error executing the query
469
  """
 
 
 
 
470
 
471
  async def _do_delete(tx: AsyncManagedTransaction):
472
- query = """
473
- MATCH (n:base {entity_id: $entity_id})
 
474
  DETACH DELETE n
475
  """
476
  result = await tx.run(query, entity_id=node_id)
@@ -490,6 +557,10 @@ class MemgraphStorage(BaseGraphStorage):
490
  Args:
491
  nodes: List of node labels to be deleted
492
  """
 
 
 
 
493
  for node in nodes:
494
  await self.delete_node(node)
495
 
@@ -502,11 +573,16 @@ class MemgraphStorage(BaseGraphStorage):
502
  Raises:
503
  Exception: If there is an error executing the query
504
  """
 
 
 
 
505
  for source, target in edges:
506
 
507
  async def _do_delete_edge(tx: AsyncManagedTransaction):
508
- query = """
509
- MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
 
510
  DELETE r
511
  """
512
  result = await tx.run(
@@ -523,9 +599,9 @@ class MemgraphStorage(BaseGraphStorage):
523
  raise
524
 
525
  async def drop(self) -> dict[str, str]:
526
- """Drop all data from storage and clean up resources
527
 
528
- This method will delete all nodes and relationships in the Neo4j database.
529
 
530
  Returns:
531
  dict[str, str]: Operation status and message
@@ -535,17 +611,24 @@ class MemgraphStorage(BaseGraphStorage):
535
  Raises:
536
  Exception: If there is an error executing the query
537
  """
 
 
 
 
538
  try:
539
  async with self._driver.session(database=self._DATABASE) as session:
540
- query = "MATCH (n) DETACH DELETE n"
 
541
  result = await session.run(query)
542
  await result.consume()
543
  logger.info(
544
- f"Process {os.getpid()} drop Memgraph database {self._DATABASE}"
545
  )
546
- return {"status": "success", "message": "data dropped"}
547
  except Exception as e:
548
- logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
 
 
549
  return {"status": "error", "message": str(e)}
550
 
551
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -558,6 +641,10 @@ class MemgraphStorage(BaseGraphStorage):
558
  Returns:
559
  int: Sum of the degrees of both nodes
560
  """
 
 
 
 
561
  src_degree = await self.node_degree(src_id)
562
  trg_degree = await self.node_degree(tgt_id)
563
 
@@ -578,12 +665,17 @@ class MemgraphStorage(BaseGraphStorage):
578
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
579
  An empty list if no matching nodes are found.
580
  """
 
 
 
 
 
581
  async with self._driver.session(
582
  database=self._DATABASE, default_access_mode="READ"
583
  ) as session:
584
- query = """
585
  UNWIND $chunk_ids AS chunk_id
586
- MATCH (n:base)
587
  WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
588
  RETURN DISTINCT n
589
  """
@@ -607,12 +699,17 @@ class MemgraphStorage(BaseGraphStorage):
607
  list[dict]: A list of edges, where each edge is a dictionary of its properties.
608
  An empty list if no matching edges are found.
609
  """
 
 
 
 
 
610
  async with self._driver.session(
611
  database=self._DATABASE, default_access_mode="READ"
612
  ) as session:
613
- query = """
614
  UNWIND $chunk_ids AS chunk_id
615
- MATCH (a:base)-[r]-(b:base)
616
  WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
617
  WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
618
  // Ensure we only return each unique edge once by ordering the source and target
@@ -652,9 +749,15 @@ class MemgraphStorage(BaseGraphStorage):
652
  Raises:
653
  Exception: If there is an error executing the query
654
  """
 
 
 
 
 
655
  result = KnowledgeGraph()
656
  seen_nodes = set()
657
  seen_edges = set()
 
658
  async with self._driver.session(
659
  database=self._DATABASE, default_access_mode="READ"
660
  ) as session:
@@ -682,19 +785,17 @@ class MemgraphStorage(BaseGraphStorage):
682
  await count_result.consume()
683
 
684
  # Run the main query to get nodes with highest degree
685
- main_query = """
686
- MATCH (n)
687
  OPTIONAL MATCH (n)-[r]-()
688
  WITH n, COALESCE(count(r), 0) AS degree
689
  ORDER BY degree DESC
690
  LIMIT $max_nodes
691
- WITH collect({node: n}) AS filtered_nodes
692
- UNWIND filtered_nodes AS node_info
693
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
694
- OPTIONAL MATCH (a)-[r]-(b)
695
  WHERE a IN kept_nodes AND b IN kept_nodes
696
- RETURN filtered_nodes AS node_info,
697
- collect(DISTINCT r) AS relationships
698
  """
699
  result_set = None
700
  try:
@@ -710,31 +811,33 @@ class MemgraphStorage(BaseGraphStorage):
710
  await result_set.consume()
711
 
712
  else:
713
- bfs_query = """
714
- MATCH (start) WHERE start.entity_id = $entity_id
 
715
  WITH start
716
- CALL {
717
  WITH start
718
- MATCH path = (start)-[*0..$max_depth]-(node)
719
  WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
720
  UNWIND path_nodes AS n
721
  WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
722
  WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
723
  RETURN all_nodes, all_rels
724
- }
725
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
726
-
727
- // Apply node limiting here
728
- WITH CASE
729
- WHEN total_nodes <= $max_nodes THEN nodes
730
- ELSE nodes[0..$max_nodes]
731
  END AS limited_nodes,
732
  relationships,
733
  total_nodes,
734
- total_nodes > $max_nodes AS is_truncated
735
- UNWIND limited_nodes AS node
736
- WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated
737
- RETURN node_info, relationships, total_nodes, is_truncated
 
 
738
  """
739
  result_set = None
740
  try:
@@ -742,8 +845,6 @@ class MemgraphStorage(BaseGraphStorage):
742
  bfs_query,
743
  {
744
  "entity_id": node_label,
745
- "max_depth": max_depth,
746
- "max_nodes": max_nodes,
747
  },
748
  )
749
  record = await result_set.single()
@@ -777,22 +878,21 @@ class MemgraphStorage(BaseGraphStorage):
777
  )
778
  )
779
 
780
- if "relationships" in record and record["relationships"]:
781
- for rel in record["relationships"]:
782
- edge_id = rel.id
783
- if edge_id not in seen_edges:
784
- seen_edges.add(edge_id)
785
- start = rel.start_node
786
- end = rel.end_node
787
- result.edges.append(
788
- KnowledgeGraphEdge(
789
- id=f"{edge_id}",
790
- type=rel.type,
791
- source=f"{start.id}",
792
- target=f"{end.id}",
793
- properties=dict(rel),
794
- )
795
  )
 
796
 
797
  logger.info(
798
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
 
31
  @final
32
  @dataclass
33
  class MemgraphStorage(BaseGraphStorage):
34
+ def __init__(self, namespace, global_config, embedding_func, workspace=None):
35
+ memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
36
+ if memgraph_workspace and memgraph_workspace.strip():
37
+ workspace = memgraph_workspace
38
  super().__init__(
39
  namespace=namespace,
40
+ workspace=workspace or "",
41
  global_config=global_config,
42
  embedding_func=embedding_func,
43
  )
44
  self._driver = None
45
 
46
+ def _get_workspace_label(self) -> str:
47
+ """Get workspace label, return 'base' for compatibility when workspace is empty"""
48
+ workspace = getattr(self, "workspace", None)
49
+ return workspace if workspace else "base"
50
+
51
  async def initialize(self):
52
  URI = os.environ.get(
53
  "MEMGRAPH_URI",
 
72
  async with self._driver.session(database=DATABASE) as session:
73
  # Create index for base nodes on entity_id if it doesn't exist
74
  try:
75
+ workspace_label = self._get_workspace_label()
76
+ await session.run(
77
+ f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
78
+ )
79
+ logger.info(
80
+ f"Created index on :{workspace_label}(entity_id) in Memgraph."
81
+ )
82
  except Exception as e:
83
  # Index may already exist, which is not an error
84
  logger.warning(
85
+ f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
86
  )
87
  await session.run("RETURN 1")
88
  logger.info(f"Connected to Memgraph at {URI}")
 
115
  Raises:
116
  Exception: If there is an error checking the node existence.
117
  """
118
+ if self._driver is None:
119
+ raise RuntimeError(
120
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
121
+ )
122
  async with self._driver.session(
123
  database=self._DATABASE, default_access_mode="READ"
124
  ) as session:
125
  try:
126
+ workspace_label = self._get_workspace_label()
127
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
128
  result = await session.run(query, entity_id=node_id)
129
  single_result = await result.single()
130
  await result.consume() # Ensure result is fully consumed
131
+ return (
132
+ single_result["node_exists"] if single_result is not None else False
133
+ )
134
  except Exception as e:
135
  logger.error(f"Error checking node existence for {node_id}: {str(e)}")
136
  await result.consume() # Ensure the result is consumed even on error
 
150
  Raises:
151
  Exception: If there is an error checking the edge existence.
152
  """
153
+ if self._driver is None:
154
+ raise RuntimeError(
155
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
156
+ )
157
  async with self._driver.session(
158
  database=self._DATABASE, default_access_mode="READ"
159
  ) as session:
160
  try:
161
+ workspace_label = self._get_workspace_label()
162
  query = (
163
+ f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
164
  "RETURN COUNT(r) > 0 AS edgeExists"
165
  )
166
  result = await session.run(
167
  query,
168
  source_entity_id=source_node_id,
169
  target_entity_id=target_node_id,
170
+ ) # type: ignore
171
  single_result = await result.single()
172
  await result.consume() # Ensure result is fully consumed
173
+ return (
174
+ single_result["edgeExists"] if single_result is not None else False
175
+ )
176
  except Exception as e:
177
  logger.error(
178
  f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
 
193
  Raises:
194
  Exception: If there is an error executing the query
195
  """
196
+ if self._driver is None:
197
+ raise RuntimeError(
198
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
199
+ )
200
  async with self._driver.session(
201
  database=self._DATABASE, default_access_mode="READ"
202
  ) as session:
203
  try:
204
+ workspace_label = self._get_workspace_label()
205
+ query = (
206
+ f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
207
+ )
208
  result = await session.run(query, entity_id=node_id)
209
  try:
210
  records = await result.fetch(
 
218
  if records:
219
  node = records[0]["n"]
220
  node_dict = dict(node)
221
+ # Remove workspace label from labels list if it exists
222
  if "labels" in node_dict:
223
  node_dict["labels"] = [
224
  label
225
  for label in node_dict["labels"]
226
+ if label != workspace_label
227
  ]
228
  return node_dict
229
  return None
 
247
  Raises:
248
  Exception: If there is an error executing the query
249
  """
250
+ if self._driver is None:
251
+ raise RuntimeError(
252
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
253
+ )
254
  async with self._driver.session(
255
  database=self._DATABASE, default_access_mode="READ"
256
  ) as session:
257
  try:
258
+ workspace_label = self._get_workspace_label()
259
+ query = f"""
260
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
261
  OPTIONAL MATCH (n)-[r]-()
262
  RETURN COUNT(r) AS degree
263
  """
 
286
  Raises:
287
  Exception: If there is an error executing the query
288
  """
289
+ if self._driver is None:
290
+ raise RuntimeError(
291
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
292
+ )
293
  async with self._driver.session(
294
  database=self._DATABASE, default_access_mode="READ"
295
  ) as session:
296
  try:
297
+ workspace_label = self._get_workspace_label()
298
+ query = f"""
299
+ MATCH (n:`{workspace_label}`)
300
  WHERE n.entity_id IS NOT NULL
301
  RETURN DISTINCT n.entity_id AS label
302
  ORDER BY label
 
325
  Raises:
326
  Exception: If there is an error executing the query
327
  """
328
+ if self._driver is None:
329
+ raise RuntimeError(
330
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
331
+ )
332
  try:
333
  async with self._driver.session(
334
  database=self._DATABASE, default_access_mode="READ"
335
  ) as session:
336
  try:
337
+ workspace_label = self._get_workspace_label()
338
+ query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
339
+ OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
340
  WHERE connected.entity_id IS NOT NULL
341
  RETURN n, r, connected"""
342
  results = await session.run(query, entity_id=source_node_id)
 
391
  Raises:
392
  Exception: If there is an error executing the query
393
  """
394
+ if self._driver is None:
395
+ raise RuntimeError(
396
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
397
+ )
398
  async with self._driver.session(
399
  database=self._DATABASE, default_access_mode="READ"
400
  ) as session:
401
  try:
402
+ workspace_label = self._get_workspace_label()
403
+ query = f"""
404
+ MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
405
  RETURN properties(r) as edge_properties
406
  """
407
  result = await session.run(
 
441
  node_id: The unique identifier for the node (used as label)
442
  node_data: Dictionary of node properties
443
  """
444
+ if self._driver is None:
445
+ raise RuntimeError(
446
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
447
+ )
448
  properties = node_data
449
  entity_type = properties["entity_type"]
450
  if "entity_id" not in properties:
 
452
 
453
  try:
454
  async with self._driver.session(database=self._DATABASE) as session:
455
+ workspace_label = self._get_workspace_label()
456
 
457
  async def execute_upsert(tx: AsyncManagedTransaction):
458
+ query = f"""
459
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
 
460
  SET n += $properties
461
+ SET n:`{entity_type}`
462
  """
 
 
463
  result = await tx.run(
464
  query, entity_id=node_id, properties=properties
465
  )
 
486
  Raises:
487
  Exception: If there is an error executing the query
488
  """
489
+ if self._driver is None:
490
+ raise RuntimeError(
491
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
492
+ )
493
  try:
494
  edge_properties = edge_data
495
  async with self._driver.session(database=self._DATABASE) as session:
496
 
497
  async def execute_upsert(tx: AsyncManagedTransaction):
498
+ workspace_label = self._get_workspace_label()
499
+ query = f"""
500
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
501
  WITH source
502
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
503
  MERGE (source)-[r:DIRECTED]-(target)
504
  SET r += $properties
505
  RETURN r, source, target
 
529
  Raises:
530
  Exception: If there is an error executing the query
531
  """
532
+ if self._driver is None:
533
+ raise RuntimeError(
534
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
535
+ )
536
 
537
  async def _do_delete(tx: AsyncManagedTransaction):
538
+ workspace_label = self._get_workspace_label()
539
+ query = f"""
540
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
541
  DETACH DELETE n
542
  """
543
  result = await tx.run(query, entity_id=node_id)
 
557
  Args:
558
  nodes: List of node labels to be deleted
559
  """
560
+ if self._driver is None:
561
+ raise RuntimeError(
562
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
563
+ )
564
  for node in nodes:
565
  await self.delete_node(node)
566
 
 
573
  Raises:
574
  Exception: If there is an error executing the query
575
  """
576
+ if self._driver is None:
577
+ raise RuntimeError(
578
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
579
+ )
580
  for source, target in edges:
581
 
582
  async def _do_delete_edge(tx: AsyncManagedTransaction):
583
+ workspace_label = self._get_workspace_label()
584
+ query = f"""
585
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
586
  DELETE r
587
  """
588
  result = await tx.run(
 
599
  raise
600
 
601
  async def drop(self) -> dict[str, str]:
602
+ """Drop all data from the current workspace and clean up resources
603
 
604
+ This method will delete all nodes and relationships in the Memgraph database.
605
 
606
  Returns:
607
  dict[str, str]: Operation status and message
 
611
  Raises:
612
  Exception: If there is an error executing the query
613
  """
614
+ if self._driver is None:
615
+ raise RuntimeError(
616
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
617
+ )
618
  try:
619
  async with self._driver.session(database=self._DATABASE) as session:
620
+ workspace_label = self._get_workspace_label()
621
+ query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
622
  result = await session.run(query)
623
  await result.consume()
624
  logger.info(
625
+ f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
626
  )
627
+ return {"status": "success", "message": "workspace data dropped"}
628
  except Exception as e:
629
+ logger.error(
630
+ f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
631
+ )
632
  return {"status": "error", "message": str(e)}
633
 
634
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
641
  Returns:
642
  int: Sum of the degrees of both nodes
643
  """
644
+ if self._driver is None:
645
+ raise RuntimeError(
646
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
647
+ )
648
  src_degree = await self.node_degree(src_id)
649
  trg_degree = await self.node_degree(tgt_id)
650
 
 
665
  list[dict]: A list of nodes, where each node is a dictionary of its properties.
666
  An empty list if no matching nodes are found.
667
  """
668
+ if self._driver is None:
669
+ raise RuntimeError(
670
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
671
+ )
672
+ workspace_label = self._get_workspace_label()
673
  async with self._driver.session(
674
  database=self._DATABASE, default_access_mode="READ"
675
  ) as session:
676
+ query = f"""
677
  UNWIND $chunk_ids AS chunk_id
678
+ MATCH (n:`{workspace_label}`)
679
  WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
680
  RETURN DISTINCT n
681
  """
 
699
  list[dict]: A list of edges, where each edge is a dictionary of its properties.
700
  An empty list if no matching edges are found.
701
  """
702
+ if self._driver is None:
703
+ raise RuntimeError(
704
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
705
+ )
706
+ workspace_label = self._get_workspace_label()
707
  async with self._driver.session(
708
  database=self._DATABASE, default_access_mode="READ"
709
  ) as session:
710
+ query = f"""
711
  UNWIND $chunk_ids AS chunk_id
712
+ MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
713
  WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
714
  WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
715
  // Ensure we only return each unique edge once by ordering the source and target
 
749
  Raises:
750
  Exception: If there is an error executing the query
751
  """
752
+ if self._driver is None:
753
+ raise RuntimeError(
754
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
755
+ )
756
+
757
  result = KnowledgeGraph()
758
  seen_nodes = set()
759
  seen_edges = set()
760
+ workspace_label = self._get_workspace_label()
761
  async with self._driver.session(
762
  database=self._DATABASE, default_access_mode="READ"
763
  ) as session:
 
785
  await count_result.consume()
786
 
787
  # Run the main query to get nodes with highest degree
788
+ main_query = f"""
789
+ MATCH (n:`{workspace_label}`)
790
  OPTIONAL MATCH (n)-[r]-()
791
  WITH n, COALESCE(count(r), 0) AS degree
792
  ORDER BY degree DESC
793
  LIMIT $max_nodes
794
+ WITH collect(n) AS kept_nodes
795
+ MATCH (a)-[r]-(b)
 
 
796
  WHERE a IN kept_nodes AND b IN kept_nodes
797
+ RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
798
+ collect(DISTINCT r) AS relationships
799
  """
800
  result_set = None
801
  try:
 
811
  await result_set.consume()
812
 
813
  else:
814
+ bfs_query = f"""
815
+ MATCH (start:`{workspace_label}`)
816
+ WHERE start.entity_id = $entity_id
817
  WITH start
818
+ CALL {{
819
  WITH start
820
+ MATCH path = (start)-[*0..{max_depth}]-(node)
821
  WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
822
  UNWIND path_nodes AS n
823
  WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
824
  WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
825
  RETURN all_nodes, all_rels
826
+ }}
827
  WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
828
+ WITH
829
+ CASE
830
+ WHEN total_nodes <= {max_nodes} THEN nodes
831
+ ELSE nodes[0..{max_nodes}]
 
832
  END AS limited_nodes,
833
  relationships,
834
  total_nodes,
835
+ total_nodes > {max_nodes} AS is_truncated
836
+ RETURN
837
+ [node IN limited_nodes | {{node: node}}] AS node_info,
838
+ relationships,
839
+ total_nodes,
840
+ is_truncated
841
  """
842
  result_set = None
843
  try:
 
845
  bfs_query,
846
  {
847
  "entity_id": node_label,
 
 
848
  },
849
  )
850
  record = await result_set.single()
 
878
  )
879
  )
880
 
881
+ for rel in record["relationships"]:
882
+ edge_id = rel.id
883
+ if edge_id not in seen_edges:
884
+ seen_edges.add(edge_id)
885
+ start = rel.start_node
886
+ end = rel.end_node
887
+ result.edges.append(
888
+ KnowledgeGraphEdge(
889
+ id=f"{edge_id}",
890
+ type=rel.type,
891
+ source=f"{start.id}",
892
+ target=f"{end.id}",
893
+ properties=dict(rel),
 
 
894
  )
895
+ )
896
 
897
  logger.info(
898
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"