MdNazishArman commited on
Commit
11e481c
·
1 Parent(s): 678a87e

Added system prompt support in all modes

Browse files
Files changed (3) hide show
  1. README.md +11 -2
  2. lightrag/lightrag.py +9 -4
  3. lightrag/operate.py +18 -12
README.md CHANGED
@@ -171,7 +171,7 @@ rag = LightRAG(working_dir=WORKING_DIR)
171
 
172
  # Create query parameters
173
  query_param = QueryParam(
174
- mode="hybrid", # or other mode: "local", "global", "hybrid"
175
  )
176
 
177
  # Example 1: Using the default system prompt
@@ -184,11 +184,20 @@ print(response_default)
184
  # Example 2: Using a custom prompt
185
  custom_prompt = """
186
  You are an expert assistant in environmental science. Provide detailed and structured answers with examples.
 
 
 
 
 
 
 
 
 
187
  """
188
  response_custom = rag.query(
189
  "What are the primary benefits of renewable energy?",
190
  param=query_param,
191
- prompt=custom_prompt # Pass the custom prompt
192
  )
193
  print(response_custom)
194
  ```
 
171
 
172
  # Create query parameters
173
  query_param = QueryParam(
174
+ mode="hybrid", # or other mode: "local", "global", "hybrid", "mix" and "naive"
175
  )
176
 
177
  # Example 1: Using the default system prompt
 
184
  # Example 2: Using a custom prompt
185
  custom_prompt = """
186
  You are an expert assistant in environmental science. Provide detailed and structured answers with examples.
187
+ ---Conversation History---
188
+ {history}
189
+
190
+ ---Knowledge Base---
191
+ {context_data}
192
+
193
+ ---Response Rules---
194
+
195
+ - Target format and length: {response_type}
196
  """
197
  response_custom = rag.query(
198
  "What are the primary benefits of renewable energy?",
199
  param=query_param,
200
+ system_prompt=custom_prompt # Pass the custom prompt
201
  )
202
  print(response_custom)
203
  ```
lightrag/lightrag.py CHANGED
@@ -984,7 +984,10 @@ class LightRAG:
984
  await self._insert_done()
985
 
986
  def query(
987
- self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
 
 
 
988
  ) -> str | Iterator[str]:
989
  """
990
  Perform a sync query.
@@ -999,13 +1002,13 @@ class LightRAG:
999
  """
1000
  loop = always_get_an_event_loop()
1001
 
1002
- return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
1003
 
1004
  async def aquery(
1005
  self,
1006
  query: str,
1007
  param: QueryParam = QueryParam(),
1008
- prompt: str | None = None,
1009
  ) -> str | AsyncIterator[str]:
1010
  """
1011
  Perform a async query.
@@ -1037,7 +1040,7 @@ class LightRAG:
1037
  global_config=asdict(self),
1038
  embedding_func=self.embedding_func,
1039
  ),
1040
- prompt=prompt,
1041
  )
1042
  elif param.mode == "naive":
1043
  response = await naive_query(
@@ -1056,6 +1059,7 @@ class LightRAG:
1056
  global_config=asdict(self),
1057
  embedding_func=self.embedding_func,
1058
  ),
 
1059
  )
1060
  elif param.mode == "mix":
1061
  response = await mix_kg_vector_query(
@@ -1077,6 +1081,7 @@ class LightRAG:
1077
  global_config=asdict(self),
1078
  embedding_func=self.embedding_func,
1079
  ),
 
1080
  )
1081
  else:
1082
  raise ValueError(f"Unknown mode {param.mode}")
 
984
  await self._insert_done()
985
 
986
  def query(
987
+ self,
988
+ query: str,
989
+ param: QueryParam = QueryParam(),
990
+ system_prompt: str | None = None,
991
  ) -> str | Iterator[str]:
992
  """
993
  Perform a sync query.
 
1002
  """
1003
  loop = always_get_an_event_loop()
1004
 
1005
+ return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
1006
 
1007
  async def aquery(
1008
  self,
1009
  query: str,
1010
  param: QueryParam = QueryParam(),
1011
+ system_prompt: str | None = None,
1012
  ) -> str | AsyncIterator[str]:
1013
  """
1014
  Perform a async query.
 
1040
  global_config=asdict(self),
1041
  embedding_func=self.embedding_func,
1042
  ),
1043
+ system_prompt=system_prompt,
1044
  )
1045
  elif param.mode == "naive":
1046
  response = await naive_query(
 
1059
  global_config=asdict(self),
1060
  embedding_func=self.embedding_func,
1061
  ),
1062
+ system_prompt=system_prompt,
1063
  )
1064
  elif param.mode == "mix":
1065
  response = await mix_kg_vector_query(
 
1081
  global_config=asdict(self),
1082
  embedding_func=self.embedding_func,
1083
  ),
1084
+ system_prompt=system_prompt,
1085
  )
1086
  else:
1087
  raise ValueError(f"Unknown mode {param.mode}")
lightrag/operate.py CHANGED
@@ -613,7 +613,7 @@ async def kg_query(
613
  query_param: QueryParam,
614
  global_config: dict[str, str],
615
  hashing_kv: BaseKVStorage | None = None,
616
- prompt: str | None = None,
617
  ) -> str:
618
  # Handle cache
619
  use_model_func = global_config["llm_model_func"]
@@ -677,7 +677,7 @@ async def kg_query(
677
  query_param.conversation_history, query_param.history_turns
678
  )
679
 
680
- sys_prompt_temp = prompt if prompt else PROMPTS["rag_response"]
681
  sys_prompt = sys_prompt_temp.format(
682
  context_data=context,
683
  response_type=query_param.response_type,
@@ -828,6 +828,7 @@ async def mix_kg_vector_query(
828
  query_param: QueryParam,
829
  global_config: dict[str, str],
830
  hashing_kv: BaseKVStorage | None = None,
 
831
  ) -> str | AsyncIterator[str]:
832
  """
833
  Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -962,15 +963,19 @@ async def mix_kg_vector_query(
962
  return {"kg_context": kg_context, "vector_context": vector_context}
963
 
964
  # 5. Construct hybrid prompt
965
- sys_prompt = PROMPTS["mix_rag_response"].format(
966
- kg_context=kg_context
967
- if kg_context
968
- else "No relevant knowledge graph information found",
969
- vector_context=vector_context
970
- if vector_context
971
- else "No relevant text information found",
972
- response_type=query_param.response_type,
973
- history=history_context,
 
 
 
 
974
  )
975
 
976
  if query_param.only_need_prompt:
@@ -1599,6 +1604,7 @@ async def naive_query(
1599
  query_param: QueryParam,
1600
  global_config: dict[str, str],
1601
  hashing_kv: BaseKVStorage | None = None,
 
1602
  ) -> str | AsyncIterator[str]:
1603
  # Handle cache
1604
  use_model_func = global_config["llm_model_func"]
@@ -1651,7 +1657,7 @@ async def naive_query(
1651
  query_param.conversation_history, query_param.history_turns
1652
  )
1653
 
1654
- sys_prompt_temp = PROMPTS["naive_rag_response"]
1655
  sys_prompt = sys_prompt_temp.format(
1656
  content_data=section,
1657
  response_type=query_param.response_type,
 
613
  query_param: QueryParam,
614
  global_config: dict[str, str],
615
  hashing_kv: BaseKVStorage | None = None,
616
+ system_prompt: str | None = None,
617
  ) -> str:
618
  # Handle cache
619
  use_model_func = global_config["llm_model_func"]
 
677
  query_param.conversation_history, query_param.history_turns
678
  )
679
 
680
+ sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
681
  sys_prompt = sys_prompt_temp.format(
682
  context_data=context,
683
  response_type=query_param.response_type,
 
828
  query_param: QueryParam,
829
  global_config: dict[str, str],
830
  hashing_kv: BaseKVStorage | None = None,
831
+ system_prompt: str | None = None,
832
  ) -> str | AsyncIterator[str]:
833
  """
834
  Hybrid retrieval implementation combining knowledge graph and vector search.
 
963
  return {"kg_context": kg_context, "vector_context": vector_context}
964
 
965
  # 5. Construct hybrid prompt
966
+ sys_prompt = (
967
+ system_prompt
968
+ if system_prompt
969
+ else PROMPTS["mix_rag_response"].format(
970
+ kg_context=kg_context
971
+ if kg_context
972
+ else "No relevant knowledge graph information found",
973
+ vector_context=vector_context
974
+ if vector_context
975
+ else "No relevant text information found",
976
+ response_type=query_param.response_type,
977
+ history=history_context,
978
+ )
979
  )
980
 
981
  if query_param.only_need_prompt:
 
1604
  query_param: QueryParam,
1605
  global_config: dict[str, str],
1606
  hashing_kv: BaseKVStorage | None = None,
1607
+ system_prompt: str | None = None,
1608
  ) -> str | AsyncIterator[str]:
1609
  # Handle cache
1610
  use_model_func = global_config["llm_model_func"]
 
1657
  query_param.conversation_history, query_param.history_turns
1658
  )
1659
 
1660
+ sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
1661
  sys_prompt = sys_prompt_temp.format(
1662
  content_data=section,
1663
  response_type=query_param.response_type,