ArnoChen commited on
Commit
5799da7
·
1 Parent(s): 8a5180d

unify doc status retrieval with get_docs_by_status

Browse files
lightrag/base.py CHANGED
@@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
249
  """Get counts of documents in each status"""
250
  raise NotImplementedError
251
 
252
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
253
- """Get all failed documents"""
254
- raise NotImplementedError
255
-
256
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
257
- """Get all pending documents"""
258
- raise NotImplementedError
259
-
260
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
261
- """Get all processing documents"""
262
- raise NotImplementedError
263
-
264
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
265
- """Get all procesed documents"""
266
  raise NotImplementedError
267
 
268
  async def update_doc_status(self, data: dict[str, Any]) -> None:
 
249
  """Get counts of documents in each status"""
250
  raise NotImplementedError
251
 
252
+ async def get_docs_by_status(
253
+ self, status: DocStatus
254
+ ) -> dict[str, DocProcessingStatus]:
255
+ """Get all documents with a specific status"""
 
 
 
 
 
 
 
 
 
 
256
  raise NotImplementedError
257
 
258
  async def update_doc_status(self, data: dict[str, Any]) -> None:
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
93
  counts[doc["status"]] += 1
94
  return counts
95
 
96
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
97
- """Get all failed documents"""
 
 
98
  return {
99
  k: DocProcessingStatus(**v)
100
  for k, v in self._data.items()
101
- if v["status"] == DocStatus.FAILED
102
- }
103
-
104
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
105
- """Get all pending documents"""
106
- return {
107
- k: DocProcessingStatus(**v)
108
- for k, v in self._data.items()
109
- if v["status"] == DocStatus.PENDING
110
- }
111
-
112
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
113
- """Get all processed documents"""
114
- return {
115
- k: DocProcessingStatus(**v)
116
- for k, v in self._data.items()
117
- if v["status"] == DocStatus.PROCESSED
118
- }
119
-
120
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
121
- """Get all processing documents"""
122
- return {
123
- k: DocProcessingStatus(**v)
124
- for k, v in self._data.items()
125
- if v["status"] == DocStatus.PROCESSING
126
  }
127
 
128
  async def index_done_callback(self):
 
93
  counts[doc["status"]] += 1
94
  return counts
95
 
96
+ async def get_docs_by_status(
97
+ self, status: DocStatus
98
+ ) -> dict[str, DocProcessingStatus]:
99
+ """all documents with a specific status"""
100
  return {
101
  k: DocProcessingStatus(**v)
102
  for k, v in self._data.items()
103
+ if v["status"] == status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  }
105
 
106
  async def index_done_callback(self):
lightrag/kg/mongo_impl.py CHANGED
@@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage):
175
  async def get_docs_by_status(
176
  self, status: DocStatus
177
  ) -> dict[str, DocProcessingStatus]:
178
- """Get all documents by status"""
179
  cursor = self._data.find({"status": status.value})
180
  result = await cursor.to_list()
181
  return {
@@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
191
  for doc in result
192
  }
193
 
194
- async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
195
- """Get all failed documents"""
196
- return await self.get_docs_by_status(DocStatus.FAILED)
197
-
198
- async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
199
- """Get all pending documents"""
200
- return await self.get_docs_by_status(DocStatus.PENDING)
201
-
202
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
203
- """Get all processing documents"""
204
- return await self.get_docs_by_status(DocStatus.PROCESSING)
205
-
206
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
207
- """Get all procesed documents"""
208
- return await self.get_docs_by_status(DocStatus.PROCESSED)
209
-
210
 
211
  @dataclass
212
  class MongoGraphStorage(BaseGraphStorage):
 
175
  async def get_docs_by_status(
176
  self, status: DocStatus
177
  ) -> dict[str, DocProcessingStatus]:
178
+ """Get all documents with a specific status"""
179
  cursor = self._data.find({"status": status.value})
180
  result = await cursor.to_list()
181
  return {
 
191
  for doc in result
192
  }
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  @dataclass
196
  class MongoGraphStorage(BaseGraphStorage):
lightrag/kg/postgres_impl.py CHANGED
@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
468
  async def get_docs_by_status(
469
  self, status: DocStatus
470
  ) -> Dict[str, DocProcessingStatus]:
471
- """Get all documents by status"""
472
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
473
  params = {"workspace": self.db.workspace, "status": status}
474
  result = await self.db.query(sql, params, True)
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
485
  for element in result
486
  }
487
 
488
- async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
489
- """Get all failed documents"""
490
- return await self.get_docs_by_status(DocStatus.FAILED)
491
-
492
- async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
493
- """Get all pending documents"""
494
- return await self.get_docs_by_status(DocStatus.PENDING)
495
-
496
- async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
497
- """Get all processing documents"""
498
- return await self.get_docs_by_status(DocStatus.PROCESSING)
499
-
500
- async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
501
- """Get all procesed documents"""
502
- return await self.get_docs_by_status(DocStatus.PROCESSED)
503
-
504
  async def index_done_callback(self):
505
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
506
  logger.info("Doc status had been saved into postgresql db!")
 
468
  async def get_docs_by_status(
469
  self, status: DocStatus
470
  ) -> Dict[str, DocProcessingStatus]:
471
+ """all documents with a specific status"""
472
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
473
  params = {"workspace": self.db.workspace, "status": status}
474
  result = await self.db.query(sql, params, True)
 
485
  for element in result
486
  }
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  async def index_done_callback(self):
489
  """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
490
  logger.info("Doc status had been saved into postgresql db!")
lightrag/lightrag.py CHANGED
@@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = {
89
  "PGDocStatusStorage",
90
  "MongoDocStatusStorage",
91
  ],
92
- "required_methods": ["get_pending_docs"],
93
  },
94
  }
95
 
@@ -230,7 +230,7 @@ class LightRAG:
230
  """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
231
 
232
  working_dir: str = field(
233
- default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}'
234
  )
235
  """Directory where cache and temporary files are stored."""
236
 
@@ -715,11 +715,11 @@ class LightRAG:
715
  # 1. Get all pending, failed, and abnormally terminated processing documents.
716
  to_process_docs: dict[str, DocProcessingStatus] = {}
717
 
718
- processing_docs = await self.doc_status.get_processing_docs()
719
  to_process_docs.update(processing_docs)
720
- failed_docs = await self.doc_status.get_failed_docs()
721
  to_process_docs.update(failed_docs)
722
- pendings_docs = await self.doc_status.get_pending_docs()
723
  to_process_docs.update(pendings_docs)
724
 
725
  if not to_process_docs:
 
89
  "PGDocStatusStorage",
90
  "MongoDocStatusStorage",
91
  ],
92
+ "required_methods": ["get_docs_by_status"],
93
  },
94
  }
95
 
 
230
  """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
231
 
232
  working_dir: str = field(
233
+ default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
234
  )
235
  """Directory where cache and temporary files are stored."""
236
 
 
715
  # 1. Get all pending, failed, and abnormally terminated processing documents.
716
  to_process_docs: dict[str, DocProcessingStatus] = {}
717
 
718
+ processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
719
  to_process_docs.update(processing_docs)
720
+ failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
721
  to_process_docs.update(failed_docs)
722
+ pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
723
  to_process_docs.update(pendings_docs)
724
 
725
  if not to_process_docs: