yangdx commited on
Commit
c62fd8f
·
2 Parent(s): 82a1a8d bf73d58

Merge from main

Browse files
examples/lightrag_openai_mongodb_graph_demo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
5
+ from lightrag.utils import EmbeddingFunc
6
+ import numpy as np
7
+
8
+ #########
9
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
10
+ # import nest_asyncio
11
+ # nest_asyncio.apply()
12
+ #########
13
+ WORKING_DIR = "./mongodb_test_dir"
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+
18
+ os.environ["OPENAI_API_KEY"] = "sk-"
19
+ os.environ["MONGO_URI"] = "mongodb://0.0.0.0:27017/?directConnection=true"
20
+ os.environ["MONGO_DATABASE"] = "LightRAG"
21
+ os.environ["MONGO_KG_COLLECTION"] = "MDB_KG"
22
+
23
+ # Embedding Configuration and Functions
24
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
25
+ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
26
+
27
+
28
+ async def embedding_func(texts: list[str]) -> np.ndarray:
29
+ return await openai_embed(
30
+ texts,
31
+ model=EMBEDDING_MODEL,
32
+ )
33
+
34
+
35
+ async def get_embedding_dimension():
36
+ test_text = ["This is a test sentence."]
37
+ embedding = await embedding_func(test_text)
38
+ return embedding.shape[1]
39
+
40
+
41
+ async def create_embedding_function_instance():
42
+ # Get embedding dimension
43
+ embedding_dimension = await get_embedding_dimension()
44
+ # Create embedding function instance
45
+ return EmbeddingFunc(
46
+ embedding_dim=embedding_dimension,
47
+ max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
48
+ func=embedding_func,
49
+ )
50
+
51
+
52
+ async def initialize_rag():
53
+ embedding_func_instance = await create_embedding_function_instance()
54
+
55
+ return LightRAG(
56
+ working_dir=WORKING_DIR,
57
+ llm_model_func=gpt_4o_mini_complete,
58
+ embedding_func=embedding_func_instance,
59
+ graph_storage="MongoGraphStorage",
60
+ log_level="DEBUG",
61
+ )
62
+
63
+
64
+ # Run the initialization
65
+ rag = asyncio.run(initialize_rag())
66
+
67
+ with open("book.txt", "r", encoding="utf-8") as f:
68
+ rag.insert(f.read())
69
+
70
+ # Perform naive search
71
+ print(
72
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
73
+ )
lightrag/api/lightrag_server.py CHANGED
@@ -48,18 +48,23 @@ def estimate_tokens(text: str) -> int:
48
  return int(tokens)
49
 
50
 
51
- # Constants for emulated Ollama model information
52
- LIGHTRAG_NAME = "lightrag"
53
- LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
54
- LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
55
- LIGHTRAG_SIZE = 7365960935 # it's a dummy value
56
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
57
- LIGHTRAG_DIGEST = "sha256:lightrag"
58
-
59
- KV_STORAGE = "JsonKVStorage"
60
- DOC_STATUS_STORAGE = "JsonDocStatusStorage"
61
- GRAPH_STORAGE = "NetworkXStorage"
62
- VECTOR_STORAGE = "NanoVectorDBStorage"
 
 
 
 
 
63
 
64
  # read config.ini
65
  config = configparser.ConfigParser()
@@ -68,8 +73,8 @@ config.read("config.ini", "utf-8")
68
  redis_uri = config.get("redis", "uri", fallback=None)
69
  if redis_uri:
70
  os.environ["REDIS_URI"] = redis_uri
71
- KV_STORAGE = "RedisKVStorage"
72
- DOC_STATUS_STORAGE = "RedisKVStorage"
73
 
74
  # Neo4j config
75
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
@@ -79,7 +84,7 @@ if neo4j_uri:
79
  os.environ["NEO4J_URI"] = neo4j_uri
80
  os.environ["NEO4J_USERNAME"] = neo4j_username
81
  os.environ["NEO4J_PASSWORD"] = neo4j_password
82
- GRAPH_STORAGE = "Neo4JStorage"
83
 
84
  # Milvus config
85
  milvus_uri = config.get("milvus", "uri", fallback=None)
@@ -91,7 +96,7 @@ if milvus_uri:
91
  os.environ["MILVUS_USER"] = milvus_user
92
  os.environ["MILVUS_PASSWORD"] = milvus_password
93
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
94
- VECTOR_STORAGE = "MilvusVectorDBStorge"
95
 
96
  # MongoDB config
97
  mongo_uri = config.get("mongodb", "uri", fallback=None)
@@ -99,8 +104,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
99
  if mongo_uri:
100
  os.environ["MONGO_URI"] = mongo_uri
101
  os.environ["MONGO_DATABASE"] = mongo_database
102
- KV_STORAGE = "MongoKVStorage"
103
- DOC_STATUS_STORAGE = "MongoKVStorage"
104
 
105
 
106
  def get_default_host(binding_type: str) -> str:
@@ -217,7 +222,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
217
  # System Configuration
218
  ASCIIColors.magenta("\n🛠️ System Configuration:")
219
  ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
220
- ASCIIColors.yellow(f"{LIGHTRAG_MODEL}")
221
  ASCIIColors.white(" ├─ Log Level: ", end="")
222
  ASCIIColors.yellow(f"{args.log_level}")
223
  ASCIIColors.white(" ├─ Timeout: ", end="")
@@ -502,8 +507,19 @@ def parse_args() -> argparse.Namespace:
502
  help="Cosine similarity threshold (default: from env or 0.4)",
503
  )
504
 
 
 
 
 
 
 
 
 
 
505
  args = parser.parse_args()
506
 
 
 
507
  return args
508
 
509
 
@@ -556,7 +572,7 @@ class OllamaMessage(BaseModel):
556
 
557
 
558
  class OllamaChatRequest(BaseModel):
559
- model: str = LIGHTRAG_MODEL
560
  messages: List[OllamaMessage]
561
  stream: bool = True # Default to streaming mode
562
  options: Optional[Dict[str, Any]] = None
@@ -571,7 +587,7 @@ class OllamaChatResponse(BaseModel):
571
 
572
 
573
  class OllamaGenerateRequest(BaseModel):
574
- model: str = LIGHTRAG_MODEL
575
  prompt: str
576
  system: Optional[str] = None
577
  stream: bool = False
@@ -860,10 +876,10 @@ def create_app(args):
860
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
861
  else {},
862
  embedding_func=embedding_func,
863
- kv_storage=KV_STORAGE,
864
- graph_storage=GRAPH_STORAGE,
865
- vector_storage=VECTOR_STORAGE,
866
- doc_status_storage=DOC_STATUS_STORAGE,
867
  vector_db_storage_cls_kwargs={
868
  "cosine_better_than_threshold": args.cosine_threshold
869
  },
@@ -883,10 +899,10 @@ def create_app(args):
883
  llm_model_max_async=args.max_async,
884
  llm_model_max_token_size=args.max_tokens,
885
  embedding_func=embedding_func,
886
- kv_storage=KV_STORAGE,
887
- graph_storage=GRAPH_STORAGE,
888
- vector_storage=VECTOR_STORAGE,
889
- doc_status_storage=DOC_STATUS_STORAGE,
890
  vector_db_storage_cls_kwargs={
891
  "cosine_better_than_threshold": args.cosine_threshold
892
  },
@@ -1452,16 +1468,16 @@ def create_app(args):
1452
  return OllamaTagResponse(
1453
  models=[
1454
  {
1455
- "name": LIGHTRAG_MODEL,
1456
- "model": LIGHTRAG_MODEL,
1457
- "size": LIGHTRAG_SIZE,
1458
- "digest": LIGHTRAG_DIGEST,
1459
- "modified_at": LIGHTRAG_CREATED_AT,
1460
  "details": {
1461
  "parent_model": "",
1462
  "format": "gguf",
1463
- "family": LIGHTRAG_NAME,
1464
- "families": [LIGHTRAG_NAME],
1465
  "parameter_size": "13B",
1466
  "quantization_level": "Q4_0",
1467
  },
@@ -1524,8 +1540,8 @@ def create_app(args):
1524
  total_response = response
1525
 
1526
  data = {
1527
- "model": LIGHTRAG_MODEL,
1528
- "created_at": LIGHTRAG_CREATED_AT,
1529
  "response": response,
1530
  "done": False,
1531
  }
@@ -1537,8 +1553,8 @@ def create_app(args):
1537
  eval_time = last_chunk_time - first_chunk_time
1538
 
1539
  data = {
1540
- "model": LIGHTRAG_MODEL,
1541
- "created_at": LIGHTRAG_CREATED_AT,
1542
  "done": True,
1543
  "total_duration": total_time,
1544
  "load_duration": 0,
@@ -1558,8 +1574,8 @@ def create_app(args):
1558
 
1559
  total_response += chunk
1560
  data = {
1561
- "model": LIGHTRAG_MODEL,
1562
- "created_at": LIGHTRAG_CREATED_AT,
1563
  "response": chunk,
1564
  "done": False,
1565
  }
@@ -1571,8 +1587,8 @@ def create_app(args):
1571
  eval_time = last_chunk_time - first_chunk_time
1572
 
1573
  data = {
1574
- "model": LIGHTRAG_MODEL,
1575
- "created_at": LIGHTRAG_CREATED_AT,
1576
  "done": True,
1577
  "total_duration": total_time,
1578
  "load_duration": 0,
@@ -1616,8 +1632,8 @@ def create_app(args):
1616
  eval_time = last_chunk_time - first_chunk_time
1617
 
1618
  return {
1619
- "model": LIGHTRAG_MODEL,
1620
- "created_at": LIGHTRAG_CREATED_AT,
1621
  "response": str(response_text),
1622
  "done": True,
1623
  "total_duration": total_time,
@@ -1690,8 +1706,8 @@ def create_app(args):
1690
  total_response = response
1691
 
1692
  data = {
1693
- "model": LIGHTRAG_MODEL,
1694
- "created_at": LIGHTRAG_CREATED_AT,
1695
  "message": {
1696
  "role": "assistant",
1697
  "content": response,
@@ -1707,8 +1723,8 @@ def create_app(args):
1707
  eval_time = last_chunk_time - first_chunk_time
1708
 
1709
  data = {
1710
- "model": LIGHTRAG_MODEL,
1711
- "created_at": LIGHTRAG_CREATED_AT,
1712
  "done": True,
1713
  "total_duration": total_time,
1714
  "load_duration": 0,
@@ -1728,8 +1744,8 @@ def create_app(args):
1728
 
1729
  total_response += chunk
1730
  data = {
1731
- "model": LIGHTRAG_MODEL,
1732
- "created_at": LIGHTRAG_CREATED_AT,
1733
  "message": {
1734
  "role": "assistant",
1735
  "content": chunk,
@@ -1745,8 +1761,8 @@ def create_app(args):
1745
  eval_time = last_chunk_time - first_chunk_time
1746
 
1747
  data = {
1748
- "model": LIGHTRAG_MODEL,
1749
- "created_at": LIGHTRAG_CREATED_AT,
1750
  "done": True,
1751
  "total_duration": total_time,
1752
  "load_duration": 0,
@@ -1801,8 +1817,8 @@ def create_app(args):
1801
  eval_time = last_chunk_time - first_chunk_time
1802
 
1803
  return {
1804
- "model": LIGHTRAG_MODEL,
1805
- "created_at": LIGHTRAG_CREATED_AT,
1806
  "message": {
1807
  "role": "assistant",
1808
  "content": str(response_text),
@@ -1845,10 +1861,10 @@ def create_app(args):
1845
  "embedding_binding_host": args.embedding_binding_host,
1846
  "embedding_model": args.embedding_model,
1847
  "max_tokens": args.max_tokens,
1848
- "kv_storage": KV_STORAGE,
1849
- "doc_status_storage": DOC_STATUS_STORAGE,
1850
- "graph_storage": GRAPH_STORAGE,
1851
- "vector_storage": VECTOR_STORAGE,
1852
  },
1853
  }
1854
 
 
48
  return int(tokens)
49
 
50
 
51
+ class OllamaServerInfos:
52
+ # Constants for emulated Ollama model information
53
+ LIGHTRAG_NAME = "lightrag"
54
+ LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
55
+ LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
56
+ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
57
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
58
+ LIGHTRAG_DIGEST = "sha256:lightrag"
59
+
60
+ KV_STORAGE = "JsonKVStorage"
61
+ DOC_STATUS_STORAGE = "JsonDocStatusStorage"
62
+ GRAPH_STORAGE = "NetworkXStorage"
63
+ VECTOR_STORAGE = "NanoVectorDBStorage"
64
+
65
+
66
+ # Add infos
67
+ ollama_server_infos = OllamaServerInfos()
68
 
69
  # read config.ini
70
  config = configparser.ConfigParser()
 
73
  redis_uri = config.get("redis", "uri", fallback=None)
74
  if redis_uri:
75
  os.environ["REDIS_URI"] = redis_uri
76
+ ollama_server_infos.KV_STORAGE = "RedisKVStorage"
77
+ ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage"
78
 
79
  # Neo4j config
80
  neo4j_uri = config.get("neo4j", "uri", fallback=None)
 
84
  os.environ["NEO4J_URI"] = neo4j_uri
85
  os.environ["NEO4J_USERNAME"] = neo4j_username
86
  os.environ["NEO4J_PASSWORD"] = neo4j_password
87
+ ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage"
88
 
89
  # Milvus config
90
  milvus_uri = config.get("milvus", "uri", fallback=None)
 
96
  os.environ["MILVUS_USER"] = milvus_user
97
  os.environ["MILVUS_PASSWORD"] = milvus_password
98
  os.environ["MILVUS_DB_NAME"] = milvus_db_name
99
+ ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge"
100
 
101
  # MongoDB config
102
  mongo_uri = config.get("mongodb", "uri", fallback=None)
 
104
  if mongo_uri:
105
  os.environ["MONGO_URI"] = mongo_uri
106
  os.environ["MONGO_DATABASE"] = mongo_database
107
+ ollama_server_infos.KV_STORAGE = "MongoKVStorage"
108
+ ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage"
109
 
110
 
111
  def get_default_host(binding_type: str) -> str:
 
222
  # System Configuration
223
  ASCIIColors.magenta("\n🛠️ System Configuration:")
224
  ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
225
+ ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
226
  ASCIIColors.white(" ├─ Log Level: ", end="")
227
  ASCIIColors.yellow(f"{args.log_level}")
228
  ASCIIColors.white(" ├─ Timeout: ", end="")
 
507
  help="Cosine similarity threshold (default: from env or 0.4)",
508
  )
509
 
510
+ parser.add_argument(
511
+ "--simulated-model-name",
512
+ type=str,
513
+ default=get_env_value(
514
+ "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
515
+ ),
516
+ help="Number of conversation history turns to include (default: from env or 3)",
517
+ )
518
+
519
  args = parser.parse_args()
520
 
521
+ ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
522
+
523
  return args
524
 
525
 
 
572
 
573
 
574
  class OllamaChatRequest(BaseModel):
575
+ model: str = ollama_server_infos.LIGHTRAG_MODEL
576
  messages: List[OllamaMessage]
577
  stream: bool = True # Default to streaming mode
578
  options: Optional[Dict[str, Any]] = None
 
587
 
588
 
589
  class OllamaGenerateRequest(BaseModel):
590
+ model: str = ollama_server_infos.LIGHTRAG_MODEL
591
  prompt: str
592
  system: Optional[str] = None
593
  stream: bool = False
 
876
  if args.llm_binding == "lollms" or args.llm_binding == "ollama"
877
  else {},
878
  embedding_func=embedding_func,
879
+ kv_storage=ollama_server_infos.KV_STORAGE,
880
+ graph_storage=ollama_server_infos.GRAPH_STORAGE,
881
+ vector_storage=ollama_server_infos.VECTOR_STORAGE,
882
+ doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
883
  vector_db_storage_cls_kwargs={
884
  "cosine_better_than_threshold": args.cosine_threshold
885
  },
 
899
  llm_model_max_async=args.max_async,
900
  llm_model_max_token_size=args.max_tokens,
901
  embedding_func=embedding_func,
902
+ kv_storage=ollama_server_infos.KV_STORAGE,
903
+ graph_storage=ollama_server_infos.GRAPH_STORAGE,
904
+ vector_storage=ollama_server_infos.VECTOR_STORAGE,
905
+ doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
906
  vector_db_storage_cls_kwargs={
907
  "cosine_better_than_threshold": args.cosine_threshold
908
  },
 
1468
  return OllamaTagResponse(
1469
  models=[
1470
  {
1471
+ "name": ollama_server_infos.LIGHTRAG_MODEL,
1472
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1473
+ "size": ollama_server_infos.LIGHTRAG_SIZE,
1474
+ "digest": ollama_server_infos.LIGHTRAG_DIGEST,
1475
+ "modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1476
  "details": {
1477
  "parent_model": "",
1478
  "format": "gguf",
1479
+ "family": ollama_server_infos.LIGHTRAG_NAME,
1480
+ "families": [ollama_server_infos.LIGHTRAG_NAME],
1481
  "parameter_size": "13B",
1482
  "quantization_level": "Q4_0",
1483
  },
 
1540
  total_response = response
1541
 
1542
  data = {
1543
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1544
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1545
  "response": response,
1546
  "done": False,
1547
  }
 
1553
  eval_time = last_chunk_time - first_chunk_time
1554
 
1555
  data = {
1556
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1557
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1558
  "done": True,
1559
  "total_duration": total_time,
1560
  "load_duration": 0,
 
1574
 
1575
  total_response += chunk
1576
  data = {
1577
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1578
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1579
  "response": chunk,
1580
  "done": False,
1581
  }
 
1587
  eval_time = last_chunk_time - first_chunk_time
1588
 
1589
  data = {
1590
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1591
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1592
  "done": True,
1593
  "total_duration": total_time,
1594
  "load_duration": 0,
 
1632
  eval_time = last_chunk_time - first_chunk_time
1633
 
1634
  return {
1635
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1636
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1637
  "response": str(response_text),
1638
  "done": True,
1639
  "total_duration": total_time,
 
1706
  total_response = response
1707
 
1708
  data = {
1709
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1710
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1711
  "message": {
1712
  "role": "assistant",
1713
  "content": response,
 
1723
  eval_time = last_chunk_time - first_chunk_time
1724
 
1725
  data = {
1726
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1727
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1728
  "done": True,
1729
  "total_duration": total_time,
1730
  "load_duration": 0,
 
1744
 
1745
  total_response += chunk
1746
  data = {
1747
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1748
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1749
  "message": {
1750
  "role": "assistant",
1751
  "content": chunk,
 
1761
  eval_time = last_chunk_time - first_chunk_time
1762
 
1763
  data = {
1764
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1765
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1766
  "done": True,
1767
  "total_duration": total_time,
1768
  "load_duration": 0,
 
1817
  eval_time = last_chunk_time - first_chunk_time
1818
 
1819
  return {
1820
+ "model": ollama_server_infos.LIGHTRAG_MODEL,
1821
+ "created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
1822
  "message": {
1823
  "role": "assistant",
1824
  "content": str(response_text),
 
1861
  "embedding_binding_host": args.embedding_binding_host,
1862
  "embedding_model": args.embedding_model,
1863
  "max_tokens": args.max_tokens,
1864
+ "kv_storage": ollama_server_infos.KV_STORAGE,
1865
+ "doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE,
1866
+ "graph_storage": ollama_server_infos.GRAPH_STORAGE,
1867
+ "vector_storage": ollama_server_infos.VECTOR_STORAGE,
1868
  },
1869
  }
1870
 
lightrag/kg/mongo_impl.py CHANGED
@@ -2,15 +2,18 @@ import os
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
  import pipmaster as pm
 
5
 
6
  if not pm.is_installed("pymongo"):
7
  pm.install("pymongo")
8
 
9
  from pymongo import MongoClient
10
- from typing import Union
 
11
  from lightrag.utils import logger
12
 
13
  from lightrag.base import BaseKVStorage
 
14
 
15
 
16
  @dataclass
@@ -78,3 +81,360 @@ class MongoKVStorage(BaseKVStorage):
78
  async def drop(self):
79
  """ """
80
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
+ import np
6
 
7
  if not pm.is_installed("pymongo"):
8
  pm.install("pymongo")
9
 
10
  from pymongo import MongoClient
11
+ from motor.motor_asyncio import AsyncIOMotorClient
12
+ from typing import Union, List, Tuple
13
  from lightrag.utils import logger
14
 
15
  from lightrag.base import BaseKVStorage
16
+ from lightrag.base import BaseGraphStorage
17
 
18
 
19
  @dataclass
 
81
  async def drop(self):
82
  """ """
83
  pass
84
+
85
+
86
+ @dataclass
87
+ class MongoGraphStorage(BaseGraphStorage):
88
+ """
89
+ A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries.
90
+ """
91
+
92
+ def __init__(self, namespace, global_config, embedding_func):
93
+ super().__init__(
94
+ namespace=namespace,
95
+ global_config=global_config,
96
+ embedding_func=embedding_func,
97
+ )
98
+ self.client = AsyncIOMotorClient(
99
+ os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
100
+ )
101
+ self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
102
+ self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
103
+
104
+ #
105
+ # -------------------------------------------------------------------------
106
+ # HELPER: $graphLookup pipeline
107
+ # -------------------------------------------------------------------------
108
+ #
109
+
110
+ async def _graph_lookup(
111
+ self, start_node_id: str, max_depth: int = None
112
+ ) -> List[dict]:
113
+ """
114
+ Performs a $graphLookup starting from 'start_node_id' and returns
115
+ all reachable documents (including the start node itself).
116
+
117
+ Pipeline Explanation:
118
+ - 1) $match: We match the start node document by _id = start_node_id.
119
+ - 2) $graphLookup:
120
+ "from": same collection,
121
+ "startWith": "$edges.target" (the immediate neighbors in 'edges'),
122
+ "connectFromField": "edges.target",
123
+ "connectToField": "_id",
124
+ "as": "reachableNodes",
125
+ "maxDepth": max_depth (if provided),
126
+ "depthField": "depth" (used for debugging or filtering).
127
+ - 3) We add an $project or $unwind as needed to extract data.
128
+ """
129
+ pipeline = [
130
+ {"$match": {"_id": start_node_id}},
131
+ {
132
+ "$graphLookup": {
133
+ "from": self.collection.name,
134
+ "startWith": "$edges.target",
135
+ "connectFromField": "edges.target",
136
+ "connectToField": "_id",
137
+ "as": "reachableNodes",
138
+ "depthField": "depth",
139
+ }
140
+ },
141
+ ]
142
+
143
+ # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
144
+ if max_depth is not None:
145
+ pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
146
+
147
+ # Return the matching doc plus a field "reachableNodes"
148
+ cursor = self.collection.aggregate(pipeline)
149
+ results = await cursor.to_list(None)
150
+
151
+ # If there's no matching node, results = [].
152
+ # Otherwise, results[0] is the start node doc,
153
+ # plus results[0]["reachableNodes"] is the array of connected docs.
154
+ return results
155
+
156
+ #
157
+ # -------------------------------------------------------------------------
158
+ # BASIC QUERIES
159
+ # -------------------------------------------------------------------------
160
+ #
161
+
162
+ async def has_node(self, node_id: str) -> bool:
163
+ """
164
+ Check if node_id is present in the collection by looking up its doc.
165
+ No real need for $graphLookup here, but let's keep it direct.
166
+ """
167
+ doc = await self.collection.find_one({"_id": node_id}, {"_id": 1})
168
+ return doc is not None
169
+
170
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
171
+ """
172
+ Check if there's a direct single-hop edge from source_node_id to target_node_id.
173
+
174
+ We'll do a $graphLookup with maxDepth=0 from the source node—meaning
175
+ “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1
176
+ and then see if the target node is in the "reachableNodes" at depth=0.
177
+
178
+ But typically for a direct edge, we might just do a find_one.
179
+ Below is a demonstration approach.
180
+ """
181
+
182
+ # We can do a single-hop graphLookup (maxDepth=0 or 1).
183
+ # Then check if the target_node appears among the edges array.
184
+ pipeline = [
185
+ {"$match": {"_id": source_node_id}},
186
+ {
187
+ "$graphLookup": {
188
+ "from": self.collection.name,
189
+ "startWith": "$edges.target",
190
+ "connectFromField": "edges.target",
191
+ "connectToField": "_id",
192
+ "as": "reachableNodes",
193
+ "depthField": "depth",
194
+ "maxDepth": 0, # means: do not follow beyond immediate edges
195
+ }
196
+ },
197
+ {
198
+ "$project": {
199
+ "_id": 0,
200
+ "reachableNodes._id": 1, # only keep the _id from the subdocs
201
+ }
202
+ },
203
+ ]
204
+ cursor = self.collection.aggregate(pipeline)
205
+ results = await cursor.to_list(None)
206
+ if not results:
207
+ return False
208
+
209
+ # results[0]["reachableNodes"] are the immediate neighbors
210
+ reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
211
+ return target_node_id in reachable_ids
212
+
213
+ #
214
+ # -------------------------------------------------------------------------
215
+ # DEGREES
216
+ # -------------------------------------------------------------------------
217
+ #
218
+
219
+ async def node_degree(self, node_id: str) -> int:
220
+ """
221
+ Returns the total number of edges connected to node_id (both inbound and outbound).
222
+ The easiest approach is typically two queries:
223
+ - count of edges array in node_id's doc
224
+ - count of how many other docs have node_id in their edges.target.
225
+
226
+ But we'll do a $graphLookup demonstration for inbound edges:
227
+ 1) Outbound edges: direct from node's edges array
228
+ 2) Inbound edges: we can do a special $graphLookup from all docs
229
+ or do an explicit match.
230
+
231
+ For demonstration, let's do this in two steps (with second step $graphLookup).
232
+ """
233
+ # --- 1) Outbound edges (direct from doc) ---
234
+ doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
235
+ if not doc:
236
+ return 0
237
+ outbound_count = len(doc.get("edges", []))
238
+
239
+ # --- 2) Inbound edges:
240
+ # A simple way is: find all docs where "edges.target" == node_id.
241
+ # But let's do a $graphLookup from `node_id` in REVERSE.
242
+ # There's a trick to do "reverse" graphLookups: you'd store
243
+ # reversed edges or do a more advanced pipeline. Typically you'd do
244
+ # a direct match. We'll just do a direct match for inbound.
245
+ inbound_count_pipeline = [
246
+ {"$match": {"edges.target": node_id}},
247
+ {
248
+ "$project": {
249
+ "matchingEdgesCount": {
250
+ "$size": {
251
+ "$filter": {
252
+ "input": "$edges",
253
+ "as": "edge",
254
+ "cond": {"$eq": ["$$edge.target", node_id]},
255
+ }
256
+ }
257
+ }
258
+ }
259
+ },
260
+ {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
261
+ ]
262
+ inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
263
+ inbound_result = await inbound_cursor.to_list(None)
264
+ inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
265
+
266
+ return outbound_count + inbound_count
267
+
268
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
269
+ """
270
+ If your graph can hold multiple edges from the same src to the same tgt
271
+ (e.g. different 'relation' values), you can sum them. If it's always
272
+ one edge, this is typically 1 or 0.
273
+
274
+ We'll do a single-hop $graphLookup from src_id,
275
+ then count how many edges reference tgt_id at depth=0.
276
+ """
277
+ pipeline = [
278
+ {"$match": {"_id": src_id}},
279
+ {
280
+ "$graphLookup": {
281
+ "from": self.collection.name,
282
+ "startWith": "$edges.target",
283
+ "connectFromField": "edges.target",
284
+ "connectToField": "_id",
285
+ "as": "neighbors",
286
+ "depthField": "depth",
287
+ "maxDepth": 0,
288
+ }
289
+ },
290
+ {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
291
+ ]
292
+ cursor = self.collection.aggregate(pipeline)
293
+ results = await cursor.to_list(None)
294
+ if not results:
295
+ return 0
296
+
297
+ # We can simply count how many edges in `results[0].edges` have target == tgt_id.
298
+ edges = results[0].get("edges", [])
299
+ count = sum(1 for e in edges if e.get("target") == tgt_id)
300
+ return count
301
+
302
+ #
303
+ # -------------------------------------------------------------------------
304
+ # GETTERS
305
+ # -------------------------------------------------------------------------
306
+ #
307
+
308
+ async def get_node(self, node_id: str) -> Union[dict, None]:
309
+ """
310
+ Return the full node document (including "edges"), or None if missing.
311
+ """
312
+ return await self.collection.find_one({"_id": node_id})
313
+
314
+ async def get_edge(
315
+ self, source_node_id: str, target_node_id: str
316
+ ) -> Union[dict, None]:
317
+ """
318
+ Return the first edge dict from source_node_id to target_node_id if it exists.
319
+ Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
320
+ """
321
+ pipeline = [
322
+ {"$match": {"_id": source_node_id}},
323
+ {
324
+ "$graphLookup": {
325
+ "from": self.collection.name,
326
+ "startWith": "$edges.target",
327
+ "connectFromField": "edges.target",
328
+ "connectToField": "_id",
329
+ "as": "neighbors",
330
+ "depthField": "depth",
331
+ "maxDepth": 0,
332
+ }
333
+ },
334
+ {"$project": {"edges": 1}},
335
+ ]
336
+ cursor = self.collection.aggregate(pipeline)
337
+ docs = await cursor.to_list(None)
338
+ if not docs:
339
+ return None
340
+
341
+ for e in docs[0].get("edges", []):
342
+ if e.get("target") == target_node_id:
343
+ return e
344
+ return None
345
+
346
+ async def get_node_edges(
347
+ self, source_node_id: str
348
+ ) -> Union[List[Tuple[str, str]], None]:
349
+ """
350
+ Return a list of (target_id, relation) for direct edges from source_node_id.
351
+ Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
352
+ """
353
+ pipeline = [
354
+ {"$match": {"_id": source_node_id}},
355
+ {
356
+ "$graphLookup": {
357
+ "from": self.collection.name,
358
+ "startWith": "$edges.target",
359
+ "connectFromField": "edges.target",
360
+ "connectToField": "_id",
361
+ "as": "neighbors",
362
+ "depthField": "depth",
363
+ "maxDepth": 0,
364
+ }
365
+ },
366
+ {"$project": {"_id": 0, "edges": 1}},
367
+ ]
368
+ cursor = self.collection.aggregate(pipeline)
369
+ result = await cursor.to_list(None)
370
+ if not result:
371
+ return None
372
+
373
+ edges = result[0].get("edges", [])
374
+ return [(e["target"], e["relation"]) for e in edges]
375
+
376
+ #
377
+ # -------------------------------------------------------------------------
378
+ # UPSERTS
379
+ # -------------------------------------------------------------------------
380
+ #
381
+
382
+ async def upsert_node(self, node_id: str, node_data: dict):
383
+ """
384
+ Insert or update a node document. If new, create an empty edges array.
385
+ """
386
+ # By default, preserve existing 'edges'.
387
+ # We'll only set 'edges' to [] on insert (no overwrite).
388
+ update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
389
+ await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
390
+
391
+ async def upsert_edge(
392
+ self, source_node_id: str, target_node_id: str, edge_data: dict
393
+ ):
394
+ """
395
+ Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
396
+ If an edge with the same target exists, we remove it and re-insert with updated data.
397
+ """
398
+ # Ensure source node exists
399
+ await self.upsert_node(source_node_id, {})
400
+
401
+ # Remove existing edge (if any)
402
+ await self.collection.update_one(
403
+ {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
404
+ )
405
+
406
+ # Insert new edge
407
+ new_edge = {"target": target_node_id}
408
+ new_edge.update(edge_data)
409
+ await self.collection.update_one(
410
+ {"_id": source_node_id}, {"$push": {"edges": new_edge}}
411
+ )
412
+
413
+ #
414
+ # -------------------------------------------------------------------------
415
+ # DELETION
416
+ # -------------------------------------------------------------------------
417
+ #
418
+
419
+ async def delete_node(self, node_id: str):
420
+ """
421
+ 1) Remove node’s doc entirely.
422
+ 2) Remove inbound edges from any doc that references node_id.
423
+ """
424
+ # Remove inbound edges from all other docs
425
+ await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
426
+
427
+ # Remove the node doc
428
+ await self.collection.delete_one({"_id": node_id})
429
+
430
+ #
431
+ # -------------------------------------------------------------------------
432
+ # EMBEDDINGS (NOT IMPLEMENTED)
433
+ # -------------------------------------------------------------------------
434
+ #
435
+
436
+ async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
437
+ """
438
+ Placeholder for demonstration, raises NotImplementedError.
439
+ """
440
+ raise NotImplementedError("Node embedding is not used in lightrag.")
lightrag/lightrag.py CHANGED
@@ -48,6 +48,7 @@ STORAGES = {
48
  "OracleVectorDBStorage": ".kg.oracle_impl",
49
  "MilvusVectorDBStorge": ".kg.milvus_impl",
50
  "MongoKVStorage": ".kg.mongo_impl",
 
51
  "RedisKVStorage": ".kg.redis_impl",
52
  "ChromaVectorDBStorage": ".kg.chroma_impl",
53
  "TiDBKVStorage": ".kg.tidb_impl",
 
48
  "OracleVectorDBStorage": ".kg.oracle_impl",
49
  "MilvusVectorDBStorge": ".kg.milvus_impl",
50
  "MongoKVStorage": ".kg.mongo_impl",
51
+ "MongoGraphStorage": ".kg.mongo_impl",
52
  "RedisKVStorage": ".kg.redis_impl",
53
  "ChromaVectorDBStorage": ".kg.chroma_impl",
54
  "TiDBKVStorage": ".kg.tidb_impl",