zrguo commited on
Commit
fc5f4ac
·
unverified ·
2 Parent(s): e708e45 ac355a1

Merge pull request #751 from ArnoChenFx/add-MongoDocStatusStorage

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -130,7 +130,7 @@ if mongo_uri:
130
  os.environ["MONGO_URI"] = mongo_uri
131
  os.environ["MONGO_DATABASE"] = mongo_database
132
  rag_storage_config.KV_STORAGE = "MongoKVStorage"
133
- rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
134
  if mongo_graph:
135
  rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
136
 
 
130
  os.environ["MONGO_URI"] = mongo_uri
131
  os.environ["MONGO_DATABASE"] = mongo_database
132
  rag_storage_config.KV_STORAGE = "MongoKVStorage"
133
+ rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
134
  if mongo_graph:
135
  rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
136
 
lightrag/base.py CHANGED
@@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage):
227
  """Get all pending documents"""
228
  raise NotImplementedError
229
 
 
 
 
 
 
 
 
 
230
  async def update_doc_status(self, data: dict[str, Any]) -> None:
231
  """Updates the status of a document. By default, it calls upsert."""
232
  await self.upsert(data)
 
227
  """Get all pending documents"""
228
  raise NotImplementedError
229
 
230
+ async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
231
+ """Get all processing documents"""
232
+ raise NotImplementedError
233
+
234
+ async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
235
+ """Get all procesed documents"""
236
+ raise NotImplementedError
237
+
238
  async def update_doc_status(self, data: dict[str, Any]) -> None:
239
  """Updates the status of a document. By default, it calls upsert."""
240
  await self.upsert(data)
lightrag/kg/mongo_impl.py CHANGED
@@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from pymongo import MongoClient
18
 
19
- from ..base import BaseGraphStorage, BaseKVStorage
 
 
 
 
 
 
20
  from ..namespace import NameSpace, is_namespace
21
  from ..utils import logger
22
 
@@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage):
39
 
40
  async def filter_keys(self, data: set[str]) -> set[str]:
41
  existing_ids = [
42
- str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
 
43
  ]
44
  return set([s for s in data if s not in existing_ids])
45
 
@@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage):
77
  await self._data.drop()
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @dataclass
81
  class MongoGraphStorage(BaseGraphStorage):
82
  """
 
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from pymongo import MongoClient
18
 
19
+ from ..base import (
20
+ BaseGraphStorage,
21
+ BaseKVStorage,
22
+ DocProcessingStatus,
23
+ DocStatus,
24
+ DocStatusStorage,
25
+ )
26
  from ..namespace import NameSpace, is_namespace
27
  from ..utils import logger
28
 
 
45
 
46
  async def filter_keys(self, data: set[str]) -> set[str]:
47
  existing_ids = [
48
+ str(x["_id"])
49
+ for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
50
  ]
51
  return set([s for s in data if s not in existing_ids])
52
 
 
84
  await self._data.drop()
85
 
86
 
87
+ @dataclass
88
+ class MongoDocStatusStorage(DocStatusStorage):
89
+ def __post_init__(self):
90
+ client = MongoClient(
91
+ os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
92
+ )
93
+ database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
94
+ self._data = database.get_collection(self.namespace)
95
+ logger.info(f"Use MongoDB as doc status {self.namespace}")
96
+
97
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
98
+ return self._data.find_one({"_id": id})
99
+
100
+ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
101
+ return list(self._data.find({"_id": {"$in": ids}}))
102
+
103
+ async def filter_keys(self, data: set[str]) -> set[str]:
104
+ existing_ids = [
105
+ str(x["_id"])
106
+ for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
107
+ ]
108
+ return set([s for s in data if s not in existing_ids])
109
+
110
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
111
+ for k, v in data.items():
112
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
113
+ data[k]["_id"] = k
114
+
115
+ async def drop(self) -> None:
116
+ """Drop the collection"""
117
+ await self._data.drop()
118
+
119
+ async def get_status_counts(self) -> dict[str, int]:
120
+ """Get counts of documents in each status"""
121
+ pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
122
+ result = list(self._data.aggregate(pipeline))
123
+ counts = {}
124
+ for doc in result:
125
+ counts[doc["_id"]] = doc["count"]
126
+ return counts
127
+
128
+ async def get_docs_by_status(
129
+ self, status: DocStatus
130
+ ) -> dict[str, DocProcessingStatus]:
131
+ """Get all documents by status"""
132
+ result = list(self._data.find({"status": status.value}))
133
+ return {
134
+ doc["_id"]: DocProcessingStatus(
135
+ content=doc["content"],
136
+ content_summary=doc.get("content_summary"),
137
+ content_length=doc["content_length"],
138
+ status=doc["status"],
139
+ created_at=doc.get("created_at"),
140
+ updated_at=doc.get("updated_at"),
141
+ chunks_count=doc.get("chunks_count", -1),
142
+ )
143
+ for doc in result
144
+ }
145
+
146
+ async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
147
+ """Get all failed documents"""
148
+ return await self.get_docs_by_status(DocStatus.FAILED)
149
+
150
+ async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
151
+ """Get all pending documents"""
152
+ return await self.get_docs_by_status(DocStatus.PENDING)
153
+
154
+ async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
155
+ """Get all processing documents"""
156
+ return await self.get_docs_by_status(DocStatus.PROCESSING)
157
+
158
+ async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
159
+ """Get all procesed documents"""
160
+ return await self.get_docs_by_status(DocStatus.PROCESSED)
161
+
162
+
163
  @dataclass
164
  class MongoGraphStorage(BaseGraphStorage):
165
  """
lightrag/kg/postgres_impl.py CHANGED
@@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage):
495
  """Get all pending documents"""
496
  return await self.get_docs_by_status(DocStatus.PENDING)
497
 
 
 
 
 
 
 
 
 
498
  async def index_done_callback(self):
499
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
500
  logger.info("Doc status had been saved into postgresql db!")
 
495
  """Get all pending documents"""
496
  return await self.get_docs_by_status(DocStatus.PENDING)
497
 
498
+ async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
499
+ """Get all processing documents"""
500
+ return await self.get_docs_by_status(DocStatus.PROCESSING)
501
+
502
+ async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
503
+ """Get all procesed documents"""
504
+ return await self.get_docs_by_status(DocStatus.PROCESSED)
505
+
506
  async def index_done_callback(self):
507
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
508
  logger.info("Doc status had been saved into postgresql db!")
lightrag/kg/qdrant_impl.py CHANGED
@@ -70,7 +70,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
70
  )
71
 
72
  async def upsert(self, data: dict[str, dict]):
73
- logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
74
  if not len(data):
75
  logger.warning("You insert an empty data to vector DB")
76
  return []
@@ -123,5 +122,4 @@ class QdrantVectorDBStorage(BaseVectorStorage):
123
  limit=top_k,
124
  with_payload=True,
125
  )
126
- logger.debug(f"query result: {results}")
127
  return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
 
70
  )
71
 
72
  async def upsert(self, data: dict[str, dict]):
 
73
  if not len(data):
74
  logger.warning("You insert an empty data to vector DB")
75
  return []
 
122
  limit=top_k,
123
  with_payload=True,
124
  )
 
125
  return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
lightrag/lightrag.py CHANGED
@@ -46,6 +46,7 @@ STORAGES = {
46
  "OracleVectorDBStorage": ".kg.oracle_impl",
47
  "MilvusVectorDBStorge": ".kg.milvus_impl",
48
  "MongoKVStorage": ".kg.mongo_impl",
 
49
  "MongoGraphStorage": ".kg.mongo_impl",
50
  "RedisKVStorage": ".kg.redis_impl",
51
  "ChromaVectorDBStorage": ".kg.chroma_impl",
 
46
  "OracleVectorDBStorage": ".kg.oracle_impl",
47
  "MilvusVectorDBStorge": ".kg.milvus_impl",
48
  "MongoKVStorage": ".kg.mongo_impl",
49
+ "MongoDocStatusStorage": ".kg.mongo_impl",
50
  "MongoGraphStorage": ".kg.mongo_impl",
51
  "RedisKVStorage": ".kg.redis_impl",
52
  "ChromaVectorDBStorage": ".kg.chroma_impl",