DavIvek commited on
Commit
7f20a21
·
1 Parent(s): 8365801

add Memgraph graph storage backend

Browse files
config.ini.example CHANGED
@@ -21,3 +21,6 @@ password = your_password
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
 
 
 
 
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
24
+
25
+ [memgraph]
26
+ uri = bolt://localhost:7687
examples/graph_visual_with_neo4j.py CHANGED
@@ -11,7 +11,7 @@ BATCH_SIZE_EDGES = 100
11
  # Neo4j connection credentials
12
  NEO4J_URI = "bolt://localhost:7687"
13
  NEO4J_USERNAME = "neo4j"
14
- NEO4J_PASSWORD = "your_password"
15
 
16
 
17
  def xml_to_json(xml_file):
 
11
  # Neo4j connection credentials
12
  NEO4J_URI = "bolt://localhost:7687"
13
  NEO4J_USERNAME = "neo4j"
14
+ NEO4J_PASSWORD = "david123"
15
 
16
 
17
  def xml_to_json(xml_file):
examples/lightrag_openai_demo.py CHANGED
@@ -82,6 +82,7 @@ async def initialize_rag():
82
  working_dir=WORKING_DIR,
83
  embedding_func=openai_embed,
84
  llm_model_func=gpt_4o_mini_complete,
 
85
  )
86
 
87
  await rag.initialize_storages()
 
82
  working_dir=WORKING_DIR,
83
  embedding_func=openai_embed,
84
  llm_model_func=gpt_4o_mini_complete,
85
+ graph_storage="MemgraphStorage",
86
  )
87
 
88
  await rag.initialize_storages()
lightrag/kg/__init__.py CHANGED
@@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
 
18
  # "AGEStorage",
19
  # "TiDBGraphStorage",
20
  # "GremlinStorage",
@@ -56,6 +57,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
56
  "NetworkXStorage": [],
57
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
58
  "MongoGraphStorage": [],
 
59
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
60
  "AGEStorage": [
61
  "AGE_POSTGRES_DB",
@@ -108,6 +110,7 @@ STORAGES = {
108
  "PGDocStatusStorage": ".kg.postgres_impl",
109
  "FaissVectorDBStorage": ".kg.faiss_impl",
110
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
 
111
  }
112
 
113
 
 
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
18
+ "MemgraphStorage",
19
  # "AGEStorage",
20
  # "TiDBGraphStorage",
21
  # "GremlinStorage",
 
57
  "NetworkXStorage": [],
58
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
59
  "MongoGraphStorage": [],
60
+ "MemgraphStorage": ["MEMGRAPH_URI"],
61
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
62
  "AGEStorage": [
63
  "AGE_POSTGRES_DB",
 
110
  "PGDocStatusStorage": ".kg.postgres_impl",
111
  "FaissVectorDBStorage": ".kg.faiss_impl",
112
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
113
+ "MemgraphStorage": ".kg.memgraph_impl",
114
  }
115
 
116
 
lightrag/kg/memgraph_impl.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from dataclasses import dataclass
4
+ from typing import final
5
+ import configparser
6
+
7
+ from ..utils import logger
8
+ from ..base import BaseGraphStorage
9
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
10
+ from ..constants import GRAPH_FIELD_SEP
11
+ import pipmaster as pm
12
+
13
+ if not pm.is_installed("neo4j"):
14
+ pm.install("neo4j")
15
+
16
+ from neo4j import (
17
+ AsyncGraphDatabase,
18
+ AsyncManagedTransaction,
19
+ )
20
+
21
+ from dotenv import load_dotenv
22
+
23
+ # use the .env that is inside the current folder
24
+ load_dotenv(dotenv_path=".env", override=False)
25
+
26
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
27
+
28
+ config = configparser.ConfigParser()
29
+ config.read("config.ini", "utf-8")
30
+
31
+ @final
32
+ @dataclass
33
+ class MemgraphStorage(BaseGraphStorage):
34
+ def __init__(self, namespace, global_config, embedding_func):
35
+ super().__init__(
36
+ namespace=namespace,
37
+ global_config=global_config,
38
+ embedding_func=embedding_func,
39
+ )
40
+ self._driver = None
41
+
42
+ async def initialize(self):
43
+ URI = os.environ.get("MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687"))
44
+ USERNAME = os.environ.get("MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback=""))
45
+ PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback=""))
46
+ DATABASE = os.environ.get("MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph"))
47
+
48
+ self._driver = AsyncGraphDatabase.driver(
49
+ URI,
50
+ auth=(USERNAME, PASSWORD),
51
+ )
52
+ self._DATABASE = DATABASE
53
+ try:
54
+ async with self._driver.session(database=DATABASE) as session:
55
+ # Create index for base nodes on entity_id if it doesn't exist
56
+ try:
57
+ await session.run("""CREATE INDEX ON :base(entity_id)""")
58
+ logger.info("Created index on :base(entity_id) in Memgraph.")
59
+ except Exception as e:
60
+ # Index may already exist, which is not an error
61
+ logger.warning(f"Index creation on :base(entity_id) may have failed or already exists: {e}")
62
+ await session.run("RETURN 1")
63
+ logger.info(f"Connected to Memgraph at {URI}")
64
+ except Exception as e:
65
+ logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
66
+ raise
67
+
68
+ async def finalize(self):
69
+ if self._driver is not None:
70
+ await self._driver.close()
71
+ self._driver = None
72
+
73
+ async def __aexit__(self, exc_type, exc, tb):
74
+ await self.finalize()
75
+
76
+ async def index_done_callback(self):
77
+ # Memgraph handles persistence automatically
78
+ pass
79
+
80
+ async def has_node(self, node_id: str) -> bool:
81
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
82
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
83
+ result = await session.run(query, entity_id=node_id)
84
+ single_result = await result.single()
85
+ await result.consume()
86
+ return single_result["node_exists"]
87
+
88
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
89
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
90
+ query = (
91
+ "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
92
+ "RETURN COUNT(r) > 0 AS edgeExists"
93
+ )
94
+ result = await session.run(
95
+ query,
96
+ source_entity_id=source_node_id,
97
+ target_entity_id=target_node_id,
98
+ )
99
+ single_result = await result.single()
100
+ await result.consume()
101
+ return single_result["edgeExists"]
102
+
103
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
104
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
105
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
106
+ result = await session.run(query, entity_id=node_id)
107
+ records = await result.fetch(2)
108
+ await result.consume()
109
+ if records:
110
+ node = records[0]["n"]
111
+ node_dict = dict(node)
112
+ if "labels" in node_dict:
113
+ node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
114
+ return node_dict
115
+ return None
116
+
117
+ async def get_all_labels(self) -> list[str]:
118
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
119
+ query = """
120
+ MATCH (n:base)
121
+ WHERE n.entity_id IS NOT NULL
122
+ RETURN DISTINCT n.entity_id AS label
123
+ ORDER BY label
124
+ """
125
+ result = await session.run(query)
126
+ labels = []
127
+ async for record in result:
128
+ labels.append(record["label"])
129
+ await result.consume()
130
+ return labels
131
+
132
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
133
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
134
+ query = """
135
+ MATCH (n:base {entity_id: $entity_id})
136
+ OPTIONAL MATCH (n)-[r]-(connected:base)
137
+ WHERE connected.entity_id IS NOT NULL
138
+ RETURN n, r, connected
139
+ """
140
+ results = await session.run(query, entity_id=source_node_id)
141
+ edges = []
142
+ async for record in results:
143
+ source_node = record["n"]
144
+ connected_node = record["connected"]
145
+ if not source_node or not connected_node:
146
+ continue
147
+ source_label = source_node.get("entity_id")
148
+ target_label = connected_node.get("entity_id")
149
+ if source_label and target_label:
150
+ edges.append((source_label, target_label))
151
+ await results.consume()
152
+ return edges
153
+
154
+ async def get_edge(self, source_node_id: str, target_node_id: str) -> dict[str, str] | None:
155
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
156
+ query = """
157
+ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
158
+ RETURN properties(r) as edge_properties
159
+ """
160
+ result = await session.run(
161
+ query,
162
+ source_entity_id=source_node_id,
163
+ target_entity_id=target_node_id,
164
+ )
165
+ records = await result.fetch(2)
166
+ await result.consume()
167
+ if records:
168
+ edge_result = dict(records[0]["edge_properties"])
169
+ for key, default_value in {
170
+ "weight": 0.0,
171
+ "source_id": None,
172
+ "description": None,
173
+ "keywords": None,
174
+ }.items():
175
+ if key not in edge_result:
176
+ edge_result[key] = default_value
177
+ return edge_result
178
+ return None
179
+
180
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
181
+ properties = node_data
182
+ entity_type = properties.get("entity_type", "base")
183
+ if "entity_id" not in properties:
184
+ raise ValueError("Memgraph: node properties must contain an 'entity_id' field")
185
+ async with self._driver.session(database=self._DATABASE) as session:
186
+ async def execute_upsert(tx: AsyncManagedTransaction):
187
+ query = (
188
+ f"""
189
+ MERGE (n:base {{entity_id: $entity_id}})
190
+ SET n += $properties
191
+ SET n:`{entity_type}`
192
+ """
193
+ )
194
+ result = await tx.run(query, entity_id=node_id, properties=properties)
195
+ await result.consume()
196
+ await session.execute_write(execute_upsert)
197
+
198
+ async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
199
+ edge_properties = edge_data
200
+ async with self._driver.session(database=self._DATABASE) as session:
201
+ async def execute_upsert(tx: AsyncManagedTransaction):
202
+ query = """
203
+ MATCH (source:base {entity_id: $source_entity_id})
204
+ WITH source
205
+ MATCH (target:base {entity_id: $target_entity_id})
206
+ MERGE (source)-[r:DIRECTED]-(target)
207
+ SET r += $properties
208
+ RETURN r, source, target
209
+ """
210
+ result = await tx.run(
211
+ query,
212
+ source_entity_id=source_node_id,
213
+ target_entity_id=target_node_id,
214
+ properties=edge_properties,
215
+ )
216
+ await result.consume()
217
+ await session.execute_write(execute_upsert)
218
+
219
+ async def delete_node(self, node_id: str) -> None:
220
+ async def _do_delete(tx: AsyncManagedTransaction):
221
+ query = """
222
+ MATCH (n:base {entity_id: $entity_id})
223
+ DETACH DELETE n
224
+ """
225
+ result = await tx.run(query, entity_id=node_id)
226
+ await result.consume()
227
+ async with self._driver.session(database=self._DATABASE) as session:
228
+ await session.execute_write(_do_delete)
229
+
230
+ async def remove_nodes(self, nodes: list[str]):
231
+ for node in nodes:
232
+ await self.delete_node(node)
233
+
234
+ async def remove_edges(self, edges: list[tuple[str, str]]):
235
+ for source, target in edges:
236
+ async def _do_delete_edge(tx: AsyncManagedTransaction):
237
+ query = """
238
+ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
239
+ DELETE r
240
+ """
241
+ result = await tx.run(
242
+ query, source_entity_id=source, target_entity_id=target
243
+ )
244
+ await result.consume()
245
+ async with self._driver.session(database=self._DATABASE) as session:
246
+ await session.execute_write(_do_delete_edge)
247
+
248
+ async def drop(self) -> dict[str, str]:
249
+ try:
250
+ async with self._driver.session(database=self._DATABASE) as session:
251
+ query = "MATCH (n) DETACH DELETE n"
252
+ result = await session.run(query)
253
+ await result.consume()
254
+ logger.info(f"Process {os.getpid()} drop Memgraph database {self._DATABASE}")
255
+ return {"status": "success", "message": "data dropped"}
256
+ except Exception as e:
257
+ logger.error(f"Error dropping Memgraph database {self._DATABASE}: {e}")
258
+ return {"status": "error", "message": str(e)}
259
+
260
+ async def node_degree(self, node_id: str) -> int:
261
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
262
+ query = """
263
+ MATCH (n:base {entity_id: $entity_id})
264
+ OPTIONAL MATCH (n)-[r]-()
265
+ RETURN COUNT(r) AS degree
266
+ """
267
+ result = await session.run(query, entity_id=node_id)
268
+ record = await result.single()
269
+ await result.consume()
270
+ if not record:
271
+ return 0
272
+ return record["degree"]
273
+
274
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
275
+ src_degree = await self.node_degree(src_id)
276
+ trg_degree = await self.node_degree(tgt_id)
277
+ src_degree = 0 if src_degree is None else src_degree
278
+ trg_degree = 0 if trg_degree is None else trg_degree
279
+ return int(src_degree) + int(trg_degree)
280
+
281
+ async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
282
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
283
+ query = """
284
+ UNWIND $chunk_ids AS chunk_id
285
+ MATCH (n:base)
286
+ WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
287
+ RETURN DISTINCT n
288
+ """
289
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
290
+ nodes = []
291
+ async for record in result:
292
+ node = record["n"]
293
+ node_dict = dict(node)
294
+ node_dict["id"] = node_dict.get("entity_id")
295
+ nodes.append(node_dict)
296
+ await result.consume()
297
+ return nodes
298
+
299
+ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
300
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
301
+ query = """
302
+ UNWIND $chunk_ids AS chunk_id
303
+ MATCH (a:base)-[r]-(b:base)
304
+ WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
305
+ RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
306
+ """
307
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
308
+ edges = []
309
+ async for record in result:
310
+ edge_properties = record["properties"]
311
+ edge_properties["source"] = record["source"]
312
+ edge_properties["target"] = record["target"]
313
+ edges.append(edge_properties)
314
+ await result.consume()
315
+ return edges
316
+
317
+ async def get_knowledge_graph(
318
+ self,
319
+ node_label: str,
320
+ max_depth: int = 3,
321
+ max_nodes: int = MAX_GRAPH_NODES,
322
+ ) -> KnowledgeGraph:
323
+ result = KnowledgeGraph()
324
+ seen_nodes = set()
325
+ seen_edges = set()
326
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
327
+ if node_label == "*":
328
+ count_query = "MATCH (n) RETURN count(n) as total"
329
+ count_result = await session.run(count_query)
330
+ count_record = await count_result.single()
331
+ await count_result.consume()
332
+ if count_record and count_record["total"] > max_nodes:
333
+ result.is_truncated = True
334
+ logger.info(f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}")
335
+ main_query = """
336
+ MATCH (n)
337
+ OPTIONAL MATCH (n)-[r]-()
338
+ WITH n, COALESCE(count(r), 0) AS degree
339
+ ORDER BY degree DESC
340
+ LIMIT $max_nodes
341
+ WITH collect({node: n}) AS filtered_nodes
342
+ UNWIND filtered_nodes AS node_info
343
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
344
+ OPTIONAL MATCH (a)-[r]-(b)
345
+ WHERE a IN kept_nodes AND b IN kept_nodes
346
+ RETURN filtered_nodes AS node_info,
347
+ collect(DISTINCT r) AS relationships
348
+ """
349
+ result_set = await session.run(main_query, {"max_nodes": max_nodes})
350
+ record = await result_set.single()
351
+ await result_set.consume()
352
+ else:
353
+ # BFS fallback for Memgraph (no APOC)
354
+ from collections import deque
355
+ # Get the starting node
356
+ start_query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
357
+ node_result = await session.run(start_query, entity_id=node_label)
358
+ node_record = await node_result.single()
359
+ await node_result.consume()
360
+ if not node_record:
361
+ return result
362
+ start_node = node_record["n"]
363
+ start_node_id = start_node.get("entity_id")
364
+ queue = deque([(start_node, 0)])
365
+ visited = set()
366
+ bfs_nodes = []
367
+ while queue and len(bfs_nodes) < max_nodes:
368
+ current_node, depth = queue.popleft()
369
+ node_id = current_node.get("entity_id")
370
+ if node_id in visited:
371
+ continue
372
+ visited.add(node_id)
373
+ bfs_nodes.append(current_node)
374
+ if depth < max_depth:
375
+ # Get neighbors
376
+ neighbor_query = """
377
+ MATCH (n:base {entity_id: $entity_id})-[]-(m:base)
378
+ RETURN m
379
+ """
380
+ neighbors_result = await session.run(neighbor_query, entity_id=node_id)
381
+ neighbors = [rec["m"] for rec in await neighbors_result.to_list()]
382
+ await neighbors_result.consume()
383
+ for neighbor in neighbors:
384
+ neighbor_id = neighbor.get("entity_id")
385
+ if neighbor_id not in visited:
386
+ queue.append((neighbor, depth + 1))
387
+ # Build subgraph
388
+ subgraph_ids = [n.get("entity_id") for n in bfs_nodes]
389
+ # Nodes
390
+ for n in bfs_nodes:
391
+ node_id = n.get("entity_id")
392
+ if node_id not in seen_nodes:
393
+ result.nodes.append(KnowledgeGraphNode(
394
+ id=node_id,
395
+ labels=[node_id],
396
+ properties=dict(n),
397
+ ))
398
+ seen_nodes.add(node_id)
399
+ # Edges
400
+ if subgraph_ids:
401
+ edge_query = """
402
+ MATCH (a:base)-[r]-(b:base)
403
+ WHERE a.entity_id IN $ids AND b.entity_id IN $ids
404
+ RETURN DISTINCT r, a, b
405
+ """
406
+ edge_result = await session.run(edge_query, ids=subgraph_ids)
407
+ async for record in edge_result:
408
+ r = record["r"]
409
+ a = record["a"]
410
+ b = record["b"]
411
+ edge_id = f"{a.get('entity_id')}-{b.get('entity_id')}"
412
+ if edge_id not in seen_edges:
413
+ result.edges.append(KnowledgeGraphEdge(
414
+ id=edge_id,
415
+ type="DIRECTED",
416
+ source=a.get("entity_id"),
417
+ target=b.get("entity_id"),
418
+ properties=dict(r),
419
+ ))
420
+ seen_edges.add(edge_id)
421
+ await edge_result.consume()
422
+ logger.info(f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}")
423
+ return result