Merge pull request #1167 from omdivyatej/om-pr
Browse filesFeature: Dynamic LLM Selection via QueryParam for Optimized Performance
- examples/lightrag_multi_model_all_modes_demo.py +88 -0
- lightrag/base.py +7 -0
- lightrag/lightrag.py +7 -3
- lightrag/operate.py +23 -5
examples/lightrag_multi_model_all_modes_demo.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
from lightrag import LightRAG, QueryParam
|
4 |
+
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
|
5 |
+
from lightrag.kg.shared_storage import initialize_pipeline_status
|
6 |
+
|
7 |
+
WORKING_DIR = "./lightrag_demo"
|
8 |
+
|
9 |
+
if not os.path.exists(WORKING_DIR):
|
10 |
+
os.mkdir(WORKING_DIR)
|
11 |
+
|
12 |
+
|
13 |
+
async def initialize_rag():
|
14 |
+
rag = LightRAG(
|
15 |
+
working_dir=WORKING_DIR,
|
16 |
+
embedding_func=openai_embed,
|
17 |
+
llm_model_func=gpt_4o_mini_complete, # Default model for queries
|
18 |
+
)
|
19 |
+
|
20 |
+
await rag.initialize_storages()
|
21 |
+
await initialize_pipeline_status()
|
22 |
+
|
23 |
+
return rag
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
# Initialize RAG instance
|
28 |
+
rag = asyncio.run(initialize_rag())
|
29 |
+
|
30 |
+
# Load the data
|
31 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
32 |
+
rag.insert(f.read())
|
33 |
+
|
34 |
+
# Query with naive mode (default model)
|
35 |
+
print("--- NAIVE mode ---")
|
36 |
+
print(
|
37 |
+
rag.query(
|
38 |
+
"What are the main themes in this story?", param=QueryParam(mode="naive")
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
# Query with local mode (default model)
|
43 |
+
print("\n--- LOCAL mode ---")
|
44 |
+
print(
|
45 |
+
rag.query(
|
46 |
+
"What are the main themes in this story?", param=QueryParam(mode="local")
|
47 |
+
)
|
48 |
+
)
|
49 |
+
|
50 |
+
# Query with global mode (default model)
|
51 |
+
print("\n--- GLOBAL mode ---")
|
52 |
+
print(
|
53 |
+
rag.query(
|
54 |
+
"What are the main themes in this story?", param=QueryParam(mode="global")
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
# Query with hybrid mode (default model)
|
59 |
+
print("\n--- HYBRID mode ---")
|
60 |
+
print(
|
61 |
+
rag.query(
|
62 |
+
"What are the main themes in this story?", param=QueryParam(mode="hybrid")
|
63 |
+
)
|
64 |
+
)
|
65 |
+
|
66 |
+
# Query with mix mode (default model)
|
67 |
+
print("\n--- MIX mode ---")
|
68 |
+
print(
|
69 |
+
rag.query(
|
70 |
+
"What are the main themes in this story?", param=QueryParam(mode="mix")
|
71 |
+
)
|
72 |
+
)
|
73 |
+
|
74 |
+
# Query with a custom model (gpt-4o) for a more complex question
|
75 |
+
print("\n--- Using custom model for complex analysis ---")
|
76 |
+
print(
|
77 |
+
rag.query(
|
78 |
+
"How does the character development reflect Victorian-era attitudes?",
|
79 |
+
param=QueryParam(
|
80 |
+
mode="global",
|
81 |
+
model_func=gpt_4o_complete, # Override default model with more capable one
|
82 |
+
),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
lightrag/base.py
CHANGED
@@ -10,6 +10,7 @@ from typing import (
|
|
10 |
Literal,
|
11 |
TypedDict,
|
12 |
TypeVar,
|
|
|
13 |
)
|
14 |
import numpy as np
|
15 |
from .utils import EmbeddingFunc
|
@@ -84,6 +85,12 @@ class QueryParam:
|
|
84 |
ids: list[str] | None = None
|
85 |
"""List of ids to filter the results."""
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
@dataclass
|
89 |
class StorageNameSpace(ABC):
|
|
|
10 |
Literal,
|
11 |
TypedDict,
|
12 |
TypeVar,
|
13 |
+
Callable,
|
14 |
)
|
15 |
import numpy as np
|
16 |
from .utils import EmbeddingFunc
|
|
|
85 |
ids: list[str] | None = None
|
86 |
"""List of ids to filter the results."""
|
87 |
|
88 |
+
model_func: Callable[..., object] | None = None
|
89 |
+
"""Optional override for the LLM model function to use for this specific query.
|
90 |
+
If provided, this will be used instead of the global model function.
|
91 |
+
This allows using different models for different query modes.
|
92 |
+
"""
|
93 |
+
|
94 |
|
95 |
@dataclass
|
96 |
class StorageNameSpace(ABC):
|
lightrag/lightrag.py
CHANGED
@@ -1330,11 +1330,15 @@ class LightRAG:
|
|
1330 |
Args:
|
1331 |
query (str): The query to be executed.
|
1332 |
param (QueryParam): Configuration parameters for query execution.
|
|
|
1333 |
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
1334 |
|
1335 |
Returns:
|
1336 |
str: The result of the query execution.
|
1337 |
"""
|
|
|
|
|
|
|
1338 |
if param.mode in ["local", "global", "hybrid"]:
|
1339 |
response = await kg_query(
|
1340 |
query.strip(),
|
@@ -1343,7 +1347,7 @@ class LightRAG:
|
|
1343 |
self.relationships_vdb,
|
1344 |
self.text_chunks,
|
1345 |
param,
|
1346 |
-
|
1347 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1348 |
system_prompt=system_prompt,
|
1349 |
)
|
@@ -1353,7 +1357,7 @@ class LightRAG:
|
|
1353 |
self.chunks_vdb,
|
1354 |
self.text_chunks,
|
1355 |
param,
|
1356 |
-
|
1357 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1358 |
system_prompt=system_prompt,
|
1359 |
)
|
@@ -1366,7 +1370,7 @@ class LightRAG:
|
|
1366 |
self.chunks_vdb,
|
1367 |
self.text_chunks,
|
1368 |
param,
|
1369 |
-
|
1370 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1371 |
system_prompt=system_prompt,
|
1372 |
)
|
|
|
1330 |
Args:
|
1331 |
query (str): The query to be executed.
|
1332 |
param (QueryParam): Configuration parameters for query execution.
|
1333 |
+
If param.model_func is provided, it will be used instead of the global model.
|
1334 |
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
1335 |
|
1336 |
Returns:
|
1337 |
str: The result of the query execution.
|
1338 |
"""
|
1339 |
+
# If a custom model is provided in param, temporarily update global config
|
1340 |
+
global_config = asdict(self)
|
1341 |
+
|
1342 |
if param.mode in ["local", "global", "hybrid"]:
|
1343 |
response = await kg_query(
|
1344 |
query.strip(),
|
|
|
1347 |
self.relationships_vdb,
|
1348 |
self.text_chunks,
|
1349 |
param,
|
1350 |
+
global_config,
|
1351 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1352 |
system_prompt=system_prompt,
|
1353 |
)
|
|
|
1357 |
self.chunks_vdb,
|
1358 |
self.text_chunks,
|
1359 |
param,
|
1360 |
+
global_config,
|
1361 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1362 |
system_prompt=system_prompt,
|
1363 |
)
|
|
|
1370 |
self.chunks_vdb,
|
1371 |
self.text_chunks,
|
1372 |
param,
|
1373 |
+
global_config,
|
1374 |
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
1375 |
system_prompt=system_prompt,
|
1376 |
)
|
lightrag/operate.py
CHANGED
@@ -705,7 +705,11 @@ async def kg_query(
|
|
705 |
system_prompt: str | None = None,
|
706 |
) -> str | AsyncIterator[str]:
|
707 |
# Handle cache
|
708 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
709 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
710 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
711 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
@@ -866,7 +870,9 @@ async def extract_keywords_only(
|
|
866 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
867 |
|
868 |
# 5. Call the LLM for keyword extraction
|
869 |
-
use_model_func =
|
|
|
|
|
870 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
871 |
|
872 |
# 6. Parse out JSON from the LLM response
|
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
|
|
926 |
3. Combining both results for comprehensive answer generation
|
927 |
"""
|
928 |
# 1. Cache handling
|
929 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
930 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
931 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
932 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
@@ -1731,7 +1741,11 @@ async def naive_query(
|
|
1731 |
system_prompt: str | None = None,
|
1732 |
) -> str | AsyncIterator[str]:
|
1733 |
# Handle cache
|
1734 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
1735 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1736 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1737 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
|
|
1850 |
# ---------------------------
|
1851 |
# 1) Handle potential cache for query results
|
1852 |
# ---------------------------
|
1853 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
1854 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1855 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1856 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
705 |
system_prompt: str | None = None,
|
706 |
) -> str | AsyncIterator[str]:
|
707 |
# Handle cache
|
708 |
+
use_model_func = (
|
709 |
+
query_param.model_func
|
710 |
+
if query_param.model_func
|
711 |
+
else global_config["llm_model_func"]
|
712 |
+
)
|
713 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
714 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
715 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
870 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
871 |
|
872 |
# 5. Call the LLM for keyword extraction
|
873 |
+
use_model_func = (
|
874 |
+
param.model_func if param.model_func else global_config["llm_model_func"]
|
875 |
+
)
|
876 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
877 |
|
878 |
# 6. Parse out JSON from the LLM response
|
|
|
932 |
3. Combining both results for comprehensive answer generation
|
933 |
"""
|
934 |
# 1. Cache handling
|
935 |
+
use_model_func = (
|
936 |
+
query_param.model_func
|
937 |
+
if query_param.model_func
|
938 |
+
else global_config["llm_model_func"]
|
939 |
+
)
|
940 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
941 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
942 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
|
|
1741 |
system_prompt: str | None = None,
|
1742 |
) -> str | AsyncIterator[str]:
|
1743 |
# Handle cache
|
1744 |
+
use_model_func = (
|
1745 |
+
query_param.model_func
|
1746 |
+
if query_param.model_func
|
1747 |
+
else global_config["llm_model_func"]
|
1748 |
+
)
|
1749 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1750 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1751 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
1864 |
# ---------------------------
|
1865 |
# 1) Handle potential cache for query results
|
1866 |
# ---------------------------
|
1867 |
+
use_model_func = (
|
1868 |
+
query_param.model_func
|
1869 |
+
if query_param.model_func
|
1870 |
+
else global_config["llm_model_func"]
|
1871 |
+
)
|
1872 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1873 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1874 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|