Magicyuan commited on
Commit
4a5787d
·
1 Parent(s): 03d23d7

feat(lightrag): Add document status tracking and checkpoint support

Browse files

功能(lightrag): 添加文档状态跟踪和断点续传支持

- Add DocStatus enum and DocProcessingStatus class for document processing state management
- 添加 DocStatus 枚举和 DocProcessingStatus 类用于文档处理状态管理

- Implement JsonDocStatusStorage for persistent status storage
- 实现 JsonDocStatusStorage 用于持久化状态存储

- Add document-level deduplication in batch processing
- 在批处理中添加文档级别的去重功能

- Add checkpoint support in ainsert method for resumable document processing
- 在 ainsert 方法中添加断点续传支持,实现可恢复的文档处理

- Add status query methods for monitoring processing progress
- 添加状态查询方法用于监控处理进度

- Update LightRAG initialization to support document status tracking
- 更新 LightRAG 初始化以支持文档状态跟踪

lightrag/base.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass, field
2
- from typing import TypedDict, Union, Literal, Generic, TypeVar
 
3
 
4
  import numpy as np
5
 
@@ -129,3 +130,42 @@ class BaseGraphStorage(StorageNameSpace):
129
 
130
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
131
  raise NotImplementedError("Node embedding is not used in lightrag.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass, field
2
+ from typing import TypedDict, Union, Literal, Generic, TypeVar, Optional, Dict, Any
3
+ from enum import Enum
4
 
5
  import numpy as np
6
 
 
130
 
131
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
132
  raise NotImplementedError("Node embedding is not used in lightrag.")
133
+
134
+
135
+ class DocStatus(str, Enum):
136
+ """Document processing status enum"""
137
+
138
+ PENDING = "pending"
139
+ PROCESSING = "processing"
140
+ PROCESSED = "processed"
141
+ FAILED = "failed"
142
+
143
+
144
+ @dataclass
145
+ class DocProcessingStatus:
146
+ """Document processing status data structure"""
147
+
148
+ content_summary: str # First 100 chars of document content
149
+ content_length: int # Total length of document
150
+ status: DocStatus # Current processing status
151
+ created_at: str # ISO format timestamp
152
+ updated_at: str # ISO format timestamp
153
+ chunks_count: Optional[int] = None # Number of chunks after splitting
154
+ error: Optional[str] = None # Error message if failed
155
+ metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
156
+
157
+
158
+ class DocStatusStorage(BaseKVStorage):
159
+ """Base class for document status storage"""
160
+
161
+ async def get_status_counts(self) -> Dict[str, int]:
162
+ """Get counts of documents in each status"""
163
+ raise NotImplementedError
164
+
165
+ async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
166
+ """Get all failed documents"""
167
+ raise NotImplementedError
168
+
169
+ async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
170
+ """Get all pending documents"""
171
+ raise NotImplementedError
lightrag/kg/age_impl.py CHANGED
@@ -1,7 +1,8 @@
1
  import asyncio
2
  import inspect
3
  import json
4
- import os, sys
 
5
  from contextlib import asynccontextmanager
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
@@ -22,8 +23,10 @@ from ..base import BaseGraphStorage
22
 
23
  if sys.platform.startswith("win"):
24
  import asyncio.windows_events
 
25
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
26
 
 
27
  class AGEQueryException(Exception):
28
  """Exception for the AGE queries."""
29
 
 
1
  import asyncio
2
  import inspect
3
  import json
4
+ import os
5
+ import sys
6
  from contextlib import asynccontextmanager
7
  from dataclasses import dataclass
8
  from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
 
23
 
24
  if sys.platform.startswith("win"):
25
  import asyncio.windows_events
26
+
27
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
28
 
29
+
30
  class AGEQueryException(Exception):
31
  """Exception for the AGE queries."""
32
 
lightrag/lightrag.py CHANGED
@@ -4,7 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
- from typing import Type, cast
8
 
9
  from .llm import (
10
  gpt_4o_mini_complete,
@@ -32,12 +32,14 @@ from .base import (
32
  BaseVectorStorage,
33
  StorageNameSpace,
34
  QueryParam,
 
35
  )
36
 
37
  from .storage import (
38
  JsonKVStorage,
39
  NanoVectorDBStorage,
40
  NetworkXStorage,
 
41
  )
42
 
43
  # future KG integrations
@@ -172,6 +174,9 @@ class LightRAG:
172
  addon_params: dict = field(default_factory=dict)
173
  convert_response_to_json_func: callable = convert_response_to_json
174
 
 
 
 
175
  def __post_init__(self):
176
  log_file = os.path.join("lightrag.log")
177
  set_logger(log_file)
@@ -263,7 +268,15 @@ class LightRAG:
263
  )
264
  )
265
 
266
- def _get_storage_class(self) -> Type[BaseGraphStorage]:
 
 
 
 
 
 
 
 
267
  return {
268
  # kv storage
269
  "JsonKVStorage": JsonKVStorage,
@@ -284,6 +297,7 @@ class LightRAG:
284
  "TiDBGraphStorage": TiDBGraphStorage,
285
  "GremlinStorage": GremlinStorage,
286
  # "ArangoDBStorage": ArangoDBStorage
 
287
  }
288
 
289
  def insert(self, string_or_strings):
@@ -291,71 +305,139 @@ class LightRAG:
291
  return loop.run_until_complete(self.ainsert(string_or_strings))
292
 
293
  async def ainsert(self, string_or_strings):
294
- update_storage = False
295
- try:
296
- if isinstance(string_or_strings, str):
297
- string_or_strings = [string_or_strings]
298
-
299
- new_docs = {
300
- compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
301
- for c in string_or_strings
 
 
 
 
 
 
 
 
 
 
 
 
302
  }
303
- _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
304
- new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
305
- if not len(new_docs):
306
- logger.warning("All docs are already in the storage")
307
- return
308
- update_storage = True
309
- logger.info(f"[New Docs] inserting {len(new_docs)} docs")
310
-
311
- inserting_chunks = {}
312
- for doc_key, doc in tqdm_async(
313
- new_docs.items(), desc="Chunking documents", unit="doc"
 
 
 
 
 
 
 
 
 
314
  ):
315
- chunks = {
316
- compute_mdhash_id(dp["content"], prefix="chunk-"): {
317
- **dp,
318
- "full_doc_id": doc_key,
 
 
 
 
319
  }
320
- for dp in chunking_by_token_size(
321
- doc["content"],
322
- overlap_token_size=self.chunk_overlap_token_size,
323
- max_token_size=self.chunk_token_size,
324
- tiktoken_model=self.tiktoken_model_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  )
326
- }
327
- inserting_chunks.update(chunks)
328
- _add_chunk_keys = await self.text_chunks.filter_keys(
329
- list(inserting_chunks.keys())
330
- )
331
- inserting_chunks = {
332
- k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
333
- }
334
- if not len(inserting_chunks):
335
- logger.warning("All chunks are already in the storage")
336
- return
337
- logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
338
-
339
- await self.chunks_vdb.upsert(inserting_chunks)
340
-
341
- logger.info("[Entity Extraction]...")
342
- maybe_new_kg = await extract_entities(
343
- inserting_chunks,
344
- knowledge_graph_inst=self.chunk_entity_relation_graph,
345
- entity_vdb=self.entities_vdb,
346
- relationships_vdb=self.relationships_vdb,
347
- global_config=asdict(self),
348
- )
349
- if maybe_new_kg is None:
350
- logger.warning("No new entities and relationships found")
351
- return
352
- self.chunk_entity_relation_graph = maybe_new_kg
353
 
354
- await self.full_docs.upsert(new_docs)
355
- await self.text_chunks.upsert(inserting_chunks)
356
- finally:
357
- if update_storage:
358
- await self._insert_done()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  async def _insert_done(self):
361
  tasks = []
@@ -591,3 +673,26 @@ class LightRAG:
591
  continue
592
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
593
  await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
+ from typing import Type, cast, Dict
8
 
9
  from .llm import (
10
  gpt_4o_mini_complete,
 
32
  BaseVectorStorage,
33
  StorageNameSpace,
34
  QueryParam,
35
+ DocStatus,
36
  )
37
 
38
  from .storage import (
39
  JsonKVStorage,
40
  NanoVectorDBStorage,
41
  NetworkXStorage,
42
+ JsonDocStatusStorage,
43
  )
44
 
45
  # future KG integrations
 
174
  addon_params: dict = field(default_factory=dict)
175
  convert_response_to_json_func: callable = convert_response_to_json
176
 
177
+ # Add new field for document status storage type
178
+ doc_status_storage: str = field(default="JsonDocStatusStorage")
179
+
180
  def __post_init__(self):
181
  log_file = os.path.join("lightrag.log")
182
  set_logger(log_file)
 
268
  )
269
  )
270
 
271
+ # Initialize document status storage
272
+ self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
273
+ self.doc_status = self.doc_status_storage_cls(
274
+ namespace="doc_status",
275
+ global_config=asdict(self),
276
+ embedding_func=None,
277
+ )
278
+
279
+ def _get_storage_class(self) -> dict:
280
  return {
281
  # kv storage
282
  "JsonKVStorage": JsonKVStorage,
 
297
  "TiDBGraphStorage": TiDBGraphStorage,
298
  "GremlinStorage": GremlinStorage,
299
  # "ArangoDBStorage": ArangoDBStorage
300
+ "JsonDocStatusStorage": JsonDocStatusStorage,
301
  }
302
 
303
  def insert(self, string_or_strings):
 
305
  return loop.run_until_complete(self.ainsert(string_or_strings))
306
 
307
  async def ainsert(self, string_or_strings):
308
+ """Insert documents with checkpoint support
309
+
310
+ Args:
311
+ string_or_strings: Single document string or list of document strings
312
+ """
313
+ if isinstance(string_or_strings, str):
314
+ string_or_strings = [string_or_strings]
315
+
316
+ # 1. Remove duplicate contents from the list
317
+ unique_contents = list(set(doc.strip() for doc in string_or_strings))
318
+
319
+ # 2. Generate document IDs and initial status
320
+ new_docs = {
321
+ compute_mdhash_id(content, prefix="doc-"): {
322
+ "content": content,
323
+ "content_summary": self._get_content_summary(content),
324
+ "content_length": len(content),
325
+ "status": DocStatus.PENDING,
326
+ "created_at": datetime.now().isoformat(),
327
+ "updated_at": datetime.now().isoformat(),
328
  }
329
+ for content in unique_contents
330
+ }
331
+
332
+ # 3. Filter out already processed documents
333
+ _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
334
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
335
+
336
+ if not new_docs:
337
+ logger.info("All documents have been processed or are duplicates")
338
+ return
339
+
340
+ logger.info(f"Processing {len(new_docs)} new unique documents")
341
+
342
+ # Process documents in batches
343
+ batch_size = self.addon_params.get("insert_batch_size", 10)
344
+ for i in range(0, len(new_docs), batch_size):
345
+ batch_docs = dict(list(new_docs.items())[i : i + batch_size])
346
+
347
+ for doc_id, doc in tqdm_async(
348
+ batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
349
  ):
350
+ try:
351
+ # Update status to processing
352
+ doc_status = {
353
+ "content_summary": doc["content_summary"],
354
+ "content_length": doc["content_length"],
355
+ "status": DocStatus.PROCESSING,
356
+ "created_at": doc["created_at"],
357
+ "updated_at": datetime.now().isoformat(),
358
  }
359
+ await self.doc_status.upsert({doc_id: doc_status})
360
+
361
+ # Generate chunks from document
362
+ chunks = {
363
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
364
+ **dp,
365
+ "full_doc_id": doc_id,
366
+ }
367
+ for dp in chunking_by_token_size(
368
+ doc["content"],
369
+ overlap_token_size=self.chunk_overlap_token_size,
370
+ max_token_size=self.chunk_token_size,
371
+ tiktoken_model=self.tiktoken_model_name,
372
+ )
373
+ }
374
+
375
+ # Update status with chunks information
376
+ doc_status.update(
377
+ {
378
+ "chunks_count": len(chunks),
379
+ "updated_at": datetime.now().isoformat(),
380
+ }
381
  )
382
+ await self.doc_status.upsert({doc_id: doc_status})
383
+
384
+ try:
385
+ # Store chunks in vector database
386
+ await self.chunks_vdb.upsert(chunks)
387
+
388
+ # Extract and store entities and relationships
389
+ maybe_new_kg = await extract_entities(
390
+ chunks,
391
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
392
+ entity_vdb=self.entities_vdb,
393
+ relationships_vdb=self.relationships_vdb,
394
+ global_config=asdict(self),
395
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ if maybe_new_kg is None:
398
+ raise Exception(
399
+ "Failed to extract entities and relationships"
400
+ )
401
+
402
+ self.chunk_entity_relation_graph = maybe_new_kg
403
+
404
+ # Store original document and chunks
405
+ await self.full_docs.upsert(
406
+ {doc_id: {"content": doc["content"]}}
407
+ )
408
+ await self.text_chunks.upsert(chunks)
409
+
410
+ # Update status to processed
411
+ doc_status.update(
412
+ {
413
+ "status": DocStatus.PROCESSED,
414
+ "updated_at": datetime.now().isoformat(),
415
+ }
416
+ )
417
+ await self.doc_status.upsert({doc_id: doc_status})
418
+
419
+ except Exception as e:
420
+ # Mark as failed if any step fails
421
+ doc_status.update(
422
+ {
423
+ "status": DocStatus.FAILED,
424
+ "error": str(e),
425
+ "updated_at": datetime.now().isoformat(),
426
+ }
427
+ )
428
+ await self.doc_status.upsert({doc_id: doc_status})
429
+ raise e
430
+
431
+ except Exception as e:
432
+ import traceback
433
+
434
+ error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
435
+ logger.error(error_msg)
436
+ continue
437
+
438
+ finally:
439
+ # Ensure all indexes are updated after each document
440
+ await self._insert_done()
441
 
442
  async def _insert_done(self):
443
  tasks = []
 
673
  continue
674
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
675
  await asyncio.gather(*tasks)
676
+
677
+ def _get_content_summary(self, content: str, max_length: int = 100) -> str:
678
+ """Get summary of document content
679
+
680
+ Args:
681
+ content: Original document content
682
+ max_length: Maximum length of summary
683
+
684
+ Returns:
685
+ Truncated content with ellipsis if needed
686
+ """
687
+ content = content.strip()
688
+ if len(content) <= max_length:
689
+ return content
690
+ return content[:max_length] + "..."
691
+
692
+ async def get_processing_status(self) -> Dict[str, int]:
693
+ """Get current document processing status counts
694
+
695
+ Returns:
696
+ Dict with counts for each status
697
+ """
698
+ return await self.doc_status.get_status_counts()
lightrag/storage.py CHANGED
@@ -3,7 +3,7 @@ import html
3
  import os
4
  from tqdm.asyncio import tqdm as tqdm_async
5
  from dataclasses import dataclass
6
- from typing import Any, Union, cast
7
  import networkx as nx
8
  import numpy as np
9
  from nano_vectordb import NanoVectorDB
@@ -19,6 +19,9 @@ from .base import (
19
  BaseGraphStorage,
20
  BaseKVStorage,
21
  BaseVectorStorage,
 
 
 
22
  )
23
 
24
 
@@ -315,3 +318,47 @@ class NetworkXStorage(BaseGraphStorage):
315
 
316
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
317
  return embeddings, nodes_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  from tqdm.asyncio import tqdm as tqdm_async
5
  from dataclasses import dataclass
6
+ from typing import Any, Union, cast, Dict
7
  import networkx as nx
8
  import numpy as np
9
  from nano_vectordb import NanoVectorDB
 
19
  BaseGraphStorage,
20
  BaseKVStorage,
21
  BaseVectorStorage,
22
+ DocStatus,
23
+ DocProcessingStatus,
24
+ DocStatusStorage,
25
  )
26
 
27
 
 
318
 
319
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
320
  return embeddings, nodes_ids
321
+
322
+
323
+ @dataclass
324
+ class JsonDocStatusStorage(DocStatusStorage):
325
+ """JSON implementation of document status storage"""
326
+
327
+ def __post_init__(self):
328
+ working_dir = self.global_config["working_dir"]
329
+ self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
330
+ self._data = load_json(self._file_name) or {}
331
+ logger.info(f"Loaded document status storage with {len(self._data)} records")
332
+
333
+ async def filter_keys(self, data: list[str]) -> set[str]:
334
+ """Return keys that don't exist in storage"""
335
+ return set([k for k in data if k not in self._data])
336
+
337
+ async def get_status_counts(self) -> Dict[str, int]:
338
+ """Get counts of documents in each status"""
339
+ counts = {status: 0 for status in DocStatus}
340
+ for doc in self._data.values():
341
+ counts[doc["status"]] += 1
342
+ return counts
343
+
344
+ async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
345
+ """Get all failed documents"""
346
+ return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
347
+
348
+ async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
349
+ """Get all pending documents"""
350
+ return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
351
+
352
+ async def index_done_callback(self):
353
+ """Save data to file after indexing"""
354
+ write_json(self._data, self._file_name)
355
+
356
+ async def upsert(self, data: dict[str, dict]):
357
+ """Update or insert document status
358
+
359
+ Args:
360
+ data: Dictionary of document IDs and their status data
361
+ """
362
+ self._data.update(data)
363
+ await self.index_done_callback()
364
+ return data