Daniel.y commited on
Commit
f46dd17
·
unverified ·
2 Parent(s): 70ded13 cb56de7

Merge pull request #1167 from omdivyatej/om-pr

Browse files

Feature: Dynamic LLM Selection via QueryParam for Optimized Performance

examples/lightrag_multi_model_all_modes_demo.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
5
+ from lightrag.kg.shared_storage import initialize_pipeline_status
6
+
7
+ WORKING_DIR = "./lightrag_demo"
8
+
9
+ if not os.path.exists(WORKING_DIR):
10
+ os.mkdir(WORKING_DIR)
11
+
12
+
13
+ async def initialize_rag():
14
+ rag = LightRAG(
15
+ working_dir=WORKING_DIR,
16
+ embedding_func=openai_embed,
17
+ llm_model_func=gpt_4o_mini_complete, # Default model for queries
18
+ )
19
+
20
+ await rag.initialize_storages()
21
+ await initialize_pipeline_status()
22
+
23
+ return rag
24
+
25
+
26
+ def main():
27
+ # Initialize RAG instance
28
+ rag = asyncio.run(initialize_rag())
29
+
30
+ # Load the data
31
+ with open("./book.txt", "r", encoding="utf-8") as f:
32
+ rag.insert(f.read())
33
+
34
+ # Query with naive mode (default model)
35
+ print("--- NAIVE mode ---")
36
+ print(
37
+ rag.query(
38
+ "What are the main themes in this story?", param=QueryParam(mode="naive")
39
+ )
40
+ )
41
+
42
+ # Query with local mode (default model)
43
+ print("\n--- LOCAL mode ---")
44
+ print(
45
+ rag.query(
46
+ "What are the main themes in this story?", param=QueryParam(mode="local")
47
+ )
48
+ )
49
+
50
+ # Query with global mode (default model)
51
+ print("\n--- GLOBAL mode ---")
52
+ print(
53
+ rag.query(
54
+ "What are the main themes in this story?", param=QueryParam(mode="global")
55
+ )
56
+ )
57
+
58
+ # Query with hybrid mode (default model)
59
+ print("\n--- HYBRID mode ---")
60
+ print(
61
+ rag.query(
62
+ "What are the main themes in this story?", param=QueryParam(mode="hybrid")
63
+ )
64
+ )
65
+
66
+ # Query with mix mode (default model)
67
+ print("\n--- MIX mode ---")
68
+ print(
69
+ rag.query(
70
+ "What are the main themes in this story?", param=QueryParam(mode="mix")
71
+ )
72
+ )
73
+
74
+ # Query with a custom model (gpt-4o) for a more complex question
75
+ print("\n--- Using custom model for complex analysis ---")
76
+ print(
77
+ rag.query(
78
+ "How does the character development reflect Victorian-era attitudes?",
79
+ param=QueryParam(
80
+ mode="global",
81
+ model_func=gpt_4o_complete, # Override default model with more capable one
82
+ ),
83
+ )
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
lightrag/base.py CHANGED
@@ -10,6 +10,7 @@ from typing import (
10
  Literal,
11
  TypedDict,
12
  TypeVar,
 
13
  )
14
  import numpy as np
15
  from .utils import EmbeddingFunc
@@ -84,6 +85,12 @@ class QueryParam:
84
  ids: list[str] | None = None
85
  """List of ids to filter the results."""
86
 
 
 
 
 
 
 
87
 
88
  @dataclass
89
  class StorageNameSpace(ABC):
 
10
  Literal,
11
  TypedDict,
12
  TypeVar,
13
+ Callable,
14
  )
15
  import numpy as np
16
  from .utils import EmbeddingFunc
 
85
  ids: list[str] | None = None
86
  """List of ids to filter the results."""
87
 
88
+ model_func: Callable[..., object] | None = None
89
+ """Optional override for the LLM model function to use for this specific query.
90
+ If provided, this will be used instead of the global model function.
91
+ This allows using different models for different query modes.
92
+ """
93
+
94
 
95
  @dataclass
96
  class StorageNameSpace(ABC):
lightrag/lightrag.py CHANGED
@@ -1330,11 +1330,15 @@ class LightRAG:
1330
  Args:
1331
  query (str): The query to be executed.
1332
  param (QueryParam): Configuration parameters for query execution.
 
1333
  prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1334
 
1335
  Returns:
1336
  str: The result of the query execution.
1337
  """
 
 
 
1338
  if param.mode in ["local", "global", "hybrid"]:
1339
  response = await kg_query(
1340
  query.strip(),
@@ -1343,7 +1347,7 @@ class LightRAG:
1343
  self.relationships_vdb,
1344
  self.text_chunks,
1345
  param,
1346
- asdict(self),
1347
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1348
  system_prompt=system_prompt,
1349
  )
@@ -1353,7 +1357,7 @@ class LightRAG:
1353
  self.chunks_vdb,
1354
  self.text_chunks,
1355
  param,
1356
- asdict(self),
1357
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1358
  system_prompt=system_prompt,
1359
  )
@@ -1366,7 +1370,7 @@ class LightRAG:
1366
  self.chunks_vdb,
1367
  self.text_chunks,
1368
  param,
1369
- asdict(self),
1370
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1371
  system_prompt=system_prompt,
1372
  )
 
1330
  Args:
1331
  query (str): The query to be executed.
1332
  param (QueryParam): Configuration parameters for query execution.
1333
+ If param.model_func is provided, it will be used instead of the global model.
1334
  prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1335
 
1336
  Returns:
1337
  str: The result of the query execution.
1338
  """
1339
+ # If a custom model is provided in param, temporarily update global config
1340
+ global_config = asdict(self)
1341
+
1342
  if param.mode in ["local", "global", "hybrid"]:
1343
  response = await kg_query(
1344
  query.strip(),
 
1347
  self.relationships_vdb,
1348
  self.text_chunks,
1349
  param,
1350
+ global_config,
1351
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1352
  system_prompt=system_prompt,
1353
  )
 
1357
  self.chunks_vdb,
1358
  self.text_chunks,
1359
  param,
1360
+ global_config,
1361
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1362
  system_prompt=system_prompt,
1363
  )
 
1370
  self.chunks_vdb,
1371
  self.text_chunks,
1372
  param,
1373
+ global_config,
1374
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1375
  system_prompt=system_prompt,
1376
  )
lightrag/operate.py CHANGED
@@ -705,7 +705,11 @@ async def kg_query(
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
- use_model_func = global_config["llm_model_func"]
 
 
 
 
709
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
710
  cached_response, quantized, min_val, max_val = await handle_cache(
711
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +870,9 @@ async def extract_keywords_only(
866
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
867
 
868
  # 5. Call the LLM for keyword extraction
869
- use_model_func = global_config["llm_model_func"]
 
 
870
  result = await use_model_func(kw_prompt, keyword_extraction=True)
871
 
872
  # 6. Parse out JSON from the LLM response
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
926
  3. Combining both results for comprehensive answer generation
927
  """
928
  # 1. Cache handling
929
- use_model_func = global_config["llm_model_func"]
 
 
 
 
930
  args_hash = compute_args_hash("mix", query, cache_type="query")
931
  cached_response, quantized, min_val, max_val = await handle_cache(
932
  hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1741,11 @@ async def naive_query(
1731
  system_prompt: str | None = None,
1732
  ) -> str | AsyncIterator[str]:
1733
  # Handle cache
1734
- use_model_func = global_config["llm_model_func"]
 
 
 
 
1735
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1736
  cached_response, quantized, min_val, max_val = await handle_cache(
1737
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
1850
  # ---------------------------
1851
  # 1) Handle potential cache for query results
1852
  # ---------------------------
1853
- use_model_func = global_config["llm_model_func"]
 
 
 
 
1854
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1855
  cached_response, quantized, min_val, max_val = await handle_cache(
1856
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
+ use_model_func = (
709
+ query_param.model_func
710
+ if query_param.model_func
711
+ else global_config["llm_model_func"]
712
+ )
713
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
714
  cached_response, quantized, min_val, max_val = await handle_cache(
715
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
870
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
871
 
872
  # 5. Call the LLM for keyword extraction
873
+ use_model_func = (
874
+ param.model_func if param.model_func else global_config["llm_model_func"]
875
+ )
876
  result = await use_model_func(kw_prompt, keyword_extraction=True)
877
 
878
  # 6. Parse out JSON from the LLM response
 
932
  3. Combining both results for comprehensive answer generation
933
  """
934
  # 1. Cache handling
935
+ use_model_func = (
936
+ query_param.model_func
937
+ if query_param.model_func
938
+ else global_config["llm_model_func"]
939
+ )
940
  args_hash = compute_args_hash("mix", query, cache_type="query")
941
  cached_response, quantized, min_val, max_val = await handle_cache(
942
  hashing_kv, args_hash, query, "mix", cache_type="query"
 
1741
  system_prompt: str | None = None,
1742
  ) -> str | AsyncIterator[str]:
1743
  # Handle cache
1744
+ use_model_func = (
1745
+ query_param.model_func
1746
+ if query_param.model_func
1747
+ else global_config["llm_model_func"]
1748
+ )
1749
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1750
  cached_response, quantized, min_val, max_val = await handle_cache(
1751
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
1864
  # ---------------------------
1865
  # 1) Handle potential cache for query results
1866
  # ---------------------------
1867
+ use_model_func = (
1868
+ query_param.model_func
1869
+ if query_param.model_func
1870
+ else global_config["llm_model_func"]
1871
+ )
1872
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1873
  cached_response, quantized, min_val, max_val = await handle_cache(
1874
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"