zrguo commited on
Commit
3fede48
·
unverified ·
2 Parent(s): b90ac1b 6faaceb

Merge pull request #892 from PiochU19/main

Browse files

add support of providing ids for documents insert

Files changed (2) hide show
  1. README.md +14 -0
  2. lightrag/lightrag.py +50 -21
README.md CHANGED
@@ -545,6 +545,20 @@ The `insert_batch_size` parameter in `addon_params` controls how many documents
545
 
546
  </details>
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
  <details>
550
  <summary><b>Incremental Insert</b></summary>
 
545
 
546
  </details>
547
 
548
+ <details>
549
+ <summary> <b> Insert with ID </b></summary>
550
+
551
+ If you want to provide your own IDs for your documents, number of documents and number of IDs must be the same.
552
+
553
+ ```python
554
+ # Insert single text, and provide ID for it
555
+ rag.insert("TEXT1", ids=["ID_FOR_TEXT1"])
556
+
557
+ # Insert multiple texts, and provide IDs for them
558
+ rag.insert(["TEXT1", "TEXT2",...], ids=["ID_FOR_TEXT1", "ID_FOR_TEXT2"])
559
+ ```
560
+
561
+ </details>
562
 
563
  <details>
564
  <summary><b>Incremental Insert</b></summary>
lightrag/lightrag.py CHANGED
@@ -1,8 +1,8 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- import os
5
  import configparser
 
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
@@ -41,11 +41,11 @@ from .utils import (
41
  always_get_an_event_loop,
42
  compute_mdhash_id,
43
  convert_response_to_json,
 
44
  lazy_external_import,
45
  limit_async_func_call,
46
  logger,
47
  set_logger,
48
- encode_string_by_tiktoken,
49
  )
50
  from .types import KnowledgeGraph
51
 
@@ -479,6 +479,7 @@ class LightRAG:
479
  input: str | list[str],
480
  split_by_character: str | None = None,
481
  split_by_character_only: bool = False,
 
482
  ) -> None:
483
  """Sync Insert documents with checkpoint support
484
 
@@ -487,10 +488,11 @@ class LightRAG:
487
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
488
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
489
  split_by_character is None, this parameter is ignored.
 
490
  """
491
  loop = always_get_an_event_loop()
492
  loop.run_until_complete(
493
- self.ainsert(input, split_by_character, split_by_character_only)
494
  )
495
 
496
  async def ainsert(
@@ -498,6 +500,7 @@ class LightRAG:
498
  input: str | list[str],
499
  split_by_character: str | None = None,
500
  split_by_character_only: bool = False,
 
501
  ) -> None:
502
  """Async Insert documents with checkpoint support
503
 
@@ -506,8 +509,9 @@ class LightRAG:
506
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
507
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
508
  split_by_character is None, this parameter is ignored.
 
509
  """
510
- await self.apipeline_enqueue_documents(input)
511
  await self.apipeline_process_enqueue_documents(
512
  split_by_character, split_by_character_only
513
  )
@@ -564,24 +568,51 @@ class LightRAG:
564
  if update_storage:
565
  await self._insert_done()
566
 
567
- async def apipeline_enqueue_documents(self, input: str | list[str]) -> None:
 
 
568
  """
569
  Pipeline for Processing Documents
570
 
571
- 1. Remove duplicate contents from the list
572
- 2. Generate document IDs and initial status
573
- 3. Filter out already processed documents
574
- 4. Enqueue document in status
 
575
  """
576
  if isinstance(input, str):
577
  input = [input]
578
 
579
- # 1. Remove duplicate contents from the list
580
- unique_contents = list(set(doc.strip() for doc in input))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
- # 2. Generate document IDs and initial status
583
  new_docs: dict[str, Any] = {
584
- compute_mdhash_id(content, prefix="doc-"): {
585
  "content": content,
586
  "content_summary": self._get_content_summary(content),
587
  "content_length": len(content),
@@ -589,10 +620,10 @@ class LightRAG:
589
  "created_at": datetime.now().isoformat(),
590
  "updated_at": datetime.now().isoformat(),
591
  }
592
- for content in unique_contents
593
  }
594
 
595
- # 3. Filter out already processed documents
596
  # Get docs ids
597
  all_new_doc_ids = set(new_docs.keys())
598
  # Exclude IDs of documents that are already in progress
@@ -604,7 +635,7 @@ class LightRAG:
604
  logger.info("No new unique documents were found.")
605
  return
606
 
607
- # 4. Store status document
608
  await self.doc_status.upsert(new_docs)
609
  logger.info(f"Stored {len(new_docs)} new unique documents")
610
 
@@ -661,8 +692,6 @@ class LightRAG:
661
  # 4. iterate over batch
662
  for doc_id_processing_status in docs_batch:
663
  doc_id, status_doc = doc_id_processing_status
664
- # Update status in processing
665
- doc_status_id = compute_mdhash_id(status_doc.content, prefix="doc-")
666
  # Generate chunks from document
667
  chunks: dict[str, Any] = {
668
  compute_mdhash_id(dp["content"], prefix="chunk-"): {
@@ -682,7 +711,7 @@ class LightRAG:
682
  tasks = [
683
  self.doc_status.upsert(
684
  {
685
- doc_status_id: {
686
  "status": DocStatus.PROCESSING,
687
  "updated_at": datetime.now().isoformat(),
688
  "content": status_doc.content,
@@ -703,7 +732,7 @@ class LightRAG:
703
  await asyncio.gather(*tasks)
704
  await self.doc_status.upsert(
705
  {
706
- doc_status_id: {
707
  "status": DocStatus.PROCESSED,
708
  "chunks_count": len(chunks),
709
  "content": status_doc.content,
@@ -718,7 +747,7 @@ class LightRAG:
718
  logger.error(f"Failed to process document {doc_id}: {str(e)}")
719
  await self.doc_status.upsert(
720
  {
721
- doc_status_id: {
722
  "status": DocStatus.FAILED,
723
  "error": str(e),
724
  "content": status_doc.content,
 
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import configparser
5
+ import os
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
 
41
  always_get_an_event_loop,
42
  compute_mdhash_id,
43
  convert_response_to_json,
44
+ encode_string_by_tiktoken,
45
  lazy_external_import,
46
  limit_async_func_call,
47
  logger,
48
  set_logger,
 
49
  )
50
  from .types import KnowledgeGraph
51
 
 
479
  input: str | list[str],
480
  split_by_character: str | None = None,
481
  split_by_character_only: bool = False,
482
+ ids: list[str] | None = None,
483
  ) -> None:
484
  """Sync Insert documents with checkpoint support
485
 
 
488
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
489
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
490
  split_by_character is None, this parameter is ignored.
491
+ ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
492
  """
493
  loop = always_get_an_event_loop()
494
  loop.run_until_complete(
495
+ self.ainsert(input, split_by_character, split_by_character_only, ids)
496
  )
497
 
498
  async def ainsert(
 
500
  input: str | list[str],
501
  split_by_character: str | None = None,
502
  split_by_character_only: bool = False,
503
+ ids: list[str] | None = None,
504
  ) -> None:
505
  """Async Insert documents with checkpoint support
506
 
 
509
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
510
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
511
  split_by_character is None, this parameter is ignored.
512
+ ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated
513
  """
514
+ await self.apipeline_enqueue_documents(input, ids)
515
  await self.apipeline_process_enqueue_documents(
516
  split_by_character, split_by_character_only
517
  )
 
568
  if update_storage:
569
  await self._insert_done()
570
 
571
+ async def apipeline_enqueue_documents(
572
+ self, input: str | list[str], ids: list[str] | None
573
+ ) -> None:
574
  """
575
  Pipeline for Processing Documents
576
 
577
+ 1. Validate ids if provided or generate MD5 hash IDs
578
+ 2. Remove duplicate contents
579
+ 3. Generate document initial status
580
+ 4. Filter out already processed documents
581
+ 5. Enqueue document in status
582
  """
583
  if isinstance(input, str):
584
  input = [input]
585
 
586
+ # 1. Validate ids if provided or generate MD5 hash IDs
587
+ if ids is not None:
588
+ # Check if the number of IDs matches the number of documents
589
+ if len(ids) != len(input):
590
+ raise ValueError("Number of IDs must match the number of documents")
591
+
592
+ # Check if IDs are unique
593
+ if len(ids) != len(set(ids)):
594
+ raise ValueError("IDs must be unique")
595
+
596
+ # Generate contents dict of IDs provided by user and documents
597
+ contents = {id_: doc.strip() for id_, doc in zip(ids, input)}
598
+ else:
599
+ # Generate contents dict of MD5 hash IDs and documents
600
+ contents = {
601
+ compute_mdhash_id(doc.strip(), prefix="doc-"): doc.strip()
602
+ for doc in input
603
+ }
604
+
605
+ # 2. Remove duplicate contents
606
+ unique_contents = {
607
+ id_: content
608
+ for content, id_ in {
609
+ content: id_ for id_, content in contents.items()
610
+ }.items()
611
+ }
612
 
613
+ # 3. Generate document initial status
614
  new_docs: dict[str, Any] = {
615
+ id_: {
616
  "content": content,
617
  "content_summary": self._get_content_summary(content),
618
  "content_length": len(content),
 
620
  "created_at": datetime.now().isoformat(),
621
  "updated_at": datetime.now().isoformat(),
622
  }
623
+ for id_, content in unique_contents.items()
624
  }
625
 
626
+ # 4. Filter out already processed documents
627
  # Get docs ids
628
  all_new_doc_ids = set(new_docs.keys())
629
  # Exclude IDs of documents that are already in progress
 
635
  logger.info("No new unique documents were found.")
636
  return
637
 
638
+ # 5. Store status document
639
  await self.doc_status.upsert(new_docs)
640
  logger.info(f"Stored {len(new_docs)} new unique documents")
641
 
 
692
  # 4. iterate over batch
693
  for doc_id_processing_status in docs_batch:
694
  doc_id, status_doc = doc_id_processing_status
 
 
695
  # Generate chunks from document
696
  chunks: dict[str, Any] = {
697
  compute_mdhash_id(dp["content"], prefix="chunk-"): {
 
711
  tasks = [
712
  self.doc_status.upsert(
713
  {
714
+ doc_id: {
715
  "status": DocStatus.PROCESSING,
716
  "updated_at": datetime.now().isoformat(),
717
  "content": status_doc.content,
 
732
  await asyncio.gather(*tasks)
733
  await self.doc_status.upsert(
734
  {
735
+ doc_id: {
736
  "status": DocStatus.PROCESSED,
737
  "chunks_count": len(chunks),
738
  "content": status_doc.content,
 
747
  logger.error(f"Failed to process document {doc_id}: {str(e)}")
748
  await self.doc_status.upsert(
749
  {
750
+ doc_id: {
751
  "status": DocStatus.FAILED,
752
  "error": str(e),
753
  "content": status_doc.content,