zrguo commited on
Commit
2639305
·
unverified ·
2 Parent(s): 337e371 3f44683

Merge pull request #585 from gurjot-05/feature-implementation

Browse files

Add custom function with separate keyword extraction for user's query and a separate prompt

README.md CHANGED
@@ -330,6 +330,26 @@ rag = LightRAG(
330
  with open("./newText.txt") as f:
331
  rag.insert(f.read())
332
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  ### Using Neo4J for Storage
335
 
 
330
  with open("./newText.txt") as f:
331
  rag.insert(f.read())
332
  ```
333
+ ### Separate Keyword Extraction
334
+ We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords.
335
+
336
+ ##### How It Works?
337
+ The function operates by dividing the input into two parts:
338
+ - `User Query`
339
+ - `Prompt`
340
+
341
+ It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question.
342
+
343
+ ##### Usage Example
344
+ This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students.
345
+
346
+ ```python
347
+ rag.query_with_separate_keyword_extraction(
348
+ query="Explain the law of gravity",
349
+ prompt="Provide a detailed explanation suitable for high school students studying physics.",
350
+ param=QueryParam(mode="hybrid")
351
+ )
352
+ ```
353
 
354
  ### Using Neo4J for Storage
355
 
examples/query_keyword_separation_example.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.utils import EmbeddingFunc
5
+ import numpy as np
6
+ from dotenv import load_dotenv
7
+ import logging
8
+ from openai import AzureOpenAI
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ load_dotenv()
13
+
14
+ AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
15
+ AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
16
+ AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
17
+ AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
18
+
19
+ AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
20
+ AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
21
+
22
+ WORKING_DIR = "./dickens"
23
+
24
+ if os.path.exists(WORKING_DIR):
25
+ import shutil
26
+
27
+ shutil.rmtree(WORKING_DIR)
28
+
29
+ os.mkdir(WORKING_DIR)
30
+
31
+
32
+ async def llm_model_func(
33
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
34
+ ) -> str:
35
+ client = AzureOpenAI(
36
+ api_key=AZURE_OPENAI_API_KEY,
37
+ api_version=AZURE_OPENAI_API_VERSION,
38
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
39
+ )
40
+
41
+ messages = []
42
+ if system_prompt:
43
+ messages.append({"role": "system", "content": system_prompt})
44
+ if history_messages:
45
+ messages.extend(history_messages)
46
+ messages.append({"role": "user", "content": prompt})
47
+
48
+ chat_completion = client.chat.completions.create(
49
+ model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
50
+ messages=messages,
51
+ temperature=kwargs.get("temperature", 0),
52
+ top_p=kwargs.get("top_p", 1),
53
+ n=kwargs.get("n", 1),
54
+ )
55
+ return chat_completion.choices[0].message.content
56
+
57
+
58
+ async def embedding_func(texts: list[str]) -> np.ndarray:
59
+ client = AzureOpenAI(
60
+ api_key=AZURE_OPENAI_API_KEY,
61
+ api_version=AZURE_EMBEDDING_API_VERSION,
62
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
63
+ )
64
+ embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
65
+
66
+ embeddings = [item.embedding for item in embedding.data]
67
+ return np.array(embeddings)
68
+
69
+
70
+ async def test_funcs():
71
+ result = await llm_model_func("How are you?")
72
+ print("Resposta do llm_model_func: ", result)
73
+
74
+ result = await embedding_func(["How are you?"])
75
+ print("Resultado do embedding_func: ", result.shape)
76
+ print("Dimensão da embedding: ", result.shape[1])
77
+
78
+
79
+ asyncio.run(test_funcs())
80
+
81
+ embedding_dimension = 3072
82
+
83
+ rag = LightRAG(
84
+ working_dir=WORKING_DIR,
85
+ llm_model_func=llm_model_func,
86
+ embedding_func=EmbeddingFunc(
87
+ embedding_dim=embedding_dimension,
88
+ max_token_size=8192,
89
+ func=embedding_func,
90
+ ),
91
+ )
92
+
93
+ book1 = open("./book_1.txt", encoding="utf-8")
94
+ book2 = open("./book_2.txt", encoding="utf-8")
95
+
96
+ rag.insert([book1.read(), book2.read()])
97
+
98
+
99
+ # Example function demonstrating the new query_with_separate_keyword_extraction usage
100
+ async def run_example():
101
+ query = "What are the top themes in this story?"
102
+ prompt = "Please simplify the response for a young audience."
103
+
104
+ # Using the new method to ensure the keyword extraction is only applied to the query
105
+ response = rag.query_with_separate_keyword_extraction(
106
+ query=query,
107
+ prompt=prompt,
108
+ param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary
109
+ )
110
+
111
+ print("Extracted Response:", response)
112
+
113
+
114
+ # Run the example asynchronously
115
+ if __name__ == "__main__":
116
+ asyncio.run(run_example())
lightrag/base.py CHANGED
@@ -31,6 +31,8 @@ class QueryParam:
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
 
 
34
 
35
 
36
  @dataclass
 
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
34
+ hl_keywords: list[str] = field(default_factory=list)
35
+ ll_keywords: list[str] = field(default_factory=list)
36
 
37
 
38
  @dataclass
lightrag/lightrag.py CHANGED
@@ -17,6 +17,8 @@ from .operate import (
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
 
 
20
  )
21
 
22
  from .utils import (
@@ -753,6 +755,114 @@ class LightRAG:
753
  await self._query_done()
754
  return response
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  async def _query_done(self):
757
  tasks = []
758
  for storage_inst in [self.llm_response_cache]:
 
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
20
+ extract_keywords_only,
21
+ kg_query_with_keywords,
22
  )
23
 
24
  from .utils import (
 
755
  await self._query_done()
756
  return response
757
 
758
+ def query_with_separate_keyword_extraction(
759
+ self, query: str, prompt: str, param: QueryParam = QueryParam()
760
+ ):
761
+ """
762
+ 1. Extract keywords from the 'query' using new function in operate.py.
763
+ 2. Then run the standard aquery() flow with the final prompt (formatted_question).
764
+ """
765
+
766
+ loop = always_get_an_event_loop()
767
+ return loop.run_until_complete(
768
+ self.aquery_with_separate_keyword_extraction(query, prompt, param)
769
+ )
770
+
771
+ async def aquery_with_separate_keyword_extraction(
772
+ self, query: str, prompt: str, param: QueryParam = QueryParam()
773
+ ):
774
+ """
775
+ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
776
+ 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
777
+ """
778
+
779
+ # ---------------------
780
+ # STEP 1: Keyword Extraction
781
+ # ---------------------
782
+ # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
783
+ hl_keywords, ll_keywords = await extract_keywords_only(
784
+ text=query,
785
+ param=param,
786
+ global_config=asdict(self),
787
+ hashing_kv=self.llm_response_cache
788
+ or self.key_string_value_json_storage_cls(
789
+ namespace="llm_response_cache",
790
+ global_config=asdict(self),
791
+ embedding_func=None,
792
+ ),
793
+ )
794
+
795
+ param.hl_keywords = (hl_keywords,)
796
+ param.ll_keywords = (ll_keywords,)
797
+
798
+ # ---------------------
799
+ # STEP 2: Final Query Logic
800
+ # ---------------------
801
+
802
+ # Create a new string with the prompt and the keywords
803
+ ll_keywords_str = ", ".join(ll_keywords)
804
+ hl_keywords_str = ", ".join(hl_keywords)
805
+ formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
806
+
807
+ if param.mode in ["local", "global", "hybrid"]:
808
+ response = await kg_query_with_keywords(
809
+ formatted_question,
810
+ self.chunk_entity_relation_graph,
811
+ self.entities_vdb,
812
+ self.relationships_vdb,
813
+ self.text_chunks,
814
+ param,
815
+ asdict(self),
816
+ hashing_kv=self.llm_response_cache
817
+ if self.llm_response_cache
818
+ and hasattr(self.llm_response_cache, "global_config")
819
+ else self.key_string_value_json_storage_cls(
820
+ namespace="llm_response_cache",
821
+ global_config=asdict(self),
822
+ embedding_func=None,
823
+ ),
824
+ )
825
+ elif param.mode == "naive":
826
+ response = await naive_query(
827
+ formatted_question,
828
+ self.chunks_vdb,
829
+ self.text_chunks,
830
+ param,
831
+ asdict(self),
832
+ hashing_kv=self.llm_response_cache
833
+ if self.llm_response_cache
834
+ and hasattr(self.llm_response_cache, "global_config")
835
+ else self.key_string_value_json_storage_cls(
836
+ namespace="llm_response_cache",
837
+ global_config=asdict(self),
838
+ embedding_func=None,
839
+ ),
840
+ )
841
+ elif param.mode == "mix":
842
+ response = await mix_kg_vector_query(
843
+ formatted_question,
844
+ self.chunk_entity_relation_graph,
845
+ self.entities_vdb,
846
+ self.relationships_vdb,
847
+ self.chunks_vdb,
848
+ self.text_chunks,
849
+ param,
850
+ asdict(self),
851
+ hashing_kv=self.llm_response_cache
852
+ if self.llm_response_cache
853
+ and hasattr(self.llm_response_cache, "global_config")
854
+ else self.key_string_value_json_storage_cls(
855
+ namespace="llm_response_cache",
856
+ global_config=asdict(self),
857
+ embedding_func=None,
858
+ ),
859
+ )
860
+ else:
861
+ raise ValueError(f"Unknown mode {param.mode}")
862
+
863
+ await self._query_done()
864
+ return response
865
+
866
  async def _query_done(self):
867
  tasks = []
868
  for storage_inst in [self.llm_response_cache]:
lightrag/operate.py CHANGED
@@ -681,6 +681,219 @@ async def kg_query(
681
  return response
682
 
683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  async def _build_query_context(
685
  query: list,
686
  knowledge_graph_inst: BaseGraphStorage,
 
681
  return response
682
 
683
 
684
+ async def kg_query_with_keywords(
685
+ query: str,
686
+ knowledge_graph_inst: BaseGraphStorage,
687
+ entities_vdb: BaseVectorStorage,
688
+ relationships_vdb: BaseVectorStorage,
689
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
690
+ query_param: QueryParam,
691
+ global_config: dict,
692
+ hashing_kv: BaseKVStorage = None,
693
+ ) -> str:
694
+ """
695
+ Refactored kg_query that does NOT extract keywords by itself.
696
+ It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
697
+ Then it uses those to build context and produce a final LLM response.
698
+ """
699
+
700
+ # ---------------------------
701
+ # 0) Handle potential cache
702
+ # ---------------------------
703
+ use_model_func = global_config["llm_model_func"]
704
+ args_hash = compute_args_hash(query_param.mode, query)
705
+ cached_response, quantized, min_val, max_val = await handle_cache(
706
+ hashing_kv, args_hash, query, query_param.mode
707
+ )
708
+ if cached_response is not None:
709
+ return cached_response
710
+
711
+ # ---------------------------
712
+ # 1) RETRIEVE KEYWORDS FROM query_param
713
+ # ---------------------------
714
+
715
+ # If these fields don't exist, default to empty lists/strings.
716
+ hl_keywords = getattr(query_param, "hl_keywords", []) or []
717
+ ll_keywords = getattr(query_param, "ll_keywords", []) or []
718
+
719
+ # If neither has any keywords, you could handle that logic here.
720
+ if not hl_keywords and not ll_keywords:
721
+ logger.warning(
722
+ "No keywords found in query_param. Could default to global mode or fail."
723
+ )
724
+ return PROMPTS["fail_response"]
725
+ if not ll_keywords and query_param.mode in ["local", "hybrid"]:
726
+ logger.warning("low_level_keywords is empty, switching to global mode.")
727
+ query_param.mode = "global"
728
+ if not hl_keywords and query_param.mode in ["global", "hybrid"]:
729
+ logger.warning("high_level_keywords is empty, switching to local mode.")
730
+ query_param.mode = "local"
731
+
732
+ # Flatten low-level and high-level keywords if needed
733
+ ll_keywords_flat = (
734
+ [item for sublist in ll_keywords for item in sublist]
735
+ if any(isinstance(i, list) for i in ll_keywords)
736
+ else ll_keywords
737
+ )
738
+ hl_keywords_flat = (
739
+ [item for sublist in hl_keywords for item in sublist]
740
+ if any(isinstance(i, list) for i in hl_keywords)
741
+ else hl_keywords
742
+ )
743
+
744
+ # Join the flattened lists
745
+ ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
746
+ hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
747
+
748
+ keywords = [ll_keywords_str, hl_keywords_str]
749
+
750
+ logger.info("Using %s mode for query processing", query_param.mode)
751
+
752
+ # ---------------------------
753
+ # 2) BUILD CONTEXT
754
+ # ---------------------------
755
+ context = await _build_query_context(
756
+ keywords,
757
+ knowledge_graph_inst,
758
+ entities_vdb,
759
+ relationships_vdb,
760
+ text_chunks_db,
761
+ query_param,
762
+ )
763
+ if not context:
764
+ return PROMPTS["fail_response"]
765
+
766
+ # If only context is needed, return it
767
+ if query_param.only_need_context:
768
+ return context
769
+
770
+ # ---------------------------
771
+ # 3) BUILD THE SYSTEM PROMPT + CALL LLM
772
+ # ---------------------------
773
+ sys_prompt_temp = PROMPTS["rag_response"]
774
+ sys_prompt = sys_prompt_temp.format(
775
+ context_data=context, response_type=query_param.response_type
776
+ )
777
+
778
+ if query_param.only_need_prompt:
779
+ return sys_prompt
780
+
781
+ # Now call the LLM with the final system prompt
782
+ response = await use_model_func(
783
+ query,
784
+ system_prompt=sys_prompt,
785
+ stream=query_param.stream,
786
+ )
787
+
788
+ # Clean up the response
789
+ if isinstance(response, str) and len(response) > len(sys_prompt):
790
+ response = (
791
+ response.replace(sys_prompt, "")
792
+ .replace("user", "")
793
+ .replace("model", "")
794
+ .replace(query, "")
795
+ .replace("<system>", "")
796
+ .replace("</system>", "")
797
+ .strip()
798
+ )
799
+
800
+ # ---------------------------
801
+ # 4) SAVE TO CACHE
802
+ # ---------------------------
803
+ await save_to_cache(
804
+ hashing_kv,
805
+ CacheData(
806
+ args_hash=args_hash,
807
+ content=response,
808
+ prompt=query,
809
+ quantized=quantized,
810
+ min_val=min_val,
811
+ max_val=max_val,
812
+ mode=query_param.mode,
813
+ ),
814
+ )
815
+ return response
816
+
817
+
818
+ async def extract_keywords_only(
819
+ text: str,
820
+ param: QueryParam,
821
+ global_config: dict,
822
+ hashing_kv: BaseKVStorage = None,
823
+ ) -> tuple[list[str], list[str]]:
824
+ """
825
+ Extract high-level and low-level keywords from the given 'text' using the LLM.
826
+ This method does NOT build the final RAG context or provide a final answer.
827
+ It ONLY extracts keywords (hl_keywords, ll_keywords).
828
+ """
829
+
830
+ # 1. Handle cache if needed
831
+ args_hash = compute_args_hash(param.mode, text)
832
+ cached_response, quantized, min_val, max_val = await handle_cache(
833
+ hashing_kv, args_hash, text, param.mode
834
+ )
835
+ if cached_response is not None:
836
+ # parse the cached_response if it’s JSON containing keywords
837
+ # or simply return (hl_keywords, ll_keywords) from cached
838
+ # Assuming cached_response is in the same JSON structure:
839
+ match = re.search(r"\{.*\}", cached_response, re.DOTALL)
840
+ if match:
841
+ keywords_data = json.loads(match.group(0))
842
+ hl_keywords = keywords_data.get("high_level_keywords", [])
843
+ ll_keywords = keywords_data.get("low_level_keywords", [])
844
+ return hl_keywords, ll_keywords
845
+ return [], []
846
+
847
+ # 2. Build the examples
848
+ example_number = global_config["addon_params"].get("example_number", None)
849
+ if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
850
+ examples = "\n".join(
851
+ PROMPTS["keywords_extraction_examples"][: int(example_number)]
852
+ )
853
+ else:
854
+ examples = "\n".join(PROMPTS["keywords_extraction_examples"])
855
+ language = global_config["addon_params"].get(
856
+ "language", PROMPTS["DEFAULT_LANGUAGE"]
857
+ )
858
+
859
+ # 3. Build the keyword-extraction prompt
860
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
861
+ kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
862
+
863
+ # 4. Call the LLM for keyword extraction
864
+ use_model_func = global_config["llm_model_func"]
865
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
866
+
867
+ # 5. Parse out JSON from the LLM response
868
+ match = re.search(r"\{.*\}", result, re.DOTALL)
869
+ if not match:
870
+ logger.error("No JSON-like structure found in the result.")
871
+ return [], []
872
+ try:
873
+ keywords_data = json.loads(match.group(0))
874
+ except json.JSONDecodeError as e:
875
+ logger.error(f"JSON parsing error: {e}")
876
+ return [], []
877
+
878
+ hl_keywords = keywords_data.get("high_level_keywords", [])
879
+ ll_keywords = keywords_data.get("low_level_keywords", [])
880
+
881
+ # 6. Cache the result if needed
882
+ await save_to_cache(
883
+ hashing_kv,
884
+ CacheData(
885
+ args_hash=args_hash,
886
+ content=result,
887
+ prompt=text,
888
+ quantized=quantized,
889
+ min_val=min_val,
890
+ max_val=max_val,
891
+ mode=param.mode,
892
+ ),
893
+ )
894
+ return hl_keywords, ll_keywords
895
+
896
+
897
  async def _build_query_context(
898
  query: list,
899
  knowledge_graph_inst: BaseGraphStorage,