omdivyatej commited on
Commit
cb56de7
·
1 Parent(s): 1f844c6

linting errors

Browse files
examples/lightrag_multi_model_all_modes_demo.py CHANGED
@@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo"
9
  if not os.path.exists(WORKING_DIR):
10
  os.mkdir(WORKING_DIR)
11
 
 
12
  async def initialize_rag():
13
  rag = LightRAG(
14
  working_dir=WORKING_DIR,
@@ -21,6 +22,7 @@ async def initialize_rag():
21
 
22
  return rag
23
 
 
24
  def main():
25
  # Initialize RAG instance
26
  rag = asyncio.run(initialize_rag())
@@ -33,8 +35,7 @@ def main():
33
  print("--- NAIVE mode ---")
34
  print(
35
  rag.query(
36
- "What are the main themes in this story?",
37
- param=QueryParam(mode="naive")
38
  )
39
  )
40
 
@@ -42,8 +43,7 @@ def main():
42
  print("\n--- LOCAL mode ---")
43
  print(
44
  rag.query(
45
- "What are the main themes in this story?",
46
- param=QueryParam(mode="local")
47
  )
48
  )
49
 
@@ -51,8 +51,7 @@ def main():
51
  print("\n--- GLOBAL mode ---")
52
  print(
53
  rag.query(
54
- "What are the main themes in this story?",
55
- param=QueryParam(mode="global")
56
  )
57
  )
58
 
@@ -60,8 +59,7 @@ def main():
60
  print("\n--- HYBRID mode ---")
61
  print(
62
  rag.query(
63
- "What are the main themes in this story?",
64
- param=QueryParam(mode="hybrid")
65
  )
66
  )
67
 
@@ -69,8 +67,7 @@ def main():
69
  print("\n--- MIX mode ---")
70
  print(
71
  rag.query(
72
- "What are the main themes in this story?",
73
- param=QueryParam(mode="mix")
74
  )
75
  )
76
 
@@ -81,10 +78,11 @@ def main():
81
  "How does the character development reflect Victorian-era attitudes?",
82
  param=QueryParam(
83
  mode="global",
84
- model_func=gpt_4o_complete # Override default model with more capable one
85
- )
86
  )
87
  )
88
 
 
89
  if __name__ == "__main__":
90
- main()
 
9
  if not os.path.exists(WORKING_DIR):
10
  os.mkdir(WORKING_DIR)
11
 
12
+
13
  async def initialize_rag():
14
  rag = LightRAG(
15
  working_dir=WORKING_DIR,
 
22
 
23
  return rag
24
 
25
+
26
  def main():
27
  # Initialize RAG instance
28
  rag = asyncio.run(initialize_rag())
 
35
  print("--- NAIVE mode ---")
36
  print(
37
  rag.query(
38
+ "What are the main themes in this story?", param=QueryParam(mode="naive")
 
39
  )
40
  )
41
 
 
43
  print("\n--- LOCAL mode ---")
44
  print(
45
  rag.query(
46
+ "What are the main themes in this story?", param=QueryParam(mode="local")
 
47
  )
48
  )
49
 
 
51
  print("\n--- GLOBAL mode ---")
52
  print(
53
  rag.query(
54
+ "What are the main themes in this story?", param=QueryParam(mode="global")
 
55
  )
56
  )
57
 
 
59
  print("\n--- HYBRID mode ---")
60
  print(
61
  rag.query(
62
+ "What are the main themes in this story?", param=QueryParam(mode="hybrid")
 
63
  )
64
  )
65
 
 
67
  print("\n--- MIX mode ---")
68
  print(
69
  rag.query(
70
+ "What are the main themes in this story?", param=QueryParam(mode="mix")
 
71
  )
72
  )
73
 
 
78
  "How does the character development reflect Victorian-era attitudes?",
79
  param=QueryParam(
80
  mode="global",
81
+ model_func=gpt_4o_complete, # Override default model with more capable one
82
+ ),
83
  )
84
  )
85
 
86
+
87
  if __name__ == "__main__":
88
+ main()
lightrag/base.py CHANGED
@@ -84,7 +84,7 @@ class QueryParam:
84
 
85
  ids: list[str] | None = None
86
  """List of ids to filter the results."""
87
-
88
  model_func: Callable[..., object] | None = None
89
  """Optional override for the LLM model function to use for this specific query.
90
  If provided, this will be used instead of the global model function.
 
84
 
85
  ids: list[str] | None = None
86
  """List of ids to filter the results."""
87
+
88
  model_func: Callable[..., object] | None = None
89
  """Optional override for the LLM model function to use for this specific query.
90
  If provided, this will be used instead of the global model function.
lightrag/lightrag.py CHANGED
@@ -1338,7 +1338,7 @@ class LightRAG:
1338
  """
1339
  # If a custom model is provided in param, temporarily update global config
1340
  global_config = asdict(self)
1341
-
1342
  if param.mode in ["local", "global", "hybrid"]:
1343
  response = await kg_query(
1344
  query.strip(),
 
1338
  """
1339
  # If a custom model is provided in param, temporarily update global config
1340
  global_config = asdict(self)
1341
+
1342
  if param.mode in ["local", "global", "hybrid"]:
1343
  response = await kg_query(
1344
  query.strip(),
lightrag/operate.py CHANGED
@@ -705,7 +705,11 @@ async def kg_query(
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
- use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
 
 
 
 
709
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
710
  cached_response, quantized, min_val, max_val = await handle_cache(
711
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
866
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
867
 
868
  # 5. Call the LLM for keyword extraction
869
- use_model_func = param.model_func if param.model_func else global_config["llm_model_func"]
 
 
870
  result = await use_model_func(kw_prompt, keyword_extraction=True)
871
 
872
  # 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
926
  3. Combining both results for comprehensive answer generation
927
  """
928
  # 1. Cache handling
929
- use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
 
 
 
 
930
  args_hash = compute_args_hash("mix", query, cache_type="query")
931
  cached_response, quantized, min_val, max_val = await handle_cache(
932
  hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
1731
  system_prompt: str | None = None,
1732
  ) -> str | AsyncIterator[str]:
1733
  # Handle cache
1734
- use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
 
 
 
 
1735
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1736
  cached_response, quantized, min_val, max_val = await handle_cache(
1737
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
1850
  # ---------------------------
1851
  # 1) Handle potential cache for query results
1852
  # ---------------------------
1853
- use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
 
 
 
 
1854
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1855
  cached_response, quantized, min_val, max_val = await handle_cache(
1856
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
+ use_model_func = (
709
+ query_param.model_func
710
+ if query_param.model_func
711
+ else global_config["llm_model_func"]
712
+ )
713
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
714
  cached_response, quantized, min_val, max_val = await handle_cache(
715
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
870
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
871
 
872
  # 5. Call the LLM for keyword extraction
873
+ use_model_func = (
874
+ param.model_func if param.model_func else global_config["llm_model_func"]
875
+ )
876
  result = await use_model_func(kw_prompt, keyword_extraction=True)
877
 
878
  # 6. Parse out JSON from the LLM response
 
932
  3. Combining both results for comprehensive answer generation
933
  """
934
  # 1. Cache handling
935
+ use_model_func = (
936
+ query_param.model_func
937
+ if query_param.model_func
938
+ else global_config["llm_model_func"]
939
+ )
940
  args_hash = compute_args_hash("mix", query, cache_type="query")
941
  cached_response, quantized, min_val, max_val = await handle_cache(
942
  hashing_kv, args_hash, query, "mix", cache_type="query"
 
1741
  system_prompt: str | None = None,
1742
  ) -> str | AsyncIterator[str]:
1743
  # Handle cache
1744
+ use_model_func = (
1745
+ query_param.model_func
1746
+ if query_param.model_func
1747
+ else global_config["llm_model_func"]
1748
+ )
1749
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1750
  cached_response, quantized, min_val, max_val = await handle_cache(
1751
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
1864
  # ---------------------------
1865
  # 1) Handle potential cache for query results
1866
  # ---------------------------
1867
+ use_model_func = (
1868
+ query_param.model_func
1869
+ if query_param.model_func
1870
+ else global_config["llm_model_func"]
1871
+ )
1872
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1873
  cached_response, quantized, min_val, max_val = await handle_cache(
1874
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"