omdivyatej
commited on
Commit
·
cb56de7
1
Parent(s):
1f844c6
linting errors
Browse files- examples/lightrag_multi_model_all_modes_demo.py +11 -13
- lightrag/base.py +1 -1
- lightrag/lightrag.py +1 -1
- lightrag/operate.py +23 -5
examples/lightrag_multi_model_all_modes_demo.py
CHANGED
@@ -9,6 +9,7 @@ WORKING_DIR = "./lightrag_demo"
|
|
9 |
if not os.path.exists(WORKING_DIR):
|
10 |
os.mkdir(WORKING_DIR)
|
11 |
|
|
|
12 |
async def initialize_rag():
|
13 |
rag = LightRAG(
|
14 |
working_dir=WORKING_DIR,
|
@@ -21,6 +22,7 @@ async def initialize_rag():
|
|
21 |
|
22 |
return rag
|
23 |
|
|
|
24 |
def main():
|
25 |
# Initialize RAG instance
|
26 |
rag = asyncio.run(initialize_rag())
|
@@ -33,8 +35,7 @@ def main():
|
|
33 |
print("--- NAIVE mode ---")
|
34 |
print(
|
35 |
rag.query(
|
36 |
-
"What are the main themes in this story?",
|
37 |
-
param=QueryParam(mode="naive")
|
38 |
)
|
39 |
)
|
40 |
|
@@ -42,8 +43,7 @@ def main():
|
|
42 |
print("\n--- LOCAL mode ---")
|
43 |
print(
|
44 |
rag.query(
|
45 |
-
"What are the main themes in this story?",
|
46 |
-
param=QueryParam(mode="local")
|
47 |
)
|
48 |
)
|
49 |
|
@@ -51,8 +51,7 @@ def main():
|
|
51 |
print("\n--- GLOBAL mode ---")
|
52 |
print(
|
53 |
rag.query(
|
54 |
-
"What are the main themes in this story?",
|
55 |
-
param=QueryParam(mode="global")
|
56 |
)
|
57 |
)
|
58 |
|
@@ -60,8 +59,7 @@ def main():
|
|
60 |
print("\n--- HYBRID mode ---")
|
61 |
print(
|
62 |
rag.query(
|
63 |
-
"What are the main themes in this story?",
|
64 |
-
param=QueryParam(mode="hybrid")
|
65 |
)
|
66 |
)
|
67 |
|
@@ -69,8 +67,7 @@ def main():
|
|
69 |
print("\n--- MIX mode ---")
|
70 |
print(
|
71 |
rag.query(
|
72 |
-
"What are the main themes in this story?",
|
73 |
-
param=QueryParam(mode="mix")
|
74 |
)
|
75 |
)
|
76 |
|
@@ -81,10 +78,11 @@ def main():
|
|
81 |
"How does the character development reflect Victorian-era attitudes?",
|
82 |
param=QueryParam(
|
83 |
mode="global",
|
84 |
-
model_func=gpt_4o_complete # Override default model with more capable one
|
85 |
-
)
|
86 |
)
|
87 |
)
|
88 |
|
|
|
89 |
if __name__ == "__main__":
|
90 |
-
main()
|
|
|
9 |
if not os.path.exists(WORKING_DIR):
|
10 |
os.mkdir(WORKING_DIR)
|
11 |
|
12 |
+
|
13 |
async def initialize_rag():
|
14 |
rag = LightRAG(
|
15 |
working_dir=WORKING_DIR,
|
|
|
22 |
|
23 |
return rag
|
24 |
|
25 |
+
|
26 |
def main():
|
27 |
# Initialize RAG instance
|
28 |
rag = asyncio.run(initialize_rag())
|
|
|
35 |
print("--- NAIVE mode ---")
|
36 |
print(
|
37 |
rag.query(
|
38 |
+
"What are the main themes in this story?", param=QueryParam(mode="naive")
|
|
|
39 |
)
|
40 |
)
|
41 |
|
|
|
43 |
print("\n--- LOCAL mode ---")
|
44 |
print(
|
45 |
rag.query(
|
46 |
+
"What are the main themes in this story?", param=QueryParam(mode="local")
|
|
|
47 |
)
|
48 |
)
|
49 |
|
|
|
51 |
print("\n--- GLOBAL mode ---")
|
52 |
print(
|
53 |
rag.query(
|
54 |
+
"What are the main themes in this story?", param=QueryParam(mode="global")
|
|
|
55 |
)
|
56 |
)
|
57 |
|
|
|
59 |
print("\n--- HYBRID mode ---")
|
60 |
print(
|
61 |
rag.query(
|
62 |
+
"What are the main themes in this story?", param=QueryParam(mode="hybrid")
|
|
|
63 |
)
|
64 |
)
|
65 |
|
|
|
67 |
print("\n--- MIX mode ---")
|
68 |
print(
|
69 |
rag.query(
|
70 |
+
"What are the main themes in this story?", param=QueryParam(mode="mix")
|
|
|
71 |
)
|
72 |
)
|
73 |
|
|
|
78 |
"How does the character development reflect Victorian-era attitudes?",
|
79 |
param=QueryParam(
|
80 |
mode="global",
|
81 |
+
model_func=gpt_4o_complete, # Override default model with more capable one
|
82 |
+
),
|
83 |
)
|
84 |
)
|
85 |
|
86 |
+
|
87 |
if __name__ == "__main__":
|
88 |
+
main()
|
lightrag/base.py
CHANGED
@@ -84,7 +84,7 @@ class QueryParam:
|
|
84 |
|
85 |
ids: list[str] | None = None
|
86 |
"""List of ids to filter the results."""
|
87 |
-
|
88 |
model_func: Callable[..., object] | None = None
|
89 |
"""Optional override for the LLM model function to use for this specific query.
|
90 |
If provided, this will be used instead of the global model function.
|
|
|
84 |
|
85 |
ids: list[str] | None = None
|
86 |
"""List of ids to filter the results."""
|
87 |
+
|
88 |
model_func: Callable[..., object] | None = None
|
89 |
"""Optional override for the LLM model function to use for this specific query.
|
90 |
If provided, this will be used instead of the global model function.
|
lightrag/lightrag.py
CHANGED
@@ -1338,7 +1338,7 @@ class LightRAG:
|
|
1338 |
"""
|
1339 |
# If a custom model is provided in param, temporarily update global config
|
1340 |
global_config = asdict(self)
|
1341 |
-
|
1342 |
if param.mode in ["local", "global", "hybrid"]:
|
1343 |
response = await kg_query(
|
1344 |
query.strip(),
|
|
|
1338 |
"""
|
1339 |
# If a custom model is provided in param, temporarily update global config
|
1340 |
global_config = asdict(self)
|
1341 |
+
|
1342 |
if param.mode in ["local", "global", "hybrid"]:
|
1343 |
response = await kg_query(
|
1344 |
query.strip(),
|
lightrag/operate.py
CHANGED
@@ -705,7 +705,11 @@ async def kg_query(
|
|
705 |
system_prompt: str | None = None,
|
706 |
) -> str | AsyncIterator[str]:
|
707 |
# Handle cache
|
708 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
709 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
710 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
711 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
@@ -866,7 +870,9 @@ async def extract_keywords_only(
|
|
866 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
867 |
|
868 |
# 5. Call the LLM for keyword extraction
|
869 |
-
use_model_func =
|
|
|
|
|
870 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
871 |
|
872 |
# 6. Parse out JSON from the LLM response
|
@@ -926,7 +932,11 @@ async def mix_kg_vector_query(
|
|
926 |
3. Combining both results for comprehensive answer generation
|
927 |
"""
|
928 |
# 1. Cache handling
|
929 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
930 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
931 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
932 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
@@ -1731,7 +1741,11 @@ async def naive_query(
|
|
1731 |
system_prompt: str | None = None,
|
1732 |
) -> str | AsyncIterator[str]:
|
1733 |
# Handle cache
|
1734 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
1735 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1736 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1737 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
@@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
|
|
1850 |
# ---------------------------
|
1851 |
# 1) Handle potential cache for query results
|
1852 |
# ---------------------------
|
1853 |
-
use_model_func =
|
|
|
|
|
|
|
|
|
1854 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1855 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1856 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
705 |
system_prompt: str | None = None,
|
706 |
) -> str | AsyncIterator[str]:
|
707 |
# Handle cache
|
708 |
+
use_model_func = (
|
709 |
+
query_param.model_func
|
710 |
+
if query_param.model_func
|
711 |
+
else global_config["llm_model_func"]
|
712 |
+
)
|
713 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
714 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
715 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
870 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
871 |
|
872 |
# 5. Call the LLM for keyword extraction
|
873 |
+
use_model_func = (
|
874 |
+
param.model_func if param.model_func else global_config["llm_model_func"]
|
875 |
+
)
|
876 |
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
877 |
|
878 |
# 6. Parse out JSON from the LLM response
|
|
|
932 |
3. Combining both results for comprehensive answer generation
|
933 |
"""
|
934 |
# 1. Cache handling
|
935 |
+
use_model_func = (
|
936 |
+
query_param.model_func
|
937 |
+
if query_param.model_func
|
938 |
+
else global_config["llm_model_func"]
|
939 |
+
)
|
940 |
args_hash = compute_args_hash("mix", query, cache_type="query")
|
941 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
942 |
hashing_kv, args_hash, query, "mix", cache_type="query"
|
|
|
1741 |
system_prompt: str | None = None,
|
1742 |
) -> str | AsyncIterator[str]:
|
1743 |
# Handle cache
|
1744 |
+
use_model_func = (
|
1745 |
+
query_param.model_func
|
1746 |
+
if query_param.model_func
|
1747 |
+
else global_config["llm_model_func"]
|
1748 |
+
)
|
1749 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1750 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1751 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
|
|
1864 |
# ---------------------------
|
1865 |
# 1) Handle potential cache for query results
|
1866 |
# ---------------------------
|
1867 |
+
use_model_func = (
|
1868 |
+
query_param.model_func
|
1869 |
+
if query_param.model_func
|
1870 |
+
else global_config["llm_model_func"]
|
1871 |
+
)
|
1872 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
1873 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1874 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|