Gurjot Singh commited on
Commit
b0187f6
·
1 Parent(s): c0af224

Add custom function with separate keyword extraction for user's query and a separate prompt

Browse files
Files changed (4) hide show
  1. lightrag/base.py +2 -0
  2. lightrag/lightrag.py +110 -0
  3. lightrag/operate.py +200 -0
  4. test.py +1 -1
lightrag/base.py CHANGED
@@ -31,6 +31,8 @@ class QueryParam:
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
 
 
34
 
35
 
36
  @dataclass
 
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
34
+ hl_keywords: list[str] = field(default_factory=list)
35
+ ll_keywords: list[str] = field(default_factory=list)
36
 
37
 
38
  @dataclass
lightrag/lightrag.py CHANGED
@@ -17,6 +17,8 @@ from .operate import (
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
 
 
20
  )
21
 
22
  from .utils import (
@@ -753,6 +755,114 @@ class LightRAG:
753
  await self._query_done()
754
  return response
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  async def _query_done(self):
757
  tasks = []
758
  for storage_inst in [self.llm_response_cache]:
 
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
20
+ extract_keywords_only,
21
+ kg_query_with_keywords,
22
  )
23
 
24
  from .utils import (
 
755
  await self._query_done()
756
  return response
757
 
758
+ def query_with_separate_keyword_extraction(
759
+ self,
760
+ query: str,
761
+ prompt: str,
762
+ param: QueryParam = QueryParam()
763
+ ):
764
+ """
765
+ 1. Extract keywords from the 'query' using new function in operate.py.
766
+ 2. Then run the standard aquery() flow with the final prompt (formatted_question).
767
+ """
768
+
769
+ loop = always_get_an_event_loop()
770
+ return loop.run_until_complete(self.aquery_with_separate_keyword_extraction(query, prompt, param))
771
+
772
+ async def aquery_with_separate_keyword_extraction(
773
+ self,
774
+ query: str,
775
+ prompt: str,
776
+ param: QueryParam = QueryParam()
777
+ ):
778
+ """
779
+ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
780
+ 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
781
+ """
782
+
783
+ # ---------------------
784
+ # STEP 1: Keyword Extraction
785
+ # ---------------------
786
+ # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
787
+ hl_keywords, ll_keywords = await extract_keywords_only(
788
+ text=query,
789
+ param=param,
790
+ global_config=asdict(self),
791
+ hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls(
792
+ namespace="llm_response_cache",
793
+ global_config=asdict(self),
794
+ embedding_func=None,
795
+ )
796
+ )
797
+
798
+ param.hl_keywords=hl_keywords,
799
+ param.ll_keywords=ll_keywords,
800
+
801
+ # ---------------------
802
+ # STEP 2: Final Query Logic
803
+ # ---------------------
804
+
805
+ # Create a new string with the prompt and the keywords
806
+ ll_keywords_str = ", ".join(ll_keywords)
807
+ hl_keywords_str = ", ".join(hl_keywords)
808
+ formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
809
+
810
+ if param.mode in ["local", "global", "hybrid"]:
811
+ response = await kg_query_with_keywords(
812
+ formatted_question,
813
+ self.chunk_entity_relation_graph,
814
+ self.entities_vdb,
815
+ self.relationships_vdb,
816
+ self.text_chunks,
817
+ param,
818
+ asdict(self),
819
+ hashing_kv=self.llm_response_cache
820
+ if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config")
821
+ else self.key_string_value_json_storage_cls(
822
+ namespace="llm_response_cache",
823
+ global_config=asdict(self),
824
+ embedding_func=None,
825
+ ),
826
+ )
827
+ elif param.mode == "naive":
828
+ response = await naive_query(
829
+ formatted_question,
830
+ self.chunks_vdb,
831
+ self.text_chunks,
832
+ param,
833
+ asdict(self),
834
+ hashing_kv=self.llm_response_cache
835
+ if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config")
836
+ else self.key_string_value_json_storage_cls(
837
+ namespace="llm_response_cache",
838
+ global_config=asdict(self),
839
+ embedding_func=None,
840
+ ),
841
+ )
842
+ elif param.mode == "mix":
843
+ response = await mix_kg_vector_query(
844
+ formatted_question,
845
+ self.chunk_entity_relation_graph,
846
+ self.entities_vdb,
847
+ self.relationships_vdb,
848
+ self.chunks_vdb,
849
+ self.text_chunks,
850
+ param,
851
+ asdict(self),
852
+ hashing_kv=self.llm_response_cache
853
+ if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config")
854
+ else self.key_string_value_json_storage_cls(
855
+ namespace="llm_response_cache",
856
+ global_config=asdict(self),
857
+ embedding_func=None,
858
+ ),
859
+ )
860
+ else:
861
+ raise ValueError(f"Unknown mode {param.mode}")
862
+
863
+ await self._query_done()
864
+ return response
865
+
866
  async def _query_done(self):
867
  tasks = []
868
  for storage_inst in [self.llm_response_cache]:
lightrag/operate.py CHANGED
@@ -680,6 +680,206 @@ async def kg_query(
680
  )
681
  return response
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  async def _build_query_context(
685
  query: list,
 
680
  )
681
  return response
682
 
683
+ async def kg_query_with_keywords(
684
+ query: str,
685
+ knowledge_graph_inst: BaseGraphStorage,
686
+ entities_vdb: BaseVectorStorage,
687
+ relationships_vdb: BaseVectorStorage,
688
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
689
+ query_param: QueryParam,
690
+ global_config: dict,
691
+ hashing_kv: BaseKVStorage = None,
692
+ ) -> str:
693
+ """
694
+ Refactored kg_query that does NOT extract keywords by itself.
695
+ It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
696
+ Then it uses those to build context and produce a final LLM response.
697
+ """
698
+
699
+ # ---------------------------
700
+ # 0) Handle potential cache
701
+ # ---------------------------
702
+ use_model_func = global_config["llm_model_func"]
703
+ args_hash = compute_args_hash(query_param.mode, query)
704
+ cached_response, quantized, min_val, max_val = await handle_cache(
705
+ hashing_kv, args_hash, query, query_param.mode
706
+ )
707
+ if cached_response is not None:
708
+ return cached_response
709
+
710
+ # ---------------------------
711
+ # 1) RETRIEVE KEYWORDS FROM query_param
712
+ # ---------------------------
713
+
714
+ # If these fields don't exist, default to empty lists/strings.
715
+ hl_keywords = getattr(query_param, "hl_keywords", []) or []
716
+ ll_keywords = getattr(query_param, "ll_keywords", []) or []
717
+
718
+ # If neither has any keywords, you could handle that logic here.
719
+ if not hl_keywords and not ll_keywords:
720
+ logger.warning("No keywords found in query_param. Could default to global mode or fail.")
721
+ return PROMPTS["fail_response"]
722
+ if not ll_keywords and query_param.mode in ["local", "hybrid"]:
723
+ logger.warning("low_level_keywords is empty, switching to global mode.")
724
+ query_param.mode = "global"
725
+ if not hl_keywords and query_param.mode in ["global", "hybrid"]:
726
+ logger.warning("high_level_keywords is empty, switching to local mode.")
727
+ query_param.mode = "local"
728
+
729
+ # Flatten low-level and high-level keywords if needed
730
+ ll_keywords_flat = [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords
731
+ hl_keywords_flat = [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords
732
+
733
+ # Join the flattened lists
734
+ ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
735
+ hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
736
+
737
+ keywords = [ll_keywords_str, hl_keywords_str]
738
+
739
+ logger.info("Using %s mode for query processing", query_param.mode)
740
+
741
+ # ---------------------------
742
+ # 2) BUILD CONTEXT
743
+ # ---------------------------
744
+ context = await _build_query_context(
745
+ keywords,
746
+ knowledge_graph_inst,
747
+ entities_vdb,
748
+ relationships_vdb,
749
+ text_chunks_db,
750
+ query_param,
751
+ )
752
+ if not context:
753
+ return PROMPTS["fail_response"]
754
+
755
+ # If only context is needed, return it
756
+ if query_param.only_need_context:
757
+ return context
758
+
759
+ # ---------------------------
760
+ # 3) BUILD THE SYSTEM PROMPT + CALL LLM
761
+ # ---------------------------
762
+ sys_prompt_temp = PROMPTS["rag_response"]
763
+ sys_prompt = sys_prompt_temp.format(
764
+ context_data=context, response_type=query_param.response_type
765
+ )
766
+
767
+ if query_param.only_need_prompt:
768
+ return sys_prompt
769
+
770
+ # Now call the LLM with the final system prompt
771
+ response = await use_model_func(
772
+ query,
773
+ system_prompt=sys_prompt,
774
+ stream=query_param.stream,
775
+ )
776
+
777
+ # Clean up the response
778
+ if isinstance(response, str) and len(response) > len(sys_prompt):
779
+ response = (
780
+ response.replace(sys_prompt, "")
781
+ .replace("user", "")
782
+ .replace("model", "")
783
+ .replace(query, "")
784
+ .replace("<system>", "")
785
+ .replace("</system>", "")
786
+ .strip()
787
+ )
788
+
789
+ # ---------------------------
790
+ # 4) SAVE TO CACHE
791
+ # ---------------------------
792
+ await save_to_cache(
793
+ hashing_kv,
794
+ CacheData(
795
+ args_hash=args_hash,
796
+ content=response,
797
+ prompt=query,
798
+ quantized=quantized,
799
+ min_val=min_val,
800
+ max_val=max_val,
801
+ mode=query_param.mode,
802
+ ),
803
+ )
804
+ return response
805
+
806
+ async def extract_keywords_only(
807
+ text: str,
808
+ param: QueryParam,
809
+ global_config: dict,
810
+ hashing_kv: BaseKVStorage = None,
811
+ ) -> tuple[list[str], list[str]]:
812
+ """
813
+ Extract high-level and low-level keywords from the given 'text' using the LLM.
814
+ This method does NOT build the final RAG context or provide a final answer.
815
+ It ONLY extracts keywords (hl_keywords, ll_keywords).
816
+ """
817
+
818
+ # 1. Handle cache if needed
819
+ args_hash = compute_args_hash(param.mode, text)
820
+ cached_response, quantized, min_val, max_val = await handle_cache(
821
+ hashing_kv, args_hash, text, param.mode
822
+ )
823
+ if cached_response is not None:
824
+ # parse the cached_response if it’s JSON containing keywords
825
+ # or simply return (hl_keywords, ll_keywords) from cached
826
+ # Assuming cached_response is in the same JSON structure:
827
+ match = re.search(r"\{.*\}", cached_response, re.DOTALL)
828
+ if match:
829
+ keywords_data = json.loads(match.group(0))
830
+ hl_keywords = keywords_data.get("high_level_keywords", [])
831
+ ll_keywords = keywords_data.get("low_level_keywords", [])
832
+ return hl_keywords, ll_keywords
833
+ return [], []
834
+
835
+ # 2. Build the examples
836
+ example_number = global_config["addon_params"].get("example_number", None)
837
+ if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
838
+ examples = "\n".join(
839
+ PROMPTS["keywords_extraction_examples"][: int(example_number)]
840
+ )
841
+ else:
842
+ examples = "\n".join(PROMPTS["keywords_extraction_examples"])
843
+ language = global_config["addon_params"].get(
844
+ "language", PROMPTS["DEFAULT_LANGUAGE"]
845
+ )
846
+
847
+ # 3. Build the keyword-extraction prompt
848
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
849
+ kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
850
+
851
+ # 4. Call the LLM for keyword extraction
852
+ use_model_func = global_config["llm_model_func"]
853
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
854
+
855
+ # 5. Parse out JSON from the LLM response
856
+ match = re.search(r"\{.*\}", result, re.DOTALL)
857
+ if not match:
858
+ logger.error("No JSON-like structure found in the result.")
859
+ return [], []
860
+ try:
861
+ keywords_data = json.loads(match.group(0))
862
+ except json.JSONDecodeError as e:
863
+ logger.error(f"JSON parsing error: {e}")
864
+ return [], []
865
+
866
+ hl_keywords = keywords_data.get("high_level_keywords", [])
867
+ ll_keywords = keywords_data.get("low_level_keywords", [])
868
+
869
+ # 6. Cache the result if needed
870
+ await save_to_cache(
871
+ hashing_kv,
872
+ CacheData(
873
+ args_hash=args_hash,
874
+ content=result,
875
+ prompt=text,
876
+ quantized=quantized,
877
+ min_val=min_val,
878
+ max_val=max_val,
879
+ mode=param.mode,
880
+ ),
881
+ )
882
+ return hl_keywords, ll_keywords
883
 
884
  async def _build_query_context(
885
  query: list,
test.py CHANGED
@@ -39,4 +39,4 @@ print(
39
  # Perform hybrid search
40
  print(
41
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
42
- )
 
39
  # Perform hybrid search
40
  print(
41
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
42
+ )