Daniel.y commited on
Commit
0b82367
·
unverified ·
2 Parent(s): 471380d 1601d3e

Merge pull request #1758 from HKUDS/memgraph

Browse files
README.md CHANGED
@@ -860,6 +860,41 @@ rag = LightRAG(
860
 
861
  </details>
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  ## Edit Entities and Relations
864
 
865
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
 
860
 
861
  </details>
862
 
863
+ <details>
864
+ <summary> <b>Using Memgraph for Storage</b> </summary>
865
+
866
+ * Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol.
867
+ * You can run Memgraph locally using Docker for easy testing:
868
+ * See: https://memgraph.com/download
869
+
870
+ ```python
871
+ export MEMGRAPH_URI="bolt://localhost:7687"
872
+
873
+ # Setup logger for LightRAG
874
+ setup_logger("lightrag", level="INFO")
875
+
876
+ # When you launch the project, override the default KG: NetworkX
877
+ # by specifying kg="MemgraphStorage".
878
+
879
+ # Note: Default settings use NetworkX
880
+ # Initialize LightRAG with Memgraph implementation.
881
+ async def initialize_rag():
882
+ rag = LightRAG(
883
+ working_dir=WORKING_DIR,
884
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
885
+ graph_storage="MemgraphStorage", #<-----------override KG default
886
+ )
887
+
888
+ # Initialize database connections
889
+ await rag.initialize_storages()
890
+ # Initialize pipeline status for document processing
891
+ await initialize_pipeline_status()
892
+
893
+ return rag
894
+ ```
895
+
896
+ </details>
897
+
898
  ## Edit Entities and Relations
899
 
900
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
config.ini.example CHANGED
@@ -21,3 +21,6 @@ password = your_password
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
 
 
 
 
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
24
+
25
+ [memgraph]
26
+ uri = bolt://localhost:7687
env.example CHANGED
@@ -134,13 +134,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
134
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
135
  ### Graph Storage (Recommended for production deployment)
136
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
 
137
 
138
  ####################################################################
139
  ### Default workspace for all storage types
140
  ### For the purpose of isolation of data for each LightRAG instance
141
  ### Valid characters: a-z, A-Z, 0-9, and _
142
  ####################################################################
143
- # WORKSPACE=doc—
144
 
145
  ### PostgreSQL Configuration
146
  POSTGRES_HOST=localhost
@@ -179,3 +180,10 @@ QDRANT_URL=http://localhost:6333
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
 
 
 
 
 
134
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
135
  ### Graph Storage (Recommended for production deployment)
136
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
137
+ # LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
138
 
139
  ####################################################################
140
  ### Default workspace for all storage types
141
  ### For the purpose of isolation of data for each LightRAG instance
142
  ### Valid characters: a-z, A-Z, 0-9, and _
143
  ####################################################################
144
+ # WORKSPACE=space1
145
 
146
  ### PostgreSQL Configuration
147
  POSTGRES_HOST=localhost
 
180
  ### Redis
181
  REDIS_URI=redis://localhost:6379
182
  # REDIS_WORKSPACE=forced_workspace_name
183
+
184
+ ### Memgraph Configuration
185
+ MEMGRAPH_URI=bolt://localhost:7687
186
+ MEMGRAPH_USERNAME=
187
+ MEMGRAPH_PASSWORD=
188
+ MEMGRAPH_DATABASE=memgraph
189
+ # MEMGRAPH_WORKSPACE=forced_workspace_name
lightrag/kg/__init__.py CHANGED
@@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
 
18
  # "AGEStorage",
19
  # "TiDBGraphStorage",
20
  # "GremlinStorage",
@@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
57
  "NetworkXStorage": [],
58
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
59
  "MongoGraphStorage": [],
 
60
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
61
  "AGEStorage": [
62
  "AGE_POSTGRES_DB",
@@ -111,6 +113,7 @@ STORAGES = {
111
  "PGDocStatusStorage": ".kg.postgres_impl",
112
  "FaissVectorDBStorage": ".kg.faiss_impl",
113
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
 
114
  }
115
 
116
 
 
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
18
+ "MemgraphStorage",
19
  # "AGEStorage",
20
  # "TiDBGraphStorage",
21
  # "GremlinStorage",
 
58
  "NetworkXStorage": [],
59
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
60
  "MongoGraphStorage": [],
61
+ "MemgraphStorage": ["MEMGRAPH_URI"],
62
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
63
  "AGEStorage": [
64
  "AGE_POSTGRES_DB",
 
113
  "PGDocStatusStorage": ".kg.postgres_impl",
114
  "FaissVectorDBStorage": ".kg.faiss_impl",
115
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
116
+ "MemgraphStorage": ".kg.memgraph_impl",
117
  }
118
 
119
 
lightrag/kg/memgraph_impl.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import final
4
+ import configparser
5
+
6
+ from ..utils import logger
7
+ from ..base import BaseGraphStorage
8
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
9
+ from ..constants import GRAPH_FIELD_SEP
10
+ import pipmaster as pm
11
+
12
+ if not pm.is_installed("neo4j"):
13
+ pm.install("neo4j")
14
+
15
+ from neo4j import (
16
+ AsyncGraphDatabase,
17
+ AsyncManagedTransaction,
18
+ )
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ # use the .env that is inside the current folder
23
+ load_dotenv(dotenv_path=".env", override=False)
24
+
25
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
26
+
27
+ config = configparser.ConfigParser()
28
+ config.read("config.ini", "utf-8")
29
+
30
+
31
+ @final
32
+ @dataclass
33
+ class MemgraphStorage(BaseGraphStorage):
34
+ def __init__(self, namespace, global_config, embedding_func, workspace=None):
35
+ memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
36
+ if memgraph_workspace and memgraph_workspace.strip():
37
+ workspace = memgraph_workspace
38
+ super().__init__(
39
+ namespace=namespace,
40
+ workspace=workspace or "",
41
+ global_config=global_config,
42
+ embedding_func=embedding_func,
43
+ )
44
+ self._driver = None
45
+
46
+ def _get_workspace_label(self) -> str:
47
+ """Get workspace label, return 'base' for compatibility when workspace is empty"""
48
+ workspace = getattr(self, "workspace", None)
49
+ return workspace if workspace else "base"
50
+
51
+ async def initialize(self):
52
+ URI = os.environ.get(
53
+ "MEMGRAPH_URI",
54
+ config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
55
+ )
56
+ USERNAME = os.environ.get(
57
+ "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
58
+ )
59
+ PASSWORD = os.environ.get(
60
+ "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
61
+ )
62
+ DATABASE = os.environ.get(
63
+ "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
64
+ )
65
+
66
+ self._driver = AsyncGraphDatabase.driver(
67
+ URI,
68
+ auth=(USERNAME, PASSWORD),
69
+ )
70
+ self._DATABASE = DATABASE
71
+ try:
72
+ async with self._driver.session(database=DATABASE) as session:
73
+ # Create index for base nodes on entity_id if it doesn't exist
74
+ try:
75
+ workspace_label = self._get_workspace_label()
76
+ await session.run(
77
+ f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
78
+ )
79
+ logger.info(
80
+ f"Created index on :{workspace_label}(entity_id) in Memgraph."
81
+ )
82
+ except Exception as e:
83
+ # Index may already exist, which is not an error
84
+ logger.warning(
85
+ f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
86
+ )
87
+ await session.run("RETURN 1")
88
+ logger.info(f"Connected to Memgraph at {URI}")
89
+ except Exception as e:
90
+ logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
91
+ raise
92
+
93
+ async def finalize(self):
94
+ if self._driver is not None:
95
+ await self._driver.close()
96
+ self._driver = None
97
+
98
+ async def __aexit__(self, exc_type, exc, tb):
99
+ await self.finalize()
100
+
101
+ async def index_done_callback(self):
102
+ # Memgraph handles persistence automatically
103
+ pass
104
+
105
+ async def has_node(self, node_id: str) -> bool:
106
+ """
107
+ Check if a node exists in the graph.
108
+
109
+ Args:
110
+ node_id: The ID of the node to check.
111
+
112
+ Returns:
113
+ bool: True if the node exists, False otherwise.
114
+
115
+ Raises:
116
+ Exception: If there is an error checking the node existence.
117
+ """
118
+ if self._driver is None:
119
+ raise RuntimeError(
120
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
121
+ )
122
+ async with self._driver.session(
123
+ database=self._DATABASE, default_access_mode="READ"
124
+ ) as session:
125
+ try:
126
+ workspace_label = self._get_workspace_label()
127
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
128
+ result = await session.run(query, entity_id=node_id)
129
+ single_result = await result.single()
130
+ await result.consume() # Ensure result is fully consumed
131
+ return (
132
+ single_result["node_exists"] if single_result is not None else False
133
+ )
134
+ except Exception as e:
135
+ logger.error(f"Error checking node existence for {node_id}: {str(e)}")
136
+ await result.consume() # Ensure the result is consumed even on error
137
+ raise
138
+
139
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
140
+ """
141
+ Check if an edge exists between two nodes in the graph.
142
+
143
+ Args:
144
+ source_node_id: The ID of the source node.
145
+ target_node_id: The ID of the target node.
146
+
147
+ Returns:
148
+ bool: True if the edge exists, False otherwise.
149
+
150
+ Raises:
151
+ Exception: If there is an error checking the edge existence.
152
+ """
153
+ if self._driver is None:
154
+ raise RuntimeError(
155
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
156
+ )
157
+ async with self._driver.session(
158
+ database=self._DATABASE, default_access_mode="READ"
159
+ ) as session:
160
+ try:
161
+ workspace_label = self._get_workspace_label()
162
+ query = (
163
+ f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
164
+ "RETURN COUNT(r) > 0 AS edgeExists"
165
+ )
166
+ result = await session.run(
167
+ query,
168
+ source_entity_id=source_node_id,
169
+ target_entity_id=target_node_id,
170
+ ) # type: ignore
171
+ single_result = await result.single()
172
+ await result.consume() # Ensure result is fully consumed
173
+ return (
174
+ single_result["edgeExists"] if single_result is not None else False
175
+ )
176
+ except Exception as e:
177
+ logger.error(
178
+ f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
179
+ )
180
+ await result.consume() # Ensure the result is consumed even on error
181
+ raise
182
+
183
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
184
+ """Get node by its label identifier, return only node properties
185
+
186
+ Args:
187
+ node_id: The node label to look up
188
+
189
+ Returns:
190
+ dict: Node properties if found
191
+ None: If node not found
192
+
193
+ Raises:
194
+ Exception: If there is an error executing the query
195
+ """
196
+ if self._driver is None:
197
+ raise RuntimeError(
198
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
199
+ )
200
+ async with self._driver.session(
201
+ database=self._DATABASE, default_access_mode="READ"
202
+ ) as session:
203
+ try:
204
+ workspace_label = self._get_workspace_label()
205
+ query = (
206
+ f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
207
+ )
208
+ result = await session.run(query, entity_id=node_id)
209
+ try:
210
+ records = await result.fetch(
211
+ 2
212
+ ) # Get 2 records for duplication check
213
+
214
+ if len(records) > 1:
215
+ logger.warning(
216
+ f"Multiple nodes found with label '{node_id}'. Using first node."
217
+ )
218
+ if records:
219
+ node = records[0]["n"]
220
+ node_dict = dict(node)
221
+ # Remove workspace label from labels list if it exists
222
+ if "labels" in node_dict:
223
+ node_dict["labels"] = [
224
+ label
225
+ for label in node_dict["labels"]
226
+ if label != workspace_label
227
+ ]
228
+ return node_dict
229
+ return None
230
+ finally:
231
+ await result.consume() # Ensure result is fully consumed
232
+ except Exception as e:
233
+ logger.error(f"Error getting node for {node_id}: {str(e)}")
234
+ raise
235
+
236
+ async def node_degree(self, node_id: str) -> int:
237
+ """Get the degree (number of relationships) of a node with the given label.
238
+ If multiple nodes have the same label, returns the degree of the first node.
239
+ If no node is found, returns 0.
240
+
241
+ Args:
242
+ node_id: The label of the node
243
+
244
+ Returns:
245
+ int: The number of relationships the node has, or 0 if no node found
246
+
247
+ Raises:
248
+ Exception: If there is an error executing the query
249
+ """
250
+ if self._driver is None:
251
+ raise RuntimeError(
252
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
253
+ )
254
+ async with self._driver.session(
255
+ database=self._DATABASE, default_access_mode="READ"
256
+ ) as session:
257
+ try:
258
+ workspace_label = self._get_workspace_label()
259
+ query = f"""
260
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
261
+ OPTIONAL MATCH (n)-[r]-()
262
+ RETURN COUNT(r) AS degree
263
+ """
264
+ result = await session.run(query, entity_id=node_id)
265
+ try:
266
+ record = await result.single()
267
+
268
+ if not record:
269
+ logger.warning(f"No node found with label '{node_id}'")
270
+ return 0
271
+
272
+ degree = record["degree"]
273
+ return degree
274
+ finally:
275
+ await result.consume() # Ensure result is fully consumed
276
+ except Exception as e:
277
+ logger.error(f"Error getting node degree for {node_id}: {str(e)}")
278
+ raise
279
+
280
+ async def get_all_labels(self) -> list[str]:
281
+ """
282
+ Get all existing node labels in the database
283
+ Returns:
284
+ ["Person", "Company", ...] # Alphabetically sorted label list
285
+
286
+ Raises:
287
+ Exception: If there is an error executing the query
288
+ """
289
+ if self._driver is None:
290
+ raise RuntimeError(
291
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
292
+ )
293
+ async with self._driver.session(
294
+ database=self._DATABASE, default_access_mode="READ"
295
+ ) as session:
296
+ try:
297
+ workspace_label = self._get_workspace_label()
298
+ query = f"""
299
+ MATCH (n:`{workspace_label}`)
300
+ WHERE n.entity_id IS NOT NULL
301
+ RETURN DISTINCT n.entity_id AS label
302
+ ORDER BY label
303
+ """
304
+ result = await session.run(query)
305
+ labels = []
306
+ async for record in result:
307
+ labels.append(record["label"])
308
+ await result.consume()
309
+ return labels
310
+ except Exception as e:
311
+ logger.error(f"Error getting all labels: {str(e)}")
312
+ await result.consume() # Ensure the result is consumed even on error
313
+ raise
314
+
315
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
316
+ """Retrieves all edges (relationships) for a particular node identified by its label.
317
+
318
+ Args:
319
+ source_node_id: Label of the node to get edges for
320
+
321
+ Returns:
322
+ list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
323
+ None: If no edges found
324
+
325
+ Raises:
326
+ Exception: If there is an error executing the query
327
+ """
328
+ if self._driver is None:
329
+ raise RuntimeError(
330
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
331
+ )
332
+ try:
333
+ async with self._driver.session(
334
+ database=self._DATABASE, default_access_mode="READ"
335
+ ) as session:
336
+ try:
337
+ workspace_label = self._get_workspace_label()
338
+ query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
339
+ OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
340
+ WHERE connected.entity_id IS NOT NULL
341
+ RETURN n, r, connected"""
342
+ results = await session.run(query, entity_id=source_node_id)
343
+
344
+ edges = []
345
+ async for record in results:
346
+ source_node = record["n"]
347
+ connected_node = record["connected"]
348
+
349
+ # Skip if either node is None
350
+ if not source_node or not connected_node:
351
+ continue
352
+
353
+ source_label = (
354
+ source_node.get("entity_id")
355
+ if source_node.get("entity_id")
356
+ else None
357
+ )
358
+ target_label = (
359
+ connected_node.get("entity_id")
360
+ if connected_node.get("entity_id")
361
+ else None
362
+ )
363
+
364
+ if source_label and target_label:
365
+ edges.append((source_label, target_label))
366
+
367
+ await results.consume() # Ensure results are consumed
368
+ return edges
369
+ except Exception as e:
370
+ logger.error(
371
+ f"Error getting edges for node {source_node_id}: {str(e)}"
372
+ )
373
+ await results.consume() # Ensure results are consumed even on error
374
+ raise
375
+ except Exception as e:
376
+ logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
377
+ raise
378
+
379
+ async def get_edge(
380
+ self, source_node_id: str, target_node_id: str
381
+ ) -> dict[str, str] | None:
382
+ """Get edge properties between two nodes.
383
+
384
+ Args:
385
+ source_node_id: Label of the source node
386
+ target_node_id: Label of the target node
387
+
388
+ Returns:
389
+ dict: Edge properties if found, default properties if not found or on error
390
+
391
+ Raises:
392
+ Exception: If there is an error executing the query
393
+ """
394
+ if self._driver is None:
395
+ raise RuntimeError(
396
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
397
+ )
398
+ async with self._driver.session(
399
+ database=self._DATABASE, default_access_mode="READ"
400
+ ) as session:
401
+ try:
402
+ workspace_label = self._get_workspace_label()
403
+ query = f"""
404
+ MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
405
+ RETURN properties(r) as edge_properties
406
+ """
407
+ result = await session.run(
408
+ query,
409
+ source_entity_id=source_node_id,
410
+ target_entity_id=target_node_id,
411
+ )
412
+ records = await result.fetch(2)
413
+ await result.consume()
414
+ if records:
415
+ edge_result = dict(records[0]["edge_properties"])
416
+ for key, default_value in {
417
+ "weight": 0.0,
418
+ "source_id": None,
419
+ "description": None,
420
+ "keywords": None,
421
+ }.items():
422
+ if key not in edge_result:
423
+ edge_result[key] = default_value
424
+ logger.warning(
425
+ f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
426
+ )
427
+ return edge_result
428
+ return None
429
+ except Exception as e:
430
+ logger.error(
431
+ f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
432
+ )
433
+ await result.consume() # Ensure the result is consumed even on error
434
+ raise
435
+
436
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
437
+ """
438
+ Upsert a node in the Neo4j database.
439
+
440
+ Args:
441
+ node_id: The unique identifier for the node (used as label)
442
+ node_data: Dictionary of node properties
443
+ """
444
+ if self._driver is None:
445
+ raise RuntimeError(
446
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
447
+ )
448
+ properties = node_data
449
+ entity_type = properties["entity_type"]
450
+ if "entity_id" not in properties:
451
+ raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
452
+
453
+ try:
454
+ async with self._driver.session(database=self._DATABASE) as session:
455
+ workspace_label = self._get_workspace_label()
456
+
457
+ async def execute_upsert(tx: AsyncManagedTransaction):
458
+ query = f"""
459
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
460
+ SET n += $properties
461
+ SET n:`{entity_type}`
462
+ """
463
+ result = await tx.run(
464
+ query, entity_id=node_id, properties=properties
465
+ )
466
+ await result.consume() # Ensure result is fully consumed
467
+
468
+ await session.execute_write(execute_upsert)
469
+ except Exception as e:
470
+ logger.error(f"Error during upsert: {str(e)}")
471
+ raise
472
+
473
+ async def upsert_edge(
474
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
475
+ ) -> None:
476
+ """
477
+ Upsert an edge and its properties between two nodes identified by their labels.
478
+ Ensures both source and target nodes exist and are unique before creating the edge.
479
+ Uses entity_id property to uniquely identify nodes.
480
+
481
+ Args:
482
+ source_node_id (str): Label of the source node (used as identifier)
483
+ target_node_id (str): Label of the target node (used as identifier)
484
+ edge_data (dict): Dictionary of properties to set on the edge
485
+
486
+ Raises:
487
+ Exception: If there is an error executing the query
488
+ """
489
+ if self._driver is None:
490
+ raise RuntimeError(
491
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
492
+ )
493
+ try:
494
+ edge_properties = edge_data
495
+ async with self._driver.session(database=self._DATABASE) as session:
496
+
497
+ async def execute_upsert(tx: AsyncManagedTransaction):
498
+ workspace_label = self._get_workspace_label()
499
+ query = f"""
500
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
501
+ WITH source
502
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
503
+ MERGE (source)-[r:DIRECTED]-(target)
504
+ SET r += $properties
505
+ RETURN r, source, target
506
+ """
507
+ result = await tx.run(
508
+ query,
509
+ source_entity_id=source_node_id,
510
+ target_entity_id=target_node_id,
511
+ properties=edge_properties,
512
+ )
513
+ try:
514
+ await result.fetch(2)
515
+ finally:
516
+ await result.consume() # Ensure result is consumed
517
+
518
+ await session.execute_write(execute_upsert)
519
+ except Exception as e:
520
+ logger.error(f"Error during edge upsert: {str(e)}")
521
+ raise
522
+
523
+ async def delete_node(self, node_id: str) -> None:
524
+ """Delete a node with the specified label
525
+
526
+ Args:
527
+ node_id: The label of the node to delete
528
+
529
+ Raises:
530
+ Exception: If there is an error executing the query
531
+ """
532
+ if self._driver is None:
533
+ raise RuntimeError(
534
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
535
+ )
536
+
537
+ async def _do_delete(tx: AsyncManagedTransaction):
538
+ workspace_label = self._get_workspace_label()
539
+ query = f"""
540
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
541
+ DETACH DELETE n
542
+ """
543
+ result = await tx.run(query, entity_id=node_id)
544
+ logger.debug(f"Deleted node with label {node_id}")
545
+ await result.consume()
546
+
547
+ try:
548
+ async with self._driver.session(database=self._DATABASE) as session:
549
+ await session.execute_write(_do_delete)
550
+ except Exception as e:
551
+ logger.error(f"Error during node deletion: {str(e)}")
552
+ raise
553
+
554
+ async def remove_nodes(self, nodes: list[str]):
555
+ """Delete multiple nodes
556
+
557
+ Args:
558
+ nodes: List of node labels to be deleted
559
+ """
560
+ if self._driver is None:
561
+ raise RuntimeError(
562
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
563
+ )
564
+ for node in nodes:
565
+ await self.delete_node(node)
566
+
567
+ async def remove_edges(self, edges: list[tuple[str, str]]):
568
+ """Delete multiple edges
569
+
570
+ Args:
571
+ edges: List of edges to be deleted, each edge is a (source, target) tuple
572
+
573
+ Raises:
574
+ Exception: If there is an error executing the query
575
+ """
576
+ if self._driver is None:
577
+ raise RuntimeError(
578
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
579
+ )
580
+ for source, target in edges:
581
+
582
+ async def _do_delete_edge(tx: AsyncManagedTransaction):
583
+ workspace_label = self._get_workspace_label()
584
+ query = f"""
585
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
586
+ DELETE r
587
+ """
588
+ result = await tx.run(
589
+ query, source_entity_id=source, target_entity_id=target
590
+ )
591
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
592
+ await result.consume() # Ensure result is fully consumed
593
+
594
+ try:
595
+ async with self._driver.session(database=self._DATABASE) as session:
596
+ await session.execute_write(_do_delete_edge)
597
+ except Exception as e:
598
+ logger.error(f"Error during edge deletion: {str(e)}")
599
+ raise
600
+
601
+ async def drop(self) -> dict[str, str]:
602
+ """Drop all data from the current workspace and clean up resources
603
+
604
+ This method will delete all nodes and relationships in the Memgraph database.
605
+
606
+ Returns:
607
+ dict[str, str]: Operation status and message
608
+ - On success: {"status": "success", "message": "data dropped"}
609
+ - On failure: {"status": "error", "message": "<error details>"}
610
+
611
+ Raises:
612
+ Exception: If there is an error executing the query
613
+ """
614
+ if self._driver is None:
615
+ raise RuntimeError(
616
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
617
+ )
618
+ try:
619
+ async with self._driver.session(database=self._DATABASE) as session:
620
+ workspace_label = self._get_workspace_label()
621
+ query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
622
+ result = await session.run(query)
623
+ await result.consume()
624
+ logger.info(
625
+ f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
626
+ )
627
+ return {"status": "success", "message": "workspace data dropped"}
628
+ except Exception as e:
629
+ logger.error(
630
+ f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
631
+ )
632
+ return {"status": "error", "message": str(e)}
633
+
634
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
635
+ """Get the total degree (sum of relationships) of two nodes.
636
+
637
+ Args:
638
+ src_id: Label of the source node
639
+ tgt_id: Label of the target node
640
+
641
+ Returns:
642
+ int: Sum of the degrees of both nodes
643
+ """
644
+ if self._driver is None:
645
+ raise RuntimeError(
646
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
647
+ )
648
+ src_degree = await self.node_degree(src_id)
649
+ trg_degree = await self.node_degree(tgt_id)
650
+
651
+ # Convert None to 0 for addition
652
+ src_degree = 0 if src_degree is None else src_degree
653
+ trg_degree = 0 if trg_degree is None else trg_degree
654
+
655
+ degrees = int(src_degree) + int(trg_degree)
656
+ return degrees
657
+
658
+ async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
659
+ """Get all nodes that are associated with the given chunk_ids.
660
+
661
+ Args:
662
+ chunk_ids: List of chunk IDs to find associated nodes for
663
+
664
+ Returns:
665
+ list[dict]: A list of nodes, where each node is a dictionary of its properties.
666
+ An empty list if no matching nodes are found.
667
+ """
668
+ if self._driver is None:
669
+ raise RuntimeError(
670
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
671
+ )
672
+ workspace_label = self._get_workspace_label()
673
+ async with self._driver.session(
674
+ database=self._DATABASE, default_access_mode="READ"
675
+ ) as session:
676
+ query = f"""
677
+ UNWIND $chunk_ids AS chunk_id
678
+ MATCH (n:`{workspace_label}`)
679
+ WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
680
+ RETURN DISTINCT n
681
+ """
682
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
683
+ nodes = []
684
+ async for record in result:
685
+ node = record["n"]
686
+ node_dict = dict(node)
687
+ node_dict["id"] = node_dict.get("entity_id")
688
+ nodes.append(node_dict)
689
+ await result.consume()
690
+ return nodes
691
+
692
+ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
693
+ """Get all edges that are associated with the given chunk_ids.
694
+
695
+ Args:
696
+ chunk_ids: List of chunk IDs to find associated edges for
697
+
698
+ Returns:
699
+ list[dict]: A list of edges, where each edge is a dictionary of its properties.
700
+ An empty list if no matching edges are found.
701
+ """
702
+ if self._driver is None:
703
+ raise RuntimeError(
704
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
705
+ )
706
+ workspace_label = self._get_workspace_label()
707
+ async with self._driver.session(
708
+ database=self._DATABASE, default_access_mode="READ"
709
+ ) as session:
710
+ query = f"""
711
+ UNWIND $chunk_ids AS chunk_id
712
+ MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
713
+ WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
714
+ WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
715
+ // Ensure we only return each unique edge once by ordering the source and target
716
+ WITH a, b, r,
717
+ CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
718
+ CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
719
+ RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
720
+ """
721
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
722
+ edges = []
723
+ async for record in result:
724
+ edge_properties = record["properties"]
725
+ edge_properties["source"] = record["source"]
726
+ edge_properties["target"] = record["target"]
727
+ edges.append(edge_properties)
728
+ await result.consume()
729
+ return edges
730
+
731
+ async def get_knowledge_graph(
732
+ self,
733
+ node_label: str,
734
+ max_depth: int = 3,
735
+ max_nodes: int = MAX_GRAPH_NODES,
736
+ ) -> KnowledgeGraph:
737
+ """
738
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
739
+
740
+ Args:
741
+ node_label: Label of the starting node, * means all nodes
742
+ max_depth: Maximum depth of the subgraph, Defaults to 3
743
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
744
+
745
+ Returns:
746
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
747
+ indicating whether the graph was truncated due to max_nodes limit
748
+
749
+ Raises:
750
+ Exception: If there is an error executing the query
751
+ """
752
+ if self._driver is None:
753
+ raise RuntimeError(
754
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
755
+ )
756
+
757
+ result = KnowledgeGraph()
758
+ seen_nodes = set()
759
+ seen_edges = set()
760
+ workspace_label = self._get_workspace_label()
761
+ async with self._driver.session(
762
+ database=self._DATABASE, default_access_mode="READ"
763
+ ) as session:
764
+ try:
765
+ if node_label == "*":
766
+ # First check if database has any nodes
767
+ count_query = "MATCH (n) RETURN count(n) as total"
768
+ count_result = None
769
+ total_count = 0
770
+ try:
771
+ count_result = await session.run(count_query)
772
+ count_record = await count_result.single()
773
+ if count_record:
774
+ total_count = count_record["total"]
775
+ if total_count == 0:
776
+ logger.debug("No nodes found in database")
777
+ return result
778
+ if total_count > max_nodes:
779
+ result.is_truncated = True
780
+ logger.info(
781
+ f"Graph truncated: {total_count} nodes found, limited to {max_nodes}"
782
+ )
783
+ finally:
784
+ if count_result:
785
+ await count_result.consume()
786
+
787
+ # Run the main query to get nodes with highest degree
788
+ main_query = f"""
789
+ MATCH (n:`{workspace_label}`)
790
+ OPTIONAL MATCH (n)-[r]-()
791
+ WITH n, COALESCE(count(r), 0) AS degree
792
+ ORDER BY degree DESC
793
+ LIMIT $max_nodes
794
+ WITH collect(n) AS kept_nodes
795
+ MATCH (a)-[r]-(b)
796
+ WHERE a IN kept_nodes AND b IN kept_nodes
797
+ RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
798
+ collect(DISTINCT r) AS relationships
799
+ """
800
+ result_set = None
801
+ try:
802
+ result_set = await session.run(
803
+ main_query, {"max_nodes": max_nodes}
804
+ )
805
+ record = await result_set.single()
806
+ if not record:
807
+ logger.debug("No record returned from main query")
808
+ return result
809
+ finally:
810
+ if result_set:
811
+ await result_set.consume()
812
+
813
+ else:
814
+ bfs_query = f"""
815
+ MATCH (start:`{workspace_label}`)
816
+ WHERE start.entity_id = $entity_id
817
+ WITH start
818
+ CALL {{
819
+ WITH start
820
+ MATCH path = (start)-[*0..{max_depth}]-(node)
821
+ WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
822
+ UNWIND path_nodes AS n
823
+ WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
824
+ WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
825
+ RETURN all_nodes, all_rels
826
+ }}
827
+ WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
828
+ WITH
829
+ CASE
830
+ WHEN total_nodes <= {max_nodes} THEN nodes
831
+ ELSE nodes[0..{max_nodes}]
832
+ END AS limited_nodes,
833
+ relationships,
834
+ total_nodes,
835
+ total_nodes > {max_nodes} AS is_truncated
836
+ RETURN
837
+ [node IN limited_nodes | {{node: node}}] AS node_info,
838
+ relationships,
839
+ total_nodes,
840
+ is_truncated
841
+ """
842
+ result_set = None
843
+ try:
844
+ result_set = await session.run(
845
+ bfs_query,
846
+ {
847
+ "entity_id": node_label,
848
+ },
849
+ )
850
+ record = await result_set.single()
851
+ if not record:
852
+ logger.debug(f"No nodes found for entity_id: {node_label}")
853
+ return result
854
+
855
+ # Check if the query indicates truncation
856
+ if "is_truncated" in record and record["is_truncated"]:
857
+ result.is_truncated = True
858
+ logger.info(
859
+ f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
860
+ )
861
+
862
+ finally:
863
+ if result_set:
864
+ await result_set.consume()
865
+
866
+ # Process the record if it exists
867
+ if record and record["node_info"]:
868
+ for node_info in record["node_info"]:
869
+ node = node_info["node"]
870
+ node_id = node.id
871
+ if node_id not in seen_nodes:
872
+ seen_nodes.add(node_id)
873
+ result.nodes.append(
874
+ KnowledgeGraphNode(
875
+ id=f"{node_id}",
876
+ labels=[node.get("entity_id")],
877
+ properties=dict(node),
878
+ )
879
+ )
880
+
881
+ for rel in record["relationships"]:
882
+ edge_id = rel.id
883
+ if edge_id not in seen_edges:
884
+ seen_edges.add(edge_id)
885
+ start = rel.start_node
886
+ end = rel.end_node
887
+ result.edges.append(
888
+ KnowledgeGraphEdge(
889
+ id=f"{edge_id}",
890
+ type=rel.type,
891
+ source=f"{start.id}",
892
+ target=f"{end.id}",
893
+ properties=dict(rel),
894
+ )
895
+ )
896
+
897
+ logger.info(
898
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
899
+ )
900
+
901
+ except Exception as e:
902
+ logger.error(f"Error getting knowledge graph: {str(e)}")
903
+ # Return empty but properly initialized KnowledgeGraph on error
904
+ return KnowledgeGraph()
905
+
906
+ return result
tests/test_graph_storage.py CHANGED
@@ -10,6 +10,7 @@
10
  - Neo4JStorage
11
  - MongoDBStorage
12
  - PGGraphStorage
 
13
  """
14
 
15
  import asyncio
 
10
  - Neo4JStorage
11
  - MongoDBStorage
12
  - PGGraphStorage
13
+ - MemgraphStorage
14
  """
15
 
16
  import asyncio