omdivyatej commited on
Commit
af45684
·
1 Parent(s): d6709ba

specify LLM for query

Browse files
examples/lightrag_multi_model_all_modes_demo.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
5
+ from lightrag.kg.shared_storage import initialize_pipeline_status
6
+ from lightrag.utils import setup_logger
7
+
8
+ setup_logger("lightrag", level="INFO")
9
+
10
+ WORKING_DIR = "./all_modes_demo"
11
+
12
+ if not os.path.exists(WORKING_DIR):
13
+ os.mkdir(WORKING_DIR)
14
+
15
+
16
+ async def initialize_rag():
17
+ # Initialize LightRAG with a base model (gpt-4o-mini)
18
+ rag = LightRAG(
19
+ working_dir=WORKING_DIR,
20
+ embedding_func=openai_embed,
21
+ llm_model_func=gpt_4o_mini_complete, # Default model for most queries
22
+ )
23
+
24
+ await rag.initialize_storages()
25
+ await initialize_pipeline_status()
26
+
27
+ return rag
28
+
29
+
30
+ def main():
31
+ # Initialize RAG instance
32
+ rag = asyncio.run(initialize_rag())
33
+
34
+ # Load the data
35
+ with open("./book.txt", "r", encoding="utf-8") as f:
36
+ rag.insert(f.read())
37
+
38
+ # Example query
39
+ query_text = "What are the main themes in this story?"
40
+
41
+ # Demonstrate using default model (gpt-4o-mini) for all modes
42
+ print("\n===== Default Model (gpt-4o-mini) =====")
43
+
44
+ for mode in ["local", "global", "hybrid", "naive", "mix"]:
45
+ print(f"\n--- {mode.upper()} mode with default model ---")
46
+ response = rag.query(
47
+ query_text,
48
+ param=QueryParam(mode=mode)
49
+ )
50
+ print(response)
51
+
52
+ # Demonstrate using custom model (gpt-4o) for all modes
53
+ print("\n===== Custom Model (gpt-4o) =====")
54
+
55
+ for mode in ["local", "global", "hybrid", "naive", "mix"]:
56
+ print(f"\n--- {mode.upper()} mode with custom model ---")
57
+ response = rag.query(
58
+ query_text,
59
+ param=QueryParam(
60
+ mode=mode,
61
+ model_func=gpt_4o_complete # Override with more capable model
62
+ )
63
+ )
64
+ print(response)
65
+
66
+ # Mixed approach - use different models for different modes
67
+ print("\n===== Strategic Model Selection =====")
68
+
69
+ # Complex analytical question
70
+ complex_query = "How does the character development in the story reflect Victorian-era social values?"
71
+
72
+ # Use default model for simpler modes
73
+ print("\n--- NAIVE mode with default model (suitable for simple retrieval) ---")
74
+ response1 = rag.query(
75
+ complex_query,
76
+ param=QueryParam(mode="naive") # Use default model for basic retrieval
77
+ )
78
+ print(response1)
79
+
80
+ # Use more capable model for complex modes
81
+ print("\n--- HYBRID mode with more capable model (for complex analysis) ---")
82
+ response2 = rag.query(
83
+ complex_query,
84
+ param=QueryParam(
85
+ mode="hybrid",
86
+ model_func=gpt_4o_complete # Use more capable model for complex analysis
87
+ )
88
+ )
89
+ print(response2)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
lightrag/base.py CHANGED
@@ -10,6 +10,7 @@ from typing import (
10
  Literal,
11
  TypedDict,
12
  TypeVar,
 
13
  )
14
  import numpy as np
15
  from .utils import EmbeddingFunc
@@ -83,6 +84,12 @@ class QueryParam:
83
 
84
  ids: list[str] | None = None
85
  """List of ids to filter the results."""
 
 
 
 
 
 
86
 
87
 
88
  @dataclass
 
10
  Literal,
11
  TypedDict,
12
  TypeVar,
13
+ Callable,
14
  )
15
  import numpy as np
16
  from .utils import EmbeddingFunc
 
84
 
85
  ids: list[str] | None = None
86
  """List of ids to filter the results."""
87
+
88
+ model_func: Callable[..., object] | None = None
89
+ """Optional override for the LLM model function to use for this specific query.
90
+ If provided, this will be used instead of the global model function.
91
+ This allows using different models for different query modes.
92
+ """
93
 
94
 
95
  @dataclass
lightrag/lightrag.py CHANGED
@@ -1330,11 +1330,15 @@ class LightRAG:
1330
  Args:
1331
  query (str): The query to be executed.
1332
  param (QueryParam): Configuration parameters for query execution.
 
1333
  prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1334
 
1335
  Returns:
1336
  str: The result of the query execution.
1337
  """
 
 
 
1338
  if param.mode in ["local", "global", "hybrid"]:
1339
  response = await kg_query(
1340
  query.strip(),
@@ -1343,7 +1347,7 @@ class LightRAG:
1343
  self.relationships_vdb,
1344
  self.text_chunks,
1345
  param,
1346
- asdict(self),
1347
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1348
  system_prompt=system_prompt,
1349
  )
@@ -1353,7 +1357,7 @@ class LightRAG:
1353
  self.chunks_vdb,
1354
  self.text_chunks,
1355
  param,
1356
- asdict(self),
1357
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1358
  system_prompt=system_prompt,
1359
  )
@@ -1366,7 +1370,7 @@ class LightRAG:
1366
  self.chunks_vdb,
1367
  self.text_chunks,
1368
  param,
1369
- asdict(self),
1370
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1371
  system_prompt=system_prompt,
1372
  )
 
1330
  Args:
1331
  query (str): The query to be executed.
1332
  param (QueryParam): Configuration parameters for query execution.
1333
+ If param.model_func is provided, it will be used instead of the global model.
1334
  prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1335
 
1336
  Returns:
1337
  str: The result of the query execution.
1338
  """
1339
+ # If a custom model is provided in param, temporarily update global config
1340
+ global_config = asdict(self)
1341
+
1342
  if param.mode in ["local", "global", "hybrid"]:
1343
  response = await kg_query(
1344
  query.strip(),
 
1347
  self.relationships_vdb,
1348
  self.text_chunks,
1349
  param,
1350
+ global_config,
1351
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1352
  system_prompt=system_prompt,
1353
  )
 
1357
  self.chunks_vdb,
1358
  self.text_chunks,
1359
  param,
1360
+ global_config,
1361
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1362
  system_prompt=system_prompt,
1363
  )
 
1370
  self.chunks_vdb,
1371
  self.text_chunks,
1372
  param,
1373
+ global_config,
1374
  hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1375
  system_prompt=system_prompt,
1376
  )
lightrag/operate.py CHANGED
@@ -705,7 +705,7 @@ async def kg_query(
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
- use_model_func = global_config["llm_model_func"]
709
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
710
  cached_response, quantized, min_val, max_val = await handle_cache(
711
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -866,7 +866,7 @@ async def extract_keywords_only(
866
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
867
 
868
  # 5. Call the LLM for keyword extraction
869
- use_model_func = global_config["llm_model_func"]
870
  result = await use_model_func(kw_prompt, keyword_extraction=True)
871
 
872
  # 6. Parse out JSON from the LLM response
@@ -926,7 +926,7 @@ async def mix_kg_vector_query(
926
  3. Combining both results for comprehensive answer generation
927
  """
928
  # 1. Cache handling
929
- use_model_func = global_config["llm_model_func"]
930
  args_hash = compute_args_hash("mix", query, cache_type="query")
931
  cached_response, quantized, min_val, max_val = await handle_cache(
932
  hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -1731,7 +1731,7 @@ async def naive_query(
1731
  system_prompt: str | None = None,
1732
  ) -> str | AsyncIterator[str]:
1733
  # Handle cache
1734
- use_model_func = global_config["llm_model_func"]
1735
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1736
  cached_response, quantized, min_val, max_val = await handle_cache(
1737
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1850,7 +1850,7 @@ async def kg_query_with_keywords(
1850
  # ---------------------------
1851
  # 1) Handle potential cache for query results
1852
  # ---------------------------
1853
- use_model_func = global_config["llm_model_func"]
1854
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1855
  cached_response, quantized, min_val, max_val = await handle_cache(
1856
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
705
  system_prompt: str | None = None,
706
  ) -> str | AsyncIterator[str]:
707
  # Handle cache
708
+ use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
709
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
710
  cached_response, quantized, min_val, max_val = await handle_cache(
711
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
866
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
867
 
868
  # 5. Call the LLM for keyword extraction
869
+ use_model_func = param.model_func if param.model_func else global_config["llm_model_func"]
870
  result = await use_model_func(kw_prompt, keyword_extraction=True)
871
 
872
  # 6. Parse out JSON from the LLM response
 
926
  3. Combining both results for comprehensive answer generation
927
  """
928
  # 1. Cache handling
929
+ use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
930
  args_hash = compute_args_hash("mix", query, cache_type="query")
931
  cached_response, quantized, min_val, max_val = await handle_cache(
932
  hashing_kv, args_hash, query, "mix", cache_type="query"
 
1731
  system_prompt: str | None = None,
1732
  ) -> str | AsyncIterator[str]:
1733
  # Handle cache
1734
+ use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
1735
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1736
  cached_response, quantized, min_val, max_val = await handle_cache(
1737
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 
1850
  # ---------------------------
1851
  # 1) Handle potential cache for query results
1852
  # ---------------------------
1853
+ use_model_func = query_param.model_func if query_param.model_func else global_config["llm_model_func"]
1854
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1855
  cached_response, quantized, min_val, max_val = await handle_cache(
1856
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"