PiochU19 commited on
Commit
423082d
·
1 Parent(s): e57f9b2

add support of providing ids for documents insert

Browse files
Files changed (1) hide show
  1. lightrag/lightrag.py +50 -21
lightrag/lightrag.py CHANGED
@@ -1,8 +1,8 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- import os
5
  import configparser
 
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
@@ -37,11 +37,11 @@ from .utils import (
37
  always_get_an_event_loop,
38
  compute_mdhash_id,
39
  convert_response_to_json,
 
40
  lazy_external_import,
41
  limit_async_func_call,
42
  logger,
43
  set_logger,
44
- encode_string_by_tiktoken,
45
  )
46
 
47
  config = configparser.ConfigParser()
@@ -461,6 +461,7 @@ class LightRAG:
461
  input: str | list[str],
462
  split_by_character: str | None = None,
463
  split_by_character_only: bool = False,
 
464
  ) -> None:
465
  """Sync Insert documents with checkpoint support
466
 
@@ -469,10 +470,11 @@ class LightRAG:
469
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
470
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
471
  split_by_character is None, this parameter is ignored.
 
472
  """
473
  loop = always_get_an_event_loop()
474
  loop.run_until_complete(
475
- self.ainsert(input, split_by_character, split_by_character_only)
476
  )
477
 
478
  async def ainsert(
@@ -480,6 +482,7 @@ class LightRAG:
480
  input: str | list[str],
481
  split_by_character: str | None = None,
482
  split_by_character_only: bool = False,
 
483
  ) -> None:
484
  """Async Insert documents with checkpoint support
485
 
@@ -488,8 +491,9 @@ class LightRAG:
488
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
489
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
490
  split_by_character is None, this parameter is ignored.
 
491
  """
492
- await self.apipeline_enqueue_documents(input)
493
  await self.apipeline_process_enqueue_documents(
494
  split_by_character, split_by_character_only
495
  )
@@ -546,24 +550,51 @@ class LightRAG:
546
  if update_storage:
547
  await self._insert_done()
548
 
549
- async def apipeline_enqueue_documents(self, input: str | list[str]) -> None:
 
 
550
  """
551
  Pipeline for Processing Documents
552
 
553
- 1. Remove duplicate contents from the list
554
- 2. Generate document IDs and initial status
555
- 3. Filter out already processed documents
556
- 4. Enqueue document in status
 
557
  """
558
  if isinstance(input, str):
559
  input = [input]
560
 
561
- # 1. Remove duplicate contents from the list
562
- unique_contents = list(set(doc.strip() for doc in input))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
- # 2. Generate document IDs and initial status
565
  new_docs: dict[str, Any] = {
566
- compute_mdhash_id(content, prefix="doc-"): {
567
  "content": content,
568
  "content_summary": self._get_content_summary(content),
569
  "content_length": len(content),
@@ -571,10 +602,10 @@ class LightRAG:
571
  "created_at": datetime.now().isoformat(),
572
  "updated_at": datetime.now().isoformat(),
573
  }
574
- for content in unique_contents
575
  }
576
 
577
- # 3. Filter out already processed documents
578
  # Get docs ids
579
  all_new_doc_ids = set(new_docs.keys())
580
  # Exclude IDs of documents that are already in progress
@@ -586,7 +617,7 @@ class LightRAG:
586
  logger.info("No new unique documents were found.")
587
  return
588
 
589
- # 4. Store status document
590
  await self.doc_status.upsert(new_docs)
591
  logger.info(f"Stored {len(new_docs)} new unique documents")
592
 
@@ -643,8 +674,6 @@ class LightRAG:
643
  # 4. iterate over batch
644
  for doc_id_processing_status in docs_batch:
645
  doc_id, status_doc = doc_id_processing_status
646
- # Update status in processing
647
- doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
648
  # Generate chunks from document
649
  chunks: dict[str, Any] = {
650
  compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -664,7 +693,7 @@ class LightRAG:
664
  tasks = [
665
  self.doc_status.upsert(
666
  {
667
- doc_status_id: {
668
  "status": DocStatus.PROCESSING,
669
  "updated_at": datetime.now().isoformat(),
670
  "content": status_doc.content,
@@ -685,7 +714,7 @@ class LightRAG:
685
  await asyncio.gather(*tasks)
686
  await self.doc_status.upsert(
687
  {
688
- doc_status_id: {
689
  "status": DocStatus.PROCESSED,
690
  "chunks_count": len(chunks),
691
  "content": status_doc.content,
@@ -700,7 +729,7 @@ class LightRAG:
700
  logger.error(f"Failed to process document {doc_id}: {str(e)}")
701
  await self.doc_status.upsert(
702
  {
703
- doc_status_id: {
704
  "status": DocStatus.FAILED,
705
  "error": str(e),
706
  "content": status_doc.content,
 
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import configparser
5
+ import os
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
 
37
  always_get_an_event_loop,
38
  compute_mdhash_id,
39
  convert_response_to_json,
40
+ encode_string_by_tiktoken,
41
  lazy_external_import,
42
  limit_async_func_call,
43
  logger,
44
  set_logger,
 
45
  )
46
 
47
  config = configparser.ConfigParser()
 
461
  input: str | list[str],
462
  split_by_character: str | None = None,
463
  split_by_character_only: bool = False,
464
+ ids: list[str] | None = None,
465
  ) -> None:
466
  """Sync Insert documents with checkpoint support
467
 
 
470
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
471
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
472
  split_by_character is None, this parameter is ignored.
473
+ ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
474
  """
475
  loop = always_get_an_event_loop()
476
  loop.run_until_complete(
477
+ self.ainsert(input, split_by_character, split_by_character_only, ids)
478
  )
479
 
480
  async def ainsert(
 
482
  input: str | list[str],
483
  split_by_character: str | None = None,
484
  split_by_character_only: bool = False,
485
+ ids: list[str] | None = None,
486
  ) -> None:
487
  """Async Insert documents with checkpoint support
488
 
 
491
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
492
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
493
  split_by_character is None, this parameter is ignored.
494
+ ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
495
  """
496
+ await self.apipeline_enqueue_documents(input, ids)
497
  await self.apipeline_process_enqueue_documents(
498
  split_by_character, split_by_character_only
499
  )
 
550
  if update_storage:
551
  await self._insert_done()
552
 
553
+ async def apipeline_enqueue_documents(
554
+ self, input: str | list[str], ids: list[str] | None
555
+ ) -> None:
556
  """
557
  Pipeline for Processing Documents
558
 
559
+ 1. Validate ids if provided or generate MD5 hash IDs
560
+ 2. Remove duplicate contents
561
+ 3. Generate document initial status
562
+ 4. Filter out already processed documents
563
+ 5. Enqueue document in status
564
  """
565
  if isinstance(input, str):
566
  input = [input]
567
 
568
+ # 1. Validate ids if provided or generate MD5 hash IDs
569
+ if ids is not None:
570
+ # Check if the number of IDs matches the number of documents
571
+ if len(ids) != len(input):
572
+ raise ValueError("Number of IDs must match the number of documents")
573
+
574
+ # Check if IDs are unique
575
+ if len(ids) != len(set(ids)):
576
+ raise ValueError("IDs must be unique")
577
+
578
+ # Generate contents dict of IDs provided by user and documents
579
+ contents = {id_: doc.strip() for id_, doc in zip(ids, input)}
580
+ else:
581
+ # Generate contents dict of MD5 hash IDs and documents
582
+ contents = {
583
+ compute_mdhash_id(doc.strip(), prefix="doc-"): doc.strip()
584
+ for doc in input
585
+ }
586
+
587
+ # 2. Remove duplicate contents
588
+ unique_contents = {
589
+ id_: content
590
+ for content, id_ in {
591
+ content: id_ for id_, content in contents.items()
592
+ }.items()
593
+ }
594
 
595
+ # 3. Generate document initial status
596
  new_docs: dict[str, Any] = {
597
+ id_: {
598
  "content": content,
599
  "content_summary": self._get_content_summary(content),
600
  "content_length": len(content),
 
602
  "created_at": datetime.now().isoformat(),
603
  "updated_at": datetime.now().isoformat(),
604
  }
605
+ for id_, content in unique_contents.items()
606
  }
607
 
608
+ # 4. Filter out already processed documents
609
  # Get docs ids
610
  all_new_doc_ids = set(new_docs.keys())
611
  # Exclude IDs of documents that are already in progress
 
617
  logger.info("No new unique documents were found.")
618
  return
619
 
620
+ # 5. Store status document
621
  await self.doc_status.upsert(new_docs)
622
  logger.info(f"Stored {len(new_docs)} new unique documents")
623
 
 
674
  # 4. iterate over batch
675
  for doc_id_processing_status in docs_batch:
676
  doc_id, status_doc = doc_id_processing_status
 
 
677
  # Generate chunks from document
678
  chunks: dict[str, Any] = {
679
  compute_mdhash_id(dp["content"], prefix="chunk-"): {
 
693
  tasks = [
694
  self.doc_status.upsert(
695
  {
696
+ doc_id: {
697
  "status": DocStatus.PROCESSING,
698
  "updated_at": datetime.now().isoformat(),
699
  "content": status_doc.content,
 
714
  await asyncio.gather(*tasks)
715
  await self.doc_status.upsert(
716
  {
717
+ doc_id: {
718
  "status": DocStatus.PROCESSED,
719
  "chunks_count": len(chunks),
720
  "content": status_doc.content,
 
729
  logger.error(f"Failed to process document {doc_id}: {str(e)}")
730
  await self.doc_status.upsert(
731
  {
732
+ doc_id: {
733
  "status": DocStatus.FAILED,
734
  "error": str(e),
735
  "content": status_doc.content,