DavIvek commited on
Commit
7c8ef0b
·
1 Parent(s): 040b53c

polish Memgraph implementation

Browse files
Files changed (1) hide show
  1. lightrag/kg/memgraph_impl.py +550 -251
lightrag/kg/memgraph_impl.py CHANGED
@@ -89,183 +89,419 @@ class MemgraphStorage(BaseGraphStorage):
89
  pass
90
 
91
  async def has_node(self, node_id: str) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
92
  async with self._driver.session(
93
  database=self._DATABASE, default_access_mode="READ"
94
  ) as session:
95
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
96
- result = await session.run(query, entity_id=node_id)
97
- single_result = await result.single()
98
- await result.consume()
99
- return single_result["node_exists"]
 
 
 
 
 
100
 
101
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  async with self._driver.session(
103
  database=self._DATABASE, default_access_mode="READ"
104
  ) as session:
105
- query = (
106
- "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
107
- "RETURN COUNT(r) > 0 AS edgeExists"
108
- )
109
- result = await session.run(
110
- query,
111
- source_entity_id=source_node_id,
112
- target_entity_id=target_node_id,
113
- )
114
- single_result = await result.single()
115
- await result.consume()
116
- return single_result["edgeExists"]
 
 
 
 
 
 
 
117
 
118
  async def get_node(self, node_id: str) -> dict[str, str] | None:
 
 
 
 
 
 
 
 
 
 
 
 
119
  async with self._driver.session(
120
  database=self._DATABASE, default_access_mode="READ"
121
  ) as session:
122
- query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
123
- result = await session.run(query, entity_id=node_id)
124
- records = await result.fetch(2)
125
- await result.consume()
126
- if records:
127
- node = records[0]["n"]
128
- node_dict = dict(node)
129
- if "labels" in node_dict:
130
- node_dict["labels"] = [
131
- label for label in node_dict["labels"] if label != "base"
132
- ]
133
- return node_dict
134
- return None
135
 
136
- async def get_all_labels(self) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  async with self._driver.session(
138
  database=self._DATABASE, default_access_mode="READ"
139
  ) as session:
140
- query = """
141
- MATCH (n:base)
142
- WHERE n.entity_id IS NOT NULL
143
- RETURN DISTINCT n.entity_id AS label
144
- ORDER BY label
145
- """
146
- result = await session.run(query)
147
- labels = []
148
- async for record in result:
149
- labels.append(record["label"])
150
- await result.consume()
151
- return labels
152
 
153
- async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  async with self._driver.session(
155
  database=self._DATABASE, default_access_mode="READ"
156
  ) as session:
157
- query = """
158
- MATCH (n:base {entity_id: $entity_id})
159
- OPTIONAL MATCH (n)-[r]-(connected:base)
160
- WHERE connected.entity_id IS NOT NULL
161
- RETURN n, r, connected
162
- """
163
- results = await session.run(query, entity_id=source_node_id)
164
- edges = []
165
- async for record in results:
166
- source_node = record["n"]
167
- connected_node = record["connected"]
168
- if not source_node or not connected_node:
169
- continue
170
- source_label = source_node.get("entity_id")
171
- target_label = connected_node.get("entity_id")
172
- if source_label and target_label:
173
- edges.append((source_label, target_label))
174
- await results.consume()
175
- return edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  async def get_edge(
178
  self, source_node_id: str, target_node_id: str
179
  ) -> dict[str, str] | None:
 
 
 
 
 
 
 
 
 
 
 
 
180
  async with self._driver.session(
181
  database=self._DATABASE, default_access_mode="READ"
182
  ) as session:
183
- query = """
184
- MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
185
- RETURN properties(r) as edge_properties
186
- """
187
- result = await session.run(
188
- query,
189
- source_entity_id=source_node_id,
190
- target_entity_id=target_node_id,
191
- )
192
- records = await result.fetch(2)
193
- await result.consume()
194
- if records:
195
- edge_result = dict(records[0]["edge_properties"])
196
- for key, default_value in {
197
- "weight": 0.0,
198
- "source_id": None,
199
- "description": None,
200
- "keywords": None,
201
- }.items():
202
- if key not in edge_result:
203
- edge_result[key] = default_value
204
- return edge_result
205
- return None
 
 
 
 
 
 
 
 
 
 
206
 
207
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
 
 
 
 
 
 
208
  properties = node_data
209
- entity_type = properties.get("entity_type", "base")
210
  if "entity_id" not in properties:
211
- raise ValueError(
212
- "Memgraph: node properties must contain an 'entity_id' field"
213
- )
214
- async with self._driver.session(database=self._DATABASE) as session:
215
-
216
- async def execute_upsert(tx: AsyncManagedTransaction):
217
- query = f"""
218
- MERGE (n:base {{entity_id: $entity_id}})
 
219
  SET n += $properties
220
- SET n:`{entity_type}`
221
  """
222
- result = await tx.run(query, entity_id=node_id, properties=properties)
223
- await result.consume()
 
 
 
 
224
 
225
- await session.execute_write(execute_upsert)
 
 
 
226
 
227
  async def upsert_edge(
228
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
229
  ) -> None:
230
- edge_properties = edge_data
231
- async with self._driver.session(database=self._DATABASE) as session:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- async def execute_upsert(tx: AsyncManagedTransaction):
234
- query = """
235
- MATCH (source:base {entity_id: $source_entity_id})
236
- WITH source
237
- MATCH (target:base {entity_id: $target_entity_id})
238
- MERGE (source)-[r:DIRECTED]-(target)
239
- SET r += $properties
240
- RETURN r, source, target
241
- """
242
- result = await tx.run(
243
- query,
244
- source_entity_id=source_node_id,
245
- target_entity_id=target_node_id,
246
- properties=edge_properties,
247
- )
248
- await result.consume()
 
 
 
249
 
250
- await session.execute_write(execute_upsert)
 
 
 
251
 
252
  async def delete_node(self, node_id: str) -> None:
 
 
 
 
 
 
 
 
 
253
  async def _do_delete(tx: AsyncManagedTransaction):
254
  query = """
255
  MATCH (n:base {entity_id: $entity_id})
256
  DETACH DELETE n
257
  """
258
  result = await tx.run(query, entity_id=node_id)
 
259
  await result.consume()
260
 
261
- async with self._driver.session(database=self._DATABASE) as session:
262
- await session.execute_write(_do_delete)
 
 
 
 
263
 
264
  async def remove_nodes(self, nodes: list[str]):
 
 
 
 
 
265
  for node in nodes:
266
  await self.delete_node(node)
267
 
268
  async def remove_edges(self, edges: list[tuple[str, str]]):
 
 
 
 
 
 
 
 
269
  for source, target in edges:
270
 
271
  async def _do_delete_edge(tx: AsyncManagedTransaction):
@@ -276,15 +512,32 @@ class MemgraphStorage(BaseGraphStorage):
276
  result = await tx.run(
277
  query, source_entity_id=source, target_entity_id=target
278
  )
279
- await result.consume()
 
280
 
281
- async with self._driver.session(database=self._DATABASE) as session:
282
- await session.execute_write(_do_delete_edge)
 
 
 
 
283
 
284
  async def drop(self) -> dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
285
  try:
286
  async with self._driver.session(database=self._DATABASE) as session:
287
- query = "MATCH (n) DETACH DELETE n"
288
  result = await session.run(query)
289
  await result.consume()
290
  logger.info(
@@ -295,30 +548,36 @@ class MemgraphStorage(BaseGraphStorage):
295
  logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
296
  return {"status": "error", "message": str(e)}
297
 
298
- async def node_degree(self, node_id: str) -> int:
299
- async with self._driver.session(
300
- database=self._DATABASE, default_access_mode="READ"
301
- ) as session:
302
- query = """
303
- MATCH (n:base {entity_id: $entity_id})
304
- OPTIONAL MATCH (n)-[r]-()
305
- RETURN COUNT(r) AS degree
306
- """
307
- result = await session.run(query, entity_id=node_id)
308
- record = await result.single()
309
- await result.consume()
310
- if not record:
311
- return 0
312
- return record["degree"]
313
-
314
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
 
 
 
 
 
 
 
 
315
  src_degree = await self.node_degree(src_id)
316
  trg_degree = await self.node_degree(tgt_id)
 
 
317
  src_degree = 0 if src_degree is None else src_degree
318
  trg_degree = 0 if trg_degree is None else trg_degree
319
- return int(src_degree) + int(trg_degree)
 
 
320
 
321
  async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 
 
 
 
 
 
 
 
 
322
  async with self._driver.session(
323
  database=self._DATABASE, default_access_mode="READ"
324
  ) as session:
@@ -335,10 +594,19 @@ class MemgraphStorage(BaseGraphStorage):
335
  node_dict = dict(node)
336
  node_dict["id"] = node_dict.get("entity_id")
337
  nodes.append(node_dict)
338
- await result.consume()
339
- return nodes
340
 
341
  async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
 
 
 
 
 
 
 
 
 
342
  async with self._driver.session(
343
  database=self._DATABASE, default_access_mode="READ"
344
  ) as session:
@@ -364,118 +632,149 @@ class MemgraphStorage(BaseGraphStorage):
364
  max_depth: int = 3,
365
  max_nodes: int = MAX_GRAPH_NODES,
366
  ) -> KnowledgeGraph:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  result = KnowledgeGraph()
368
  seen_nodes = set()
369
  seen_edges = set()
370
- async with self._driver.session(
371
- database=self._DATABASE, default_access_mode="READ"
372
- ) as session:
373
- if node_label == "*":
374
- count_query = "MATCH (n) RETURN count(n) as total"
375
- count_result = await session.run(count_query)
376
- count_record = await count_result.single()
377
- await count_result.consume()
378
- if count_record and count_record["total"] > max_nodes:
379
- result.is_truncated = True
380
- logger.info(
381
- f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
382
- )
383
- main_query = """
384
- MATCH (n)
385
- OPTIONAL MATCH (n)-[r]-()
386
- WITH n, COALESCE(count(r), 0) AS degree
387
- ORDER BY degree DESC
388
- LIMIT $max_nodes
389
- WITH collect({node: n}) AS filtered_nodes
390
- UNWIND filtered_nodes AS node_info
391
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
392
- OPTIONAL MATCH (a)-[r]-(b)
393
- WHERE a IN kept_nodes AND b IN kept_nodes
394
- RETURN filtered_nodes AS node_info,
395
- collect(DISTINCT r) AS relationships
396
- """
397
- result_set = await session.run(main_query, {"max_nodes": max_nodes})
398
- record = await result_set.single()
399
- await result_set.consume()
400
- else:
401
- # BFS fallback for Memgraph (no APOC)
402
- from collections import deque
403
-
404
- # Get the starting node
405
- start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
406
- node_result = await session.run(start_query, entity_id=node_label)
407
- node_record = await node_result.single()
408
- await node_result.consume()
409
- if not node_record:
410
- return result
411
- start_node = node_record["n"]
412
- queue = deque([(start_node, 0)])
413
- visited = set()
414
- bfs_nodes = []
415
- while queue and len(bfs_nodes) < max_nodes:
416
- current_node, depth = queue.popleft()
417
- node_id = current_node.get("entity_id")
418
- if node_id in visited:
419
- continue
420
- visited.add(node_id)
421
- bfs_nodes.append(current_node)
422
- if depth < max_depth:
423
- # Get neighbors
424
- neighbor_query = """
425
- MATCH (n:base {entity_id: $entity_id})-[]-(m:base)
426
- RETURN m
427
- """
428
- neighbors_result = await session.run(
429
- neighbor_query, entity_id=node_id
430
- )
431
- neighbors = [
432
- rec["m"] for rec in await neighbors_result.to_list()
433
- ]
434
- await neighbors_result.consume()
435
- for neighbor in neighbors:
436
- neighbor_id = neighbor.get("entity_id")
437
- if neighbor_id not in visited:
438
- queue.append((neighbor, depth + 1))
439
- # Build subgraph
440
- subgraph_ids = [n.get("entity_id") for n in bfs_nodes]
441
- # Nodes
442
- for n in bfs_nodes:
443
- node_id = n.get("entity_id")
444
- if node_id not in seen_nodes:
445
- result.nodes.append(
446
- KnowledgeGraphNode(
447
- id=node_id,
448
- labels=[node_id],
449
- properties=dict(n),
450
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  )
452
- seen_nodes.add(node_id)
453
- # Edges
454
- if subgraph_ids:
455
- edge_query = """
456
- MATCH (a:base)-[r]-(b:base)
457
- WHERE a.entity_id IN $ids AND b.entity_id IN $ids
458
- RETURN DISTINCT r, a, b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  """
460
- edge_result = await session.run(edge_query, ids=subgraph_ids)
461
- async for record in edge_result:
462
- r = record["r"]
463
- a = record["a"]
464
- b = record["b"]
465
- edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}"
466
- if edge_id not in seen_edges:
467
- result.edges.append(
468
- KnowledgeGraphEdge(
469
- id=edge_id,
470
- type="DIRECTED",
471
- source=a.get("entity_id"),
472
- target=b.get("entity_id"),
473
- properties=dict(r),
 
 
 
 
 
 
 
 
 
 
 
 
474
  )
475
- )
476
- seen_edges.add(edge_id)
477
- await edge_result.consume()
478
- logger.info(
479
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
480
- )
481
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  pass
90
 
91
  async def has_node(self, node_id: str) -> bool:
92
+ """
93
+ Check if a node exists in the graph.
94
+
95
+ Args:
96
+ node_id: The ID of the node to check.
97
+
98
+ Returns:
99
+ bool: True if the node exists, False otherwise.
100
+
101
+ Raises:
102
+ Exception: If there is an error checking the node existence.
103
+ """
104
  async with self._driver.session(
105
  database=self._DATABASE, default_access_mode="READ"
106
  ) as session:
107
+ try:
108
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
109
+ result = await session.run(query, entity_id=node_id)
110
+ single_result = await result.single()
111
+ await result.consume() # Ensure result is fully consumed
112
+ return single_result["node_exists"]
113
+ except Exception as e:
114
+ logger.error(f"Error checking node existence for {node_id}: {str(e)}")
115
+ await result.consume() # Ensure the result is consumed even on error
116
+ raise
117
 
118
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
119
+ """
120
+ Check if an edge exists between two nodes in the graph.
121
+
122
+ Args:
123
+ source_node_id: The ID of the source node.
124
+ target_node_id: The ID of the target node.
125
+
126
+ Returns:
127
+ bool: True if the edge exists, False otherwise.
128
+
129
+ Raises:
130
+ Exception: If there is an error checking the edge existence.
131
+ """
132
  async with self._driver.session(
133
  database=self._DATABASE, default_access_mode="READ"
134
  ) as session:
135
+ try:
136
+ query = (
137
+ "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
138
+ "RETURN COUNT(r) > 0 AS edgeExists"
139
+ )
140
+ result = await session.run(
141
+ query,
142
+ source_entity_id=source_node_id,
143
+ target_entity_id=target_node_id,
144
+ )
145
+ single_result = await result.single()
146
+ await result.consume() # Ensure result is fully consumed
147
+ return single_result["edgeExists"]
148
+ except Exception as e:
149
+ logger.error(
150
+ f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
151
+ )
152
+ await result.consume() # Ensure the result is consumed even on error
153
+ raise
154
 
155
  async def get_node(self, node_id: str) -> dict[str, str] | None:
156
+ """Get node by its label identifier, return only node properties
157
+
158
+ Args:
159
+ node_id: The node label to look up
160
+
161
+ Returns:
162
+ dict: Node properties if found
163
+ None: If node not found
164
+
165
+ Raises:
166
+ Exception: If there is an error executing the query
167
+ """
168
  async with self._driver.session(
169
  database=self._DATABASE, default_access_mode="READ"
170
  ) as session:
171
+ try:
172
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
173
+ result = await session.run(query, entity_id=node_id)
174
+ try:
175
+ records = await result.fetch(
176
+ 2
177
+ ) # Get 2 records for duplication check
 
 
 
 
 
 
178
 
179
+ if len(records) > 1:
180
+ logger.warning(
181
+ f"Multiple nodes found with label '{node_id}'. Using first node."
182
+ )
183
+ if records:
184
+ node = records[0]["n"]
185
+ node_dict = dict(node)
186
+ # Remove base label from labels list if it exists
187
+ if "labels" in node_dict:
188
+ node_dict["labels"] = [
189
+ label
190
+ for label in node_dict["labels"]
191
+ if label != "base"
192
+ ]
193
+ return node_dict
194
+ return None
195
+ finally:
196
+ await result.consume() # Ensure result is fully consumed
197
+ except Exception as e:
198
+ logger.error(f"Error getting node for {node_id}: {str(e)}")
199
+ raise
200
+
201
+ async def node_degree(self, node_id: str) -> int:
202
+ """Get the degree (number of relationships) of a node with the given label.
203
+ If multiple nodes have the same label, returns the degree of the first node.
204
+ If no node is found, returns 0.
205
+
206
+ Args:
207
+ node_id: The label of the node
208
+
209
+ Returns:
210
+ int: The number of relationships the node has, or 0 if no node found
211
+
212
+ Raises:
213
+ Exception: If there is an error executing the query
214
+ """
215
  async with self._driver.session(
216
  database=self._DATABASE, default_access_mode="READ"
217
  ) as session:
218
+ try:
219
+ query = """
220
+ MATCH (n:base {entity_id: $entity_id})
221
+ OPTIONAL MATCH (n)-[r]-()
222
+ RETURN COUNT(r) AS degree
223
+ """
224
+ result = await session.run(query, entity_id=node_id)
225
+ try:
226
+ record = await result.single()
 
 
 
227
 
228
+ if not record:
229
+ logger.warning(f"No node found with label '{node_id}'")
230
+ return 0
231
+
232
+ degree = record["degree"]
233
+ return degree
234
+ finally:
235
+ await result.consume() # Ensure result is fully consumed
236
+ except Exception as e:
237
+ logger.error(f"Error getting node degree for {node_id}: {str(e)}")
238
+ raise
239
+
240
+ async def get_all_labels(self) -> list[str]:
241
+ """
242
+ Get all existing node labels in the database
243
+ Returns:
244
+ ["Person", "Company", ...] # Alphabetically sorted label list
245
+
246
+ Raises:
247
+ Exception: If there is an error executing the query
248
+ """
249
  async with self._driver.session(
250
  database=self._DATABASE, default_access_mode="READ"
251
  ) as session:
252
+ try:
253
+ query = """
254
+ MATCH (n:base)
255
+ WHERE n.entity_id IS NOT NULL
256
+ RETURN DISTINCT n.entity_id AS label
257
+ ORDER BY label
258
+ """
259
+ result = await session.run(query)
260
+ labels = []
261
+ async for record in result:
262
+ labels.append(record["label"])
263
+ await result.consume()
264
+ return labels
265
+ except Exception as e:
266
+ logger.error(f"Error getting all labels: {str(e)}")
267
+ await result.consume() # Ensure the result is consumed even on error
268
+ raise
269
+
270
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
271
+ """Retrieves all edges (relationships) for a particular node identified by its label.
272
+
273
+ Args:
274
+ source_node_id: Label of the node to get edges for
275
+
276
+ Returns:
277
+ list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
278
+ None: If no edges found
279
+
280
+ Raises:
281
+ Exception: If there is an error executing the query
282
+ """
283
+ try:
284
+ async with self._driver.session(
285
+ database=self._DATABASE, default_access_mode="READ"
286
+ ) as session:
287
+ try:
288
+ query = """MATCH (n:base {entity_id: $entity_id})
289
+ OPTIONAL MATCH (n)-[r]-(connected:base)
290
+ WHERE connected.entity_id IS NOT NULL
291
+ RETURN n, r, connected"""
292
+ results = await session.run(query, entity_id=source_node_id)
293
+
294
+ edges = []
295
+ async for record in results:
296
+ source_node = record["n"]
297
+ connected_node = record["connected"]
298
+
299
+ # Skip if either node is None
300
+ if not source_node or not connected_node:
301
+ continue
302
+
303
+ source_label = (
304
+ source_node.get("entity_id")
305
+ if source_node.get("entity_id")
306
+ else None
307
+ )
308
+ target_label = (
309
+ connected_node.get("entity_id")
310
+ if connected_node.get("entity_id")
311
+ else None
312
+ )
313
+
314
+ if source_label and target_label:
315
+ edges.append((source_label, target_label))
316
+
317
+ await results.consume() # Ensure results are consumed
318
+ return edges
319
+ except Exception as e:
320
+ logger.error(
321
+ f"Error getting edges for node {source_node_id}: {str(e)}"
322
+ )
323
+ await results.consume() # Ensure results are consumed even on error
324
+ raise
325
+ except Exception as e:
326
+ logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
327
+ raise
328
 
329
  async def get_edge(
330
  self, source_node_id: str, target_node_id: str
331
  ) -> dict[str, str] | None:
332
+ """Get edge properties between two nodes.
333
+
334
+ Args:
335
+ source_node_id: Label of the source node
336
+ target_node_id: Label of the target node
337
+
338
+ Returns:
339
+ dict: Edge properties if found, default properties if not found or on error
340
+
341
+ Raises:
342
+ Exception: If there is an error executing the query
343
+ """
344
  async with self._driver.session(
345
  database=self._DATABASE, default_access_mode="READ"
346
  ) as session:
347
+ try:
348
+ query = """
349
+ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
350
+ RETURN properties(r) as edge_properties
351
+ """
352
+ result = await session.run(
353
+ query,
354
+ source_entity_id=source_node_id,
355
+ target_entity_id=target_node_id,
356
+ )
357
+ records = await result.fetch(2)
358
+ await result.consume()
359
+ if records:
360
+ edge_result = dict(records[0]["edge_properties"])
361
+ for key, default_value in {
362
+ "weight": 0.0,
363
+ "source_id": None,
364
+ "description": None,
365
+ "keywords": None,
366
+ }.items():
367
+ if key not in edge_result:
368
+ edge_result[key] = default_value
369
+ logger.warning(
370
+ f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
371
+ )
372
+ return edge_result
373
+ return None
374
+ except Exception as e:
375
+ logger.error(
376
+ f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
377
+ )
378
+ await result.consume() # Ensure the result is consumed even on error
379
+ raise
380
 
381
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
382
+ """
383
+ Upsert a node in the Neo4j database.
384
+
385
+ Args:
386
+ node_id: The unique identifier for the node (used as label)
387
+ node_data: Dictionary of node properties
388
+ """
389
  properties = node_data
390
+ entity_type = properties["entity_type"]
391
  if "entity_id" not in properties:
392
+ raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
393
+
394
+ try:
395
+ async with self._driver.session(database=self._DATABASE) as session:
396
+
397
+ async def execute_upsert(tx: AsyncManagedTransaction):
398
+ query = (
399
+ """
400
+ MERGE (n:base {entity_id: $entity_id})
401
  SET n += $properties
402
+ SET n:`%s`
403
  """
404
+ % entity_type
405
+ )
406
+ result = await tx.run(
407
+ query, entity_id=node_id, properties=properties
408
+ )
409
+ await result.consume() # Ensure result is fully consumed
410
 
411
+ await session.execute_write(execute_upsert)
412
+ except Exception as e:
413
+ logger.error(f"Error during upsert: {str(e)}")
414
+ raise
415
 
416
  async def upsert_edge(
417
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
418
  ) -> None:
419
+ """
420
+ Upsert an edge and its properties between two nodes identified by their labels.
421
+ Ensures both source and target nodes exist and are unique before creating the edge.
422
+ Uses entity_id property to uniquely identify nodes.
423
+
424
+ Args:
425
+ source_node_id (str): Label of the source node (used as identifier)
426
+ target_node_id (str): Label of the target node (used as identifier)
427
+ edge_data (dict): Dictionary of properties to set on the edge
428
+
429
+ Raises:
430
+ Exception: If there is an error executing the query
431
+ """
432
+ try:
433
+ edge_properties = edge_data
434
+ async with self._driver.session(database=self._DATABASE) as session:
435
 
436
+ async def execute_upsert(tx: AsyncManagedTransaction):
437
+ query = """
438
+ MATCH (source:base {entity_id: $source_entity_id})
439
+ WITH source
440
+ MATCH (target:base {entity_id: $target_entity_id})
441
+ MERGE (source)-[r:DIRECTED]-(target)
442
+ SET r += $properties
443
+ RETURN r, source, target
444
+ """
445
+ result = await tx.run(
446
+ query,
447
+ source_entity_id=source_node_id,
448
+ target_entity_id=target_node_id,
449
+ properties=edge_properties,
450
+ )
451
+ try:
452
+ await result.fetch(2)
453
+ finally:
454
+ await result.consume() # Ensure result is consumed
455
 
456
+ await session.execute_write(execute_upsert)
457
+ except Exception as e:
458
+ logger.error(f"Error during edge upsert: {str(e)}")
459
+ raise
460
 
461
  async def delete_node(self, node_id: str) -> None:
462
+ """Delete a node with the specified label
463
+
464
+ Args:
465
+ node_id: The label of the node to delete
466
+
467
+ Raises:
468
+ Exception: If there is an error executing the query
469
+ """
470
+
471
  async def _do_delete(tx: AsyncManagedTransaction):
472
  query = """
473
  MATCH (n:base {entity_id: $entity_id})
474
  DETACH DELETE n
475
  """
476
  result = await tx.run(query, entity_id=node_id)
477
+ logger.debug(f"Deleted node with label {node_id}")
478
  await result.consume()
479
 
480
+ try:
481
+ async with self._driver.session(database=self._DATABASE) as session:
482
+ await session.execute_write(_do_delete)
483
+ except Exception as e:
484
+ logger.error(f"Error during node deletion: {str(e)}")
485
+ raise
486
 
487
  async def remove_nodes(self, nodes: list[str]):
488
+ """Delete multiple nodes
489
+
490
+ Args:
491
+ nodes: List of node labels to be deleted
492
+ """
493
  for node in nodes:
494
  await self.delete_node(node)
495
 
496
  async def remove_edges(self, edges: list[tuple[str, str]]):
497
+ """Delete multiple edges
498
+
499
+ Args:
500
+ edges: List of edges to be deleted, each edge is a (source, target) tuple
501
+
502
+ Raises:
503
+ Exception: If there is an error executing the query
504
+ """
505
  for source, target in edges:
506
 
507
  async def _do_delete_edge(tx: AsyncManagedTransaction):
 
512
  result = await tx.run(
513
  query, source_entity_id=source, target_entity_id=target
514
  )
515
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
516
+ await result.consume() # Ensure result is fully consumed
517
 
518
+ try:
519
+ async with self._driver.session(database=self._DATABASE) as session:
520
+ await session.execute_write(_do_delete_edge)
521
+ except Exception as e:
522
+ logger.error(f"Error during edge deletion: {str(e)}")
523
+ raise
524
 
525
  async def drop(self) -> dict[str, str]:
526
+ """Drop all data from storage and clean up resources
527
+
528
+ This method will delete all nodes and relationships in the Neo4j database.
529
+
530
+ Returns:
531
+ dict[str, str]: Operation status and message
532
+ - On success: {"status": "success", "message": "data dropped"}
533
+ - On failure: {"status": "error", "message": "<error details>"}
534
+
535
+ Raises:
536
+ Exception: If there is an error executing the query
537
+ """
538
  try:
539
  async with self._driver.session(database=self._DATABASE) as session:
540
+ query = "DROP GRAPH"
541
  result = await session.run(query)
542
  await result.consume()
543
  logger.info(
 
548
  logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
549
  return {"status": "error", "message": str(e)}
550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
552
+ """Get the total degree (sum of relationships) of two nodes.
553
+
554
+ Args:
555
+ src_id: Label of the source node
556
+ tgt_id: Label of the target node
557
+
558
+ Returns:
559
+ int: Sum of the degrees of both nodes
560
+ """
561
  src_degree = await self.node_degree(src_id)
562
  trg_degree = await self.node_degree(tgt_id)
563
+
564
+ # Convert None to 0 for addition
565
  src_degree = 0 if src_degree is None else src_degree
566
  trg_degree = 0 if trg_degree is None else trg_degree
567
+
568
+ degrees = int(src_degree) + int(trg_degree)
569
+ return degrees
570
 
571
  async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
572
+ """Get all nodes that are associated with the given chunk_ids.
573
+
574
+ Args:
575
+ chunk_ids: List of chunk IDs to find associated nodes for
576
+
577
+ Returns:
578
+ list[dict]: A list of nodes, where each node is a dictionary of its properties.
579
+ An empty list if no matching nodes are found.
580
+ """
581
  async with self._driver.session(
582
  database=self._DATABASE, default_access_mode="READ"
583
  ) as session:
 
594
  node_dict = dict(node)
595
  node_dict["id"] = node_dict.get("entity_id")
596
  nodes.append(node_dict)
597
+ await result.consume()
598
+ return nodes
599
 
600
  async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
601
+ """Get all edges that are associated with the given chunk_ids.
602
+
603
+ Args:
604
+ chunk_ids: List of chunk IDs to find associated edges for
605
+
606
+ Returns:
607
+ list[dict]: A list of edges, where each edge is a dictionary of its properties.
608
+ An empty list if no matching edges are found.
609
+ """
610
  async with self._driver.session(
611
  database=self._DATABASE, default_access_mode="READ"
612
  ) as session:
 
632
  max_depth: int = 3,
633
  max_nodes: int = MAX_GRAPH_NODES,
634
  ) -> KnowledgeGraph:
635
+ """
636
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
637
+
638
+ Args:
639
+ node_label: Label of the starting node, * means all nodes
640
+ max_depth: Maximum depth of the subgraph, Defaults to 3
641
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
642
+
643
+ Returns:
644
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
645
+ indicating whether the graph was truncated due to max_nodes limit
646
+
647
+ Raises:
648
+ Exception: If there is an error executing the query
649
+ """
650
  result = KnowledgeGraph()
651
  seen_nodes = set()
652
  seen_edges = set()
653
+ try:
654
+ async with self._driver.session(
655
+ database=self._DATABASE, default_access_mode="READ"
656
+ ) as session:
657
+ if node_label == "*":
658
+ count_query = "MATCH (n) RETURN count(n) as total"
659
+ count_result = None
660
+ try:
661
+ count_result = await session.run(count_query)
662
+ count_record = await count_result.single()
663
+ if count_record and count_record["total"] > max_nodes:
664
+ result.is_truncated = True
665
+ logger.info(
666
+ f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  )
668
+ finally:
669
+ if count_result:
670
+ await count_result.consume()
671
+
672
+ # Run the main query to get nodes with highest degree
673
+ main_query = """
674
+ MATCH (n)
675
+ OPTIONAL MATCH (n)-[r]-()
676
+ WITH n, COALESCE(count(r), 0) AS degree
677
+ ORDER BY degree DESC
678
+ LIMIT $max_nodes
679
+ WITH collect({node: n}) AS filtered_nodes
680
+ UNWIND filtered_nodes AS node_info
681
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
682
+ OPTIONAL MATCH (a)-[r]-(b)
683
+ WHERE a IN kept_nodes AND b IN kept_nodes
684
+ RETURN filtered_nodes AS node_info,
685
+ collect(DISTINCT r) AS relationships
686
+ """
687
+ result_set = None
688
+ try:
689
+ result_set = await session.run(
690
+ main_query, {"max_nodes": max_nodes}
691
  )
692
+ record = await result_set.single()
693
+ finally:
694
+ if result_set:
695
+ await result_set.consume()
696
+
697
+ else:
698
+ bfs_query = """
699
+ MATCH (start) WHERE start.entity_id = $entity_id
700
+ WITH start
701
+ CALL {
702
+ WITH start
703
+ MATCH path = (start)-[*0..$max_depth]-(node)
704
+ WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
705
+ UNWIND path_nodes AS n
706
+ WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
707
+ WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
708
+ RETURN all_nodes, all_rels
709
+ }
710
+ WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
711
+
712
+ // Apply node limiting here
713
+ WITH CASE
714
+ WHEN total_nodes <= $max_nodes THEN nodes
715
+ ELSE nodes[0..$max_nodes]
716
+ END AS limited_nodes,
717
+ relationships,
718
+ total_nodes,
719
+ total_nodes > $max_nodes AS is_truncated
720
+ UNWIND limited_nodes AS node
721
+ WITH collect({node: node}) AS node_info, relationships, total_nodes, is_truncated
722
+ RETURN node_info, relationships, total_nodes, is_truncated
723
  """
724
+ result_set = None
725
+ try:
726
+ result_set = await session.run(
727
+ bfs_query,
728
+ {
729
+ "entity_id": node_label,
730
+ "max_depth": max_depth,
731
+ "max_nodes": max_nodes,
732
+ },
733
+ )
734
+ record = await result_set.single()
735
+ if not record:
736
+ logger.debug(f"No record found for node {node_label}")
737
+ return result
738
+
739
+ for node_info in record["node_info"]:
740
+ node = node_info["node"]
741
+ node_id = node.id
742
+ if node_id not in seen_nodes:
743
+ seen_nodes.add(node_id)
744
+ result.nodes.append(
745
+ KnowledgeGraphNode(
746
+ id=f"{node_id}",
747
+ labels=[node.get("entity_id")],
748
+ properties=dict(node),
749
+ )
750
  )
751
+
752
+ for rel in record["relationships"]:
753
+ edge_id = rel.id
754
+ if edge_id not in seen_edges:
755
+ seen_edges.add(edge_id)
756
+ start = rel.start_node
757
+ end = rel.end_node
758
+ result.edges.append(
759
+ KnowledgeGraphEdge(
760
+ id=f"{edge_id}",
761
+ type=rel.type,
762
+ source=f"{start.id}",
763
+ target=f"{end.id}",
764
+ properties=dict(rel),
765
+ )
766
+ )
767
+
768
+ logger.info(
769
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
770
+ )
771
+
772
+ return result
773
+
774
+ finally:
775
+ if result_set:
776
+ await result_set.consume()
777
+
778
+ except Exception as e:
779
+ logger.error(f"Error getting knowledge graph: {str(e)}")
780
+ return result