zrguo commited on
Commit
bf73d58
·
unverified ·
2 Parent(s): d041c62 5098dd6

Merge pull request #671 from ranfysvalle02/main

Browse files
examples/lightrag_openai_mongodb_graph_demo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
5
+ from lightrag.utils import EmbeddingFunc
6
+ import numpy as np
7
+
8
+ #########
9
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
10
+ # import nest_asyncio
11
+ # nest_asyncio.apply()
12
+ #########
13
+ WORKING_DIR = "./mongodb_test_dir"
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+
18
+ os.environ["OPENAI_API_KEY"] = "sk-"
19
+ os.environ["MONGO_URI"] = "mongodb://0.0.0.0:27017/?directConnection=true"
20
+ os.environ["MONGO_DATABASE"] = "LightRAG"
21
+ os.environ["MONGO_KG_COLLECTION"] = "MDB_KG"
22
+
23
+ # Embedding Configuration and Functions
24
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
25
+ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
26
+
27
+
28
+ async def embedding_func(texts: list[str]) -> np.ndarray:
29
+ return await openai_embed(
30
+ texts,
31
+ model=EMBEDDING_MODEL,
32
+ )
33
+
34
+
35
+ async def get_embedding_dimension():
36
+ test_text = ["This is a test sentence."]
37
+ embedding = await embedding_func(test_text)
38
+ return embedding.shape[1]
39
+
40
+
41
+ async def create_embedding_function_instance():
42
+ # Get embedding dimension
43
+ embedding_dimension = await get_embedding_dimension()
44
+ # Create embedding function instance
45
+ return EmbeddingFunc(
46
+ embedding_dim=embedding_dimension,
47
+ max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
48
+ func=embedding_func,
49
+ )
50
+
51
+
52
+ async def initialize_rag():
53
+ embedding_func_instance = await create_embedding_function_instance()
54
+
55
+ return LightRAG(
56
+ working_dir=WORKING_DIR,
57
+ llm_model_func=gpt_4o_mini_complete,
58
+ embedding_func=embedding_func_instance,
59
+ graph_storage="MongoGraphStorage",
60
+ log_level="DEBUG",
61
+ )
62
+
63
+
64
+ # Run the initialization
65
+ rag = asyncio.run(initialize_rag())
66
+
67
+ with open("book.txt", "r", encoding="utf-8") as f:
68
+ rag.insert(f.read())
69
+
70
+ # Perform naive search
71
+ print(
72
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
73
+ )
lightrag/kg/mongo_impl.py CHANGED
@@ -2,15 +2,18 @@ import os
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
  import pipmaster as pm
 
5
 
6
  if not pm.is_installed("pymongo"):
7
  pm.install("pymongo")
8
 
9
  from pymongo import MongoClient
10
- from typing import Union
 
11
  from lightrag.utils import logger
12
 
13
  from lightrag.base import BaseKVStorage
 
14
 
15
 
16
  @dataclass
@@ -78,3 +81,360 @@ class MongoKVStorage(BaseKVStorage):
78
  async def drop(self):
79
  """ """
80
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
+ import np
6
 
7
  if not pm.is_installed("pymongo"):
8
  pm.install("pymongo")
9
 
10
  from pymongo import MongoClient
11
+ from motor.motor_asyncio import AsyncIOMotorClient
12
+ from typing import Union, List, Tuple
13
  from lightrag.utils import logger
14
 
15
  from lightrag.base import BaseKVStorage
16
+ from lightrag.base import BaseGraphStorage
17
 
18
 
19
  @dataclass
 
81
  async def drop(self):
82
  """ """
83
  pass
84
+
85
+
86
+ @dataclass
87
+ class MongoGraphStorage(BaseGraphStorage):
88
+ """
89
+ A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries.
90
+ """
91
+
92
+ def __init__(self, namespace, global_config, embedding_func):
93
+ super().__init__(
94
+ namespace=namespace,
95
+ global_config=global_config,
96
+ embedding_func=embedding_func,
97
+ )
98
+ self.client = AsyncIOMotorClient(
99
+ os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
100
+ )
101
+ self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
102
+ self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
103
+
104
+ #
105
+ # -------------------------------------------------------------------------
106
+ # HELPER: $graphLookup pipeline
107
+ # -------------------------------------------------------------------------
108
+ #
109
+
110
+ async def _graph_lookup(
111
+ self, start_node_id: str, max_depth: int = None
112
+ ) -> List[dict]:
113
+ """
114
+ Performs a $graphLookup starting from 'start_node_id' and returns
115
+ all reachable documents (including the start node itself).
116
+
117
+ Pipeline Explanation:
118
+ - 1) $match: We match the start node document by _id = start_node_id.
119
+ - 2) $graphLookup:
120
+ "from": same collection,
121
+ "startWith": "$edges.target" (the immediate neighbors in 'edges'),
122
+ "connectFromField": "edges.target",
123
+ "connectToField": "_id",
124
+ "as": "reachableNodes",
125
+ "maxDepth": max_depth (if provided),
126
+ "depthField": "depth" (used for debugging or filtering).
127
+ - 3) We add an $project or $unwind as needed to extract data.
128
+ """
129
+ pipeline = [
130
+ {"$match": {"_id": start_node_id}},
131
+ {
132
+ "$graphLookup": {
133
+ "from": self.collection.name,
134
+ "startWith": "$edges.target",
135
+ "connectFromField": "edges.target",
136
+ "connectToField": "_id",
137
+ "as": "reachableNodes",
138
+ "depthField": "depth",
139
+ }
140
+ },
141
+ ]
142
+
143
+ # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
144
+ if max_depth is not None:
145
+ pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
146
+
147
+ # Return the matching doc plus a field "reachableNodes"
148
+ cursor = self.collection.aggregate(pipeline)
149
+ results = await cursor.to_list(None)
150
+
151
+ # If there's no matching node, results = [].
152
+ # Otherwise, results[0] is the start node doc,
153
+ # plus results[0]["reachableNodes"] is the array of connected docs.
154
+ return results
155
+
156
+ #
157
+ # -------------------------------------------------------------------------
158
+ # BASIC QUERIES
159
+ # -------------------------------------------------------------------------
160
+ #
161
+
162
+ async def has_node(self, node_id: str) -> bool:
163
+ """
164
+ Check if node_id is present in the collection by looking up its doc.
165
+ No real need for $graphLookup here, but let's keep it direct.
166
+ """
167
+ doc = await self.collection.find_one({"_id": node_id}, {"_id": 1})
168
+ return doc is not None
169
+
170
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
171
+ """
172
+ Check if there's a direct single-hop edge from source_node_id to target_node_id.
173
+
174
+ We'll do a $graphLookup with maxDepth=0 from the source node—meaning
175
+ “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1
176
+ and then see if the target node is in the "reachableNodes" at depth=0.
177
+
178
+ But typically for a direct edge, we might just do a find_one.
179
+ Below is a demonstration approach.
180
+ """
181
+
182
+ # We can do a single-hop graphLookup (maxDepth=0 or 1).
183
+ # Then check if the target_node appears among the edges array.
184
+ pipeline = [
185
+ {"$match": {"_id": source_node_id}},
186
+ {
187
+ "$graphLookup": {
188
+ "from": self.collection.name,
189
+ "startWith": "$edges.target",
190
+ "connectFromField": "edges.target",
191
+ "connectToField": "_id",
192
+ "as": "reachableNodes",
193
+ "depthField": "depth",
194
+ "maxDepth": 0, # means: do not follow beyond immediate edges
195
+ }
196
+ },
197
+ {
198
+ "$project": {
199
+ "_id": 0,
200
+ "reachableNodes._id": 1, # only keep the _id from the subdocs
201
+ }
202
+ },
203
+ ]
204
+ cursor = self.collection.aggregate(pipeline)
205
+ results = await cursor.to_list(None)
206
+ if not results:
207
+ return False
208
+
209
+ # results[0]["reachableNodes"] are the immediate neighbors
210
+ reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
211
+ return target_node_id in reachable_ids
212
+
213
+ #
214
+ # -------------------------------------------------------------------------
215
+ # DEGREES
216
+ # -------------------------------------------------------------------------
217
+ #
218
+
219
+ async def node_degree(self, node_id: str) -> int:
220
+ """
221
+ Returns the total number of edges connected to node_id (both inbound and outbound).
222
+ The easiest approach is typically two queries:
223
+ - count of edges array in node_id's doc
224
+ - count of how many other docs have node_id in their edges.target.
225
+
226
+ But we'll do a $graphLookup demonstration for inbound edges:
227
+ 1) Outbound edges: direct from node's edges array
228
+ 2) Inbound edges: we can do a special $graphLookup from all docs
229
+ or do an explicit match.
230
+
231
+ For demonstration, let's do this in two steps (with second step $graphLookup).
232
+ """
233
+ # --- 1) Outbound edges (direct from doc) ---
234
+ doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
235
+ if not doc:
236
+ return 0
237
+ outbound_count = len(doc.get("edges", []))
238
+
239
+ # --- 2) Inbound edges:
240
+ # A simple way is: find all docs where "edges.target" == node_id.
241
+ # But let's do a $graphLookup from `node_id` in REVERSE.
242
+ # There's a trick to do "reverse" graphLookups: you'd store
243
+ # reversed edges or do a more advanced pipeline. Typically you'd do
244
+ # a direct match. We'll just do a direct match for inbound.
245
+ inbound_count_pipeline = [
246
+ {"$match": {"edges.target": node_id}},
247
+ {
248
+ "$project": {
249
+ "matchingEdgesCount": {
250
+ "$size": {
251
+ "$filter": {
252
+ "input": "$edges",
253
+ "as": "edge",
254
+ "cond": {"$eq": ["$$edge.target", node_id]},
255
+ }
256
+ }
257
+ }
258
+ }
259
+ },
260
+ {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
261
+ ]
262
+ inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
263
+ inbound_result = await inbound_cursor.to_list(None)
264
+ inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
265
+
266
+ return outbound_count + inbound_count
267
+
268
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
269
+ """
270
+ If your graph can hold multiple edges from the same src to the same tgt
271
+ (e.g. different 'relation' values), you can sum them. If it's always
272
+ one edge, this is typically 1 or 0.
273
+
274
+ We'll do a single-hop $graphLookup from src_id,
275
+ then count how many edges reference tgt_id at depth=0.
276
+ """
277
+ pipeline = [
278
+ {"$match": {"_id": src_id}},
279
+ {
280
+ "$graphLookup": {
281
+ "from": self.collection.name,
282
+ "startWith": "$edges.target",
283
+ "connectFromField": "edges.target",
284
+ "connectToField": "_id",
285
+ "as": "neighbors",
286
+ "depthField": "depth",
287
+ "maxDepth": 0,
288
+ }
289
+ },
290
+ {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
291
+ ]
292
+ cursor = self.collection.aggregate(pipeline)
293
+ results = await cursor.to_list(None)
294
+ if not results:
295
+ return 0
296
+
297
+ # We can simply count how many edges in `results[0].edges` have target == tgt_id.
298
+ edges = results[0].get("edges", [])
299
+ count = sum(1 for e in edges if e.get("target") == tgt_id)
300
+ return count
301
+
302
+ #
303
+ # -------------------------------------------------------------------------
304
+ # GETTERS
305
+ # -------------------------------------------------------------------------
306
+ #
307
+
308
+ async def get_node(self, node_id: str) -> Union[dict, None]:
309
+ """
310
+ Return the full node document (including "edges"), or None if missing.
311
+ """
312
+ return await self.collection.find_one({"_id": node_id})
313
+
314
+ async def get_edge(
315
+ self, source_node_id: str, target_node_id: str
316
+ ) -> Union[dict, None]:
317
+ """
318
+ Return the first edge dict from source_node_id to target_node_id if it exists.
319
+ Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
320
+ """
321
+ pipeline = [
322
+ {"$match": {"_id": source_node_id}},
323
+ {
324
+ "$graphLookup": {
325
+ "from": self.collection.name,
326
+ "startWith": "$edges.target",
327
+ "connectFromField": "edges.target",
328
+ "connectToField": "_id",
329
+ "as": "neighbors",
330
+ "depthField": "depth",
331
+ "maxDepth": 0,
332
+ }
333
+ },
334
+ {"$project": {"edges": 1}},
335
+ ]
336
+ cursor = self.collection.aggregate(pipeline)
337
+ docs = await cursor.to_list(None)
338
+ if not docs:
339
+ return None
340
+
341
+ for e in docs[0].get("edges", []):
342
+ if e.get("target") == target_node_id:
343
+ return e
344
+ return None
345
+
346
+ async def get_node_edges(
347
+ self, source_node_id: str
348
+ ) -> Union[List[Tuple[str, str]], None]:
349
+ """
350
+ Return a list of (target_id, relation) for direct edges from source_node_id.
351
+ Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
352
+ """
353
+ pipeline = [
354
+ {"$match": {"_id": source_node_id}},
355
+ {
356
+ "$graphLookup": {
357
+ "from": self.collection.name,
358
+ "startWith": "$edges.target",
359
+ "connectFromField": "edges.target",
360
+ "connectToField": "_id",
361
+ "as": "neighbors",
362
+ "depthField": "depth",
363
+ "maxDepth": 0,
364
+ }
365
+ },
366
+ {"$project": {"_id": 0, "edges": 1}},
367
+ ]
368
+ cursor = self.collection.aggregate(pipeline)
369
+ result = await cursor.to_list(None)
370
+ if not result:
371
+ return None
372
+
373
+ edges = result[0].get("edges", [])
374
+ return [(e["target"], e["relation"]) for e in edges]
375
+
376
+ #
377
+ # -------------------------------------------------------------------------
378
+ # UPSERTS
379
+ # -------------------------------------------------------------------------
380
+ #
381
+
382
+ async def upsert_node(self, node_id: str, node_data: dict):
383
+ """
384
+ Insert or update a node document. If new, create an empty edges array.
385
+ """
386
+ # By default, preserve existing 'edges'.
387
+ # We'll only set 'edges' to [] on insert (no overwrite).
388
+ update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
389
+ await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
390
+
391
+ async def upsert_edge(
392
+ self, source_node_id: str, target_node_id: str, edge_data: dict
393
+ ):
394
+ """
395
+ Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
396
+ If an edge with the same target exists, we remove it and re-insert with updated data.
397
+ """
398
+ # Ensure source node exists
399
+ await self.upsert_node(source_node_id, {})
400
+
401
+ # Remove existing edge (if any)
402
+ await self.collection.update_one(
403
+ {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
404
+ )
405
+
406
+ # Insert new edge
407
+ new_edge = {"target": target_node_id}
408
+ new_edge.update(edge_data)
409
+ await self.collection.update_one(
410
+ {"_id": source_node_id}, {"$push": {"edges": new_edge}}
411
+ )
412
+
413
+ #
414
+ # -------------------------------------------------------------------------
415
+ # DELETION
416
+ # -------------------------------------------------------------------------
417
+ #
418
+
419
+ async def delete_node(self, node_id: str):
420
+ """
421
+ 1) Remove node’s doc entirely.
422
+ 2) Remove inbound edges from any doc that references node_id.
423
+ """
424
+ # Remove inbound edges from all other docs
425
+ await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
426
+
427
+ # Remove the node doc
428
+ await self.collection.delete_one({"_id": node_id})
429
+
430
+ #
431
+ # -------------------------------------------------------------------------
432
+ # EMBEDDINGS (NOT IMPLEMENTED)
433
+ # -------------------------------------------------------------------------
434
+ #
435
+
436
+ async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
437
+ """
438
+ Placeholder for demonstration, raises NotImplementedError.
439
+ """
440
+ raise NotImplementedError("Node embedding is not used in lightrag.")
lightrag/lightrag.py CHANGED
@@ -48,6 +48,7 @@ STORAGES = {
48
  "OracleVectorDBStorage": ".kg.oracle_impl",
49
  "MilvusVectorDBStorge": ".kg.milvus_impl",
50
  "MongoKVStorage": ".kg.mongo_impl",
 
51
  "RedisKVStorage": ".kg.redis_impl",
52
  "ChromaVectorDBStorage": ".kg.chroma_impl",
53
  "TiDBKVStorage": ".kg.tidb_impl",
 
48
  "OracleVectorDBStorage": ".kg.oracle_impl",
49
  "MilvusVectorDBStorge": ".kg.milvus_impl",
50
  "MongoKVStorage": ".kg.mongo_impl",
51
+ "MongoGraphStorage": ".kg.mongo_impl",
52
  "RedisKVStorage": ".kg.redis_impl",
53
  "ChromaVectorDBStorage": ".kg.chroma_impl",
54
  "TiDBKVStorage": ".kg.tidb_impl",