feat(lightrag): 添加 查询时使用embedding缓存功能
Browse files- 在 LightRAG 类中添加 embedding_cache_config配置项
- 实现基于 embedding 相似度的缓存查询和存储
- 添加量化和反量化函数,用于压缩 embedding 数据
- 新增示例演示 embedding 缓存的使用
- README.md +1 -0
- examples/lightrag_openai_compatible_demo_embedding_cache.py +112 -0
- lightrag/lightrag.py +4 -1
- lightrag/llm.py +245 -33
- lightrag/utils.py +69 -0
README.md
CHANGED
@@ -596,6 +596,7 @@ if __name__ == "__main__":
|
|
596 |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
597 |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
|
598 |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
|
|
599 |
|
600 |
## API Server Implementation
|
601 |
|
|
|
596 |
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
597 |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
|
598 |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
599 |
+
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
|
600 |
|
601 |
## API Server Implementation
|
602 |
|
examples/lightrag_openai_compatible_demo_embedding_cache.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
from lightrag import LightRAG, QueryParam
|
4 |
+
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
5 |
+
from lightrag.utils import EmbeddingFunc
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
WORKING_DIR = "./dickens"
|
9 |
+
|
10 |
+
if not os.path.exists(WORKING_DIR):
|
11 |
+
os.mkdir(WORKING_DIR)
|
12 |
+
|
13 |
+
|
14 |
+
async def llm_model_func(
|
15 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
16 |
+
) -> str:
|
17 |
+
return await openai_complete_if_cache(
|
18 |
+
"solar-mini",
|
19 |
+
prompt,
|
20 |
+
system_prompt=system_prompt,
|
21 |
+
history_messages=history_messages,
|
22 |
+
api_key=os.getenv("UPSTAGE_API_KEY"),
|
23 |
+
base_url="https://api.upstage.ai/v1/solar",
|
24 |
+
**kwargs,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
async def embedding_func(texts: list[str]) -> np.ndarray:
|
29 |
+
return await openai_embedding(
|
30 |
+
texts,
|
31 |
+
model="solar-embedding-1-large-query",
|
32 |
+
api_key=os.getenv("UPSTAGE_API_KEY"),
|
33 |
+
base_url="https://api.upstage.ai/v1/solar",
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
async def get_embedding_dim():
|
38 |
+
test_text = ["This is a test sentence."]
|
39 |
+
embedding = await embedding_func(test_text)
|
40 |
+
embedding_dim = embedding.shape[1]
|
41 |
+
return embedding_dim
|
42 |
+
|
43 |
+
|
44 |
+
# function test
|
45 |
+
async def test_funcs():
|
46 |
+
result = await llm_model_func("How are you?")
|
47 |
+
print("llm_model_func: ", result)
|
48 |
+
|
49 |
+
result = await embedding_func(["How are you?"])
|
50 |
+
print("embedding_func: ", result)
|
51 |
+
|
52 |
+
|
53 |
+
# asyncio.run(test_funcs())
|
54 |
+
|
55 |
+
|
56 |
+
async def main():
|
57 |
+
try:
|
58 |
+
embedding_dimension = await get_embedding_dim()
|
59 |
+
print(f"Detected embedding dimension: {embedding_dimension}")
|
60 |
+
|
61 |
+
rag = LightRAG(
|
62 |
+
working_dir=WORKING_DIR,
|
63 |
+
embedding_cache_config={
|
64 |
+
"enabled": True,
|
65 |
+
"similarity_threshold": 0.90, # 可以自定义阈值
|
66 |
+
},
|
67 |
+
llm_model_func=llm_model_func,
|
68 |
+
embedding_func=EmbeddingFunc(
|
69 |
+
embedding_dim=embedding_dimension,
|
70 |
+
max_token_size=8192,
|
71 |
+
func=embedding_func,
|
72 |
+
),
|
73 |
+
)
|
74 |
+
|
75 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
76 |
+
await rag.ainsert(f.read())
|
77 |
+
|
78 |
+
# Perform naive search
|
79 |
+
print(
|
80 |
+
await rag.aquery(
|
81 |
+
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
# Perform local search
|
86 |
+
print(
|
87 |
+
await rag.aquery(
|
88 |
+
"What are the top themes in this story?", param=QueryParam(mode="local")
|
89 |
+
)
|
90 |
+
)
|
91 |
+
|
92 |
+
# Perform global search
|
93 |
+
print(
|
94 |
+
await rag.aquery(
|
95 |
+
"What are the top themes in this story?",
|
96 |
+
param=QueryParam(mode="global"),
|
97 |
+
)
|
98 |
+
)
|
99 |
+
|
100 |
+
# Perform hybrid search
|
101 |
+
print(
|
102 |
+
await rag.aquery(
|
103 |
+
"What are the top themes in this story?",
|
104 |
+
param=QueryParam(mode="hybrid"),
|
105 |
+
)
|
106 |
+
)
|
107 |
+
except Exception as e:
|
108 |
+
print(f"An error occurred: {e}")
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
asyncio.run(main())
|
lightrag/lightrag.py
CHANGED
@@ -85,7 +85,10 @@ class LightRAG:
|
|
85 |
working_dir: str = field(
|
86 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
87 |
)
|
88 |
-
|
|
|
|
|
|
|
89 |
kv_storage: str = field(default="JsonKVStorage")
|
90 |
vector_storage: str = field(default="NanoVectorDBStorage")
|
91 |
graph_storage: str = field(default="NetworkXStorage")
|
|
|
85 |
working_dir: str = field(
|
86 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
87 |
)
|
88 |
+
# Default not to use embedding cache
|
89 |
+
embedding_cache_config: dict = field(
|
90 |
+
default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
|
91 |
+
)
|
92 |
kv_storage: str = field(default="JsonKVStorage")
|
93 |
vector_storage: str = field(default="NanoVectorDBStorage")
|
94 |
graph_storage: str = field(default="NetworkXStorage")
|
lightrag/llm.py
CHANGED
@@ -33,6 +33,8 @@ from .utils import (
|
|
33 |
compute_args_hash,
|
34 |
wrap_embedding_func_with_attrs,
|
35 |
locate_json_string_body_from_string,
|
|
|
|
|
36 |
)
|
37 |
|
38 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -65,10 +67,29 @@ async def openai_complete_if_cache(
|
|
65 |
messages.extend(history_messages)
|
66 |
messages.append({"role": "user", "content": prompt})
|
67 |
if hashing_kv is not None:
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
if "response_format" in kwargs:
|
74 |
response = await openai_async_client.beta.chat.completions.parse(
|
@@ -81,10 +102,24 @@ async def openai_complete_if_cache(
|
|
81 |
content = response.choices[0].message.content
|
82 |
if r"\u" in content:
|
83 |
content = content.encode("utf-8").decode("unicode_escape")
|
84 |
-
|
85 |
if hashing_kv is not None:
|
86 |
await hashing_kv.upsert(
|
87 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
return content
|
90 |
|
@@ -125,10 +160,28 @@ async def azure_openai_complete_if_cache(
|
|
125 |
if prompt is not None:
|
126 |
messages.append({"role": "user", "content": prompt})
|
127 |
if hashing_kv is not None:
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
response = await openai_async_client.chat.completions.create(
|
134 |
model=model, messages=messages, **kwargs
|
@@ -136,7 +189,21 @@ async def azure_openai_complete_if_cache(
|
|
136 |
|
137 |
if hashing_kv is not None:
|
138 |
await hashing_kv.upsert(
|
139 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
return response.choices[0].message.content
|
142 |
|
@@ -204,10 +271,29 @@ async def bedrock_complete_if_cache(
|
|
204 |
|
205 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
206 |
if hashing_kv is not None:
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
# Call model via Converse API
|
213 |
session = aioboto3.Session()
|
@@ -223,6 +309,19 @@ async def bedrock_complete_if_cache(
|
|
223 |
args_hash: {
|
224 |
"return": response["output"]["message"]["content"][0]["text"],
|
225 |
"model": model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
}
|
227 |
}
|
228 |
)
|
@@ -245,7 +344,11 @@ def initialize_hf_model(model_name):
|
|
245 |
|
246 |
|
247 |
async def hf_model_if_cache(
|
248 |
-
model,
|
|
|
|
|
|
|
|
|
249 |
) -> str:
|
250 |
model_name = model
|
251 |
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
@@ -257,10 +360,30 @@ async def hf_model_if_cache(
|
|
257 |
messages.append({"role": "user", "content": prompt})
|
258 |
|
259 |
if hashing_kv is not None:
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
input_prompt = ""
|
265 |
try:
|
266 |
input_prompt = hf_tokenizer.apply_chat_template(
|
@@ -305,12 +428,32 @@ async def hf_model_if_cache(
|
|
305 |
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
306 |
)
|
307 |
if hashing_kv is not None:
|
308 |
-
await hashing_kv.upsert(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
return response_text
|
310 |
|
311 |
|
312 |
async def ollama_model_if_cache(
|
313 |
-
model,
|
|
|
|
|
|
|
|
|
314 |
) -> str:
|
315 |
kwargs.pop("max_tokens", None)
|
316 |
# kwargs.pop("response_format", None) # allow json
|
@@ -326,18 +469,52 @@ async def ollama_model_if_cache(
|
|
326 |
messages.extend(history_messages)
|
327 |
messages.append({"role": "user", "content": prompt})
|
328 |
if hashing_kv is not None:
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
335 |
|
336 |
result = response["message"]["content"]
|
337 |
|
338 |
if hashing_kv is not None:
|
339 |
-
await hashing_kv.upsert(
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
return result
|
342 |
|
343 |
|
@@ -444,10 +621,29 @@ async def lmdeploy_model_if_cache(
|
|
444 |
messages.extend(history_messages)
|
445 |
messages.append({"role": "user", "content": prompt})
|
446 |
if hashing_kv is not None:
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
gen_config = GenerationConfig(
|
453 |
skip_special_tokens=skip_special_tokens,
|
@@ -466,7 +662,23 @@ async def lmdeploy_model_if_cache(
|
|
466 |
response += res.response
|
467 |
|
468 |
if hashing_kv is not None:
|
469 |
-
await hashing_kv.upsert(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
return response
|
471 |
|
472 |
|
|
|
33 |
compute_args_hash,
|
34 |
wrap_embedding_func_with_attrs,
|
35 |
locate_json_string_body_from_string,
|
36 |
+
quantize_embedding,
|
37 |
+
get_best_cached_response,
|
38 |
)
|
39 |
|
40 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
67 |
messages.extend(history_messages)
|
68 |
messages.append({"role": "user", "content": prompt})
|
69 |
if hashing_kv is not None:
|
70 |
+
# Get embedding cache configuration
|
71 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
72 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
73 |
+
)
|
74 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
75 |
+
if is_embedding_cache_enabled:
|
76 |
+
# Use embedding cache
|
77 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
78 |
+
current_embedding = await embedding_model_func([prompt])
|
79 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
80 |
+
best_cached_response = await get_best_cached_response(
|
81 |
+
hashing_kv,
|
82 |
+
current_embedding[0],
|
83 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
84 |
+
)
|
85 |
+
if best_cached_response is not None:
|
86 |
+
return best_cached_response
|
87 |
+
else:
|
88 |
+
# Use regular cache
|
89 |
+
args_hash = compute_args_hash(model, messages)
|
90 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
91 |
+
if if_cache_return is not None:
|
92 |
+
return if_cache_return["return"]
|
93 |
|
94 |
if "response_format" in kwargs:
|
95 |
response = await openai_async_client.beta.chat.completions.parse(
|
|
|
102 |
content = response.choices[0].message.content
|
103 |
if r"\u" in content:
|
104 |
content = content.encode("utf-8").decode("unicode_escape")
|
105 |
+
|
106 |
if hashing_kv is not None:
|
107 |
await hashing_kv.upsert(
|
108 |
+
{
|
109 |
+
args_hash: {
|
110 |
+
"return": content,
|
111 |
+
"model": model,
|
112 |
+
"embedding": quantized.tobytes().hex()
|
113 |
+
if is_embedding_cache_enabled
|
114 |
+
else None,
|
115 |
+
"embedding_shape": quantized.shape
|
116 |
+
if is_embedding_cache_enabled
|
117 |
+
else None,
|
118 |
+
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
119 |
+
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
120 |
+
"original_prompt": prompt,
|
121 |
+
}
|
122 |
+
}
|
123 |
)
|
124 |
return content
|
125 |
|
|
|
160 |
if prompt is not None:
|
161 |
messages.append({"role": "user", "content": prompt})
|
162 |
if hashing_kv is not None:
|
163 |
+
# Get embedding cache configuration
|
164 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
165 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
166 |
+
)
|
167 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
168 |
+
if is_embedding_cache_enabled:
|
169 |
+
# Use embedding cache
|
170 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
171 |
+
current_embedding = await embedding_model_func([prompt])
|
172 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
173 |
+
best_cached_response = await get_best_cached_response(
|
174 |
+
hashing_kv,
|
175 |
+
current_embedding[0],
|
176 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
177 |
+
)
|
178 |
+
if best_cached_response is not None:
|
179 |
+
return best_cached_response
|
180 |
+
else:
|
181 |
+
args_hash = compute_args_hash(model, messages)
|
182 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
183 |
+
if if_cache_return is not None:
|
184 |
+
return if_cache_return["return"]
|
185 |
|
186 |
response = await openai_async_client.chat.completions.create(
|
187 |
model=model, messages=messages, **kwargs
|
|
|
189 |
|
190 |
if hashing_kv is not None:
|
191 |
await hashing_kv.upsert(
|
192 |
+
{
|
193 |
+
args_hash: {
|
194 |
+
"return": response.choices[0].message.content,
|
195 |
+
"model": model,
|
196 |
+
"embedding": quantized.tobytes().hex()
|
197 |
+
if is_embedding_cache_enabled
|
198 |
+
else None,
|
199 |
+
"embedding_shape": quantized.shape
|
200 |
+
if is_embedding_cache_enabled
|
201 |
+
else None,
|
202 |
+
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
203 |
+
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
204 |
+
"original_prompt": prompt,
|
205 |
+
}
|
206 |
+
}
|
207 |
)
|
208 |
return response.choices[0].message.content
|
209 |
|
|
|
271 |
|
272 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
273 |
if hashing_kv is not None:
|
274 |
+
# Get embedding cache configuration
|
275 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
276 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
277 |
+
)
|
278 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
279 |
+
if is_embedding_cache_enabled:
|
280 |
+
# Use embedding cache
|
281 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
282 |
+
current_embedding = await embedding_model_func([prompt])
|
283 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
284 |
+
best_cached_response = await get_best_cached_response(
|
285 |
+
hashing_kv,
|
286 |
+
current_embedding[0],
|
287 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
288 |
+
)
|
289 |
+
if best_cached_response is not None:
|
290 |
+
return best_cached_response
|
291 |
+
else:
|
292 |
+
# Use regular cache
|
293 |
+
args_hash = compute_args_hash(model, messages)
|
294 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
295 |
+
if if_cache_return is not None:
|
296 |
+
return if_cache_return["return"]
|
297 |
|
298 |
# Call model via Converse API
|
299 |
session = aioboto3.Session()
|
|
|
309 |
args_hash: {
|
310 |
"return": response["output"]["message"]["content"][0]["text"],
|
311 |
"model": model,
|
312 |
+
"embedding": quantized.tobytes().hex()
|
313 |
+
if is_embedding_cache_enabled
|
314 |
+
else None,
|
315 |
+
"embedding_shape": quantized.shape
|
316 |
+
if is_embedding_cache_enabled
|
317 |
+
else None,
|
318 |
+
"embedding_min": min_val
|
319 |
+
if is_embedding_cache_enabled
|
320 |
+
else None,
|
321 |
+
"embedding_max": max_val
|
322 |
+
if is_embedding_cache_enabled
|
323 |
+
else None,
|
324 |
+
"original_prompt": prompt,
|
325 |
}
|
326 |
}
|
327 |
)
|
|
|
344 |
|
345 |
|
346 |
async def hf_model_if_cache(
|
347 |
+
model,
|
348 |
+
prompt,
|
349 |
+
system_prompt=None,
|
350 |
+
history_messages=[],
|
351 |
+
**kwargs,
|
352 |
) -> str:
|
353 |
model_name = model
|
354 |
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
|
|
360 |
messages.append({"role": "user", "content": prompt})
|
361 |
|
362 |
if hashing_kv is not None:
|
363 |
+
# Get embedding cache configuration
|
364 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
365 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
366 |
+
)
|
367 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
368 |
+
if is_embedding_cache_enabled:
|
369 |
+
# Use embedding cache
|
370 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
371 |
+
current_embedding = await embedding_model_func([prompt])
|
372 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
373 |
+
best_cached_response = await get_best_cached_response(
|
374 |
+
hashing_kv,
|
375 |
+
current_embedding[0],
|
376 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
377 |
+
)
|
378 |
+
if best_cached_response is not None:
|
379 |
+
return best_cached_response
|
380 |
+
else:
|
381 |
+
# Use regular cache
|
382 |
+
args_hash = compute_args_hash(model, messages)
|
383 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
384 |
+
if if_cache_return is not None:
|
385 |
+
return if_cache_return["return"]
|
386 |
+
|
387 |
input_prompt = ""
|
388 |
try:
|
389 |
input_prompt = hf_tokenizer.apply_chat_template(
|
|
|
428 |
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
429 |
)
|
430 |
if hashing_kv is not None:
|
431 |
+
await hashing_kv.upsert(
|
432 |
+
{
|
433 |
+
args_hash: {
|
434 |
+
"return": response_text,
|
435 |
+
"model": model,
|
436 |
+
"embedding": quantized.tobytes().hex()
|
437 |
+
if is_embedding_cache_enabled
|
438 |
+
else None,
|
439 |
+
"embedding_shape": quantized.shape
|
440 |
+
if is_embedding_cache_enabled
|
441 |
+
else None,
|
442 |
+
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
443 |
+
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
444 |
+
"original_prompt": prompt,
|
445 |
+
}
|
446 |
+
}
|
447 |
+
)
|
448 |
return response_text
|
449 |
|
450 |
|
451 |
async def ollama_model_if_cache(
|
452 |
+
model,
|
453 |
+
prompt,
|
454 |
+
system_prompt=None,
|
455 |
+
history_messages=[],
|
456 |
+
**kwargs,
|
457 |
) -> str:
|
458 |
kwargs.pop("max_tokens", None)
|
459 |
# kwargs.pop("response_format", None) # allow json
|
|
|
469 |
messages.extend(history_messages)
|
470 |
messages.append({"role": "user", "content": prompt})
|
471 |
if hashing_kv is not None:
|
472 |
+
# Get embedding cache configuration
|
473 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
474 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
475 |
+
)
|
476 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
477 |
+
if is_embedding_cache_enabled:
|
478 |
+
# Use embedding cache
|
479 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
480 |
+
current_embedding = await embedding_model_func([prompt])
|
481 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
482 |
+
best_cached_response = await get_best_cached_response(
|
483 |
+
hashing_kv,
|
484 |
+
current_embedding[0],
|
485 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
486 |
+
)
|
487 |
+
if best_cached_response is not None:
|
488 |
+
return best_cached_response
|
489 |
+
else:
|
490 |
+
# Use regular cache
|
491 |
+
args_hash = compute_args_hash(model, messages)
|
492 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
493 |
+
if if_cache_return is not None:
|
494 |
+
return if_cache_return["return"]
|
495 |
|
496 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
497 |
|
498 |
result = response["message"]["content"]
|
499 |
|
500 |
if hashing_kv is not None:
|
501 |
+
await hashing_kv.upsert(
|
502 |
+
{
|
503 |
+
args_hash: {
|
504 |
+
"return": result,
|
505 |
+
"model": model,
|
506 |
+
"embedding": quantized.tobytes().hex()
|
507 |
+
if is_embedding_cache_enabled
|
508 |
+
else None,
|
509 |
+
"embedding_shape": quantized.shape
|
510 |
+
if is_embedding_cache_enabled
|
511 |
+
else None,
|
512 |
+
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
513 |
+
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
514 |
+
"original_prompt": prompt,
|
515 |
+
}
|
516 |
+
}
|
517 |
+
)
|
518 |
return result
|
519 |
|
520 |
|
|
|
621 |
messages.extend(history_messages)
|
622 |
messages.append({"role": "user", "content": prompt})
|
623 |
if hashing_kv is not None:
|
624 |
+
# Get embedding cache configuration
|
625 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
626 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
627 |
+
)
|
628 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
629 |
+
if is_embedding_cache_enabled:
|
630 |
+
# Use embedding cache
|
631 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
632 |
+
current_embedding = await embedding_model_func([prompt])
|
633 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
634 |
+
best_cached_response = await get_best_cached_response(
|
635 |
+
hashing_kv,
|
636 |
+
current_embedding[0],
|
637 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
638 |
+
)
|
639 |
+
if best_cached_response is not None:
|
640 |
+
return best_cached_response
|
641 |
+
else:
|
642 |
+
# Use regular cache
|
643 |
+
args_hash = compute_args_hash(model, messages)
|
644 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
645 |
+
if if_cache_return is not None:
|
646 |
+
return if_cache_return["return"]
|
647 |
|
648 |
gen_config = GenerationConfig(
|
649 |
skip_special_tokens=skip_special_tokens,
|
|
|
662 |
response += res.response
|
663 |
|
664 |
if hashing_kv is not None:
|
665 |
+
await hashing_kv.upsert(
|
666 |
+
{
|
667 |
+
args_hash: {
|
668 |
+
"return": response,
|
669 |
+
"model": model,
|
670 |
+
"embedding": quantized.tobytes().hex()
|
671 |
+
if is_embedding_cache_enabled
|
672 |
+
else None,
|
673 |
+
"embedding_shape": quantized.shape
|
674 |
+
if is_embedding_cache_enabled
|
675 |
+
else None,
|
676 |
+
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
677 |
+
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
678 |
+
"original_prompt": prompt,
|
679 |
+
}
|
680 |
+
}
|
681 |
+
)
|
682 |
return response
|
683 |
|
684 |
|
lightrag/utils.py
CHANGED
@@ -307,3 +307,72 @@ def process_combine_contexts(hl, ll):
|
|
307 |
combined_sources_result = "\n".join(combined_sources_result)
|
308 |
|
309 |
return combined_sources_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
combined_sources_result = "\n".join(combined_sources_result)
|
308 |
|
309 |
return combined_sources_result
|
310 |
+
|
311 |
+
|
312 |
+
async def get_best_cached_response(
|
313 |
+
hashing_kv, current_embedding, similarity_threshold=0.95
|
314 |
+
):
|
315 |
+
"""Get the cached response with highest similarity"""
|
316 |
+
try:
|
317 |
+
# Get all keys using list_keys()
|
318 |
+
all_keys = await hashing_kv.all_keys()
|
319 |
+
max_similarity = 0
|
320 |
+
best_cached_response = None
|
321 |
+
|
322 |
+
# Get cached data one by one
|
323 |
+
for key in all_keys:
|
324 |
+
cache_data = await hashing_kv.get_by_id(key)
|
325 |
+
if cache_data is None or "embedding" not in cache_data:
|
326 |
+
continue
|
327 |
+
|
328 |
+
# Convert cached embedding list to ndarray
|
329 |
+
cached_quantized = np.frombuffer(
|
330 |
+
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
331 |
+
).reshape(cache_data["embedding_shape"])
|
332 |
+
cached_embedding = dequantize_embedding(
|
333 |
+
cached_quantized,
|
334 |
+
cache_data["embedding_min"],
|
335 |
+
cache_data["embedding_max"],
|
336 |
+
)
|
337 |
+
|
338 |
+
similarity = cosine_similarity(current_embedding, cached_embedding)
|
339 |
+
if similarity > max_similarity:
|
340 |
+
max_similarity = similarity
|
341 |
+
best_cached_response = cache_data["return"]
|
342 |
+
|
343 |
+
if max_similarity > similarity_threshold:
|
344 |
+
return best_cached_response
|
345 |
+
return None
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
logger.warning(f"Error in get_best_cached_response: {e}")
|
349 |
+
return None
|
350 |
+
|
351 |
+
|
352 |
+
def cosine_similarity(v1, v2):
|
353 |
+
"""Calculate cosine similarity between two vectors"""
|
354 |
+
dot_product = np.dot(v1, v2)
|
355 |
+
norm1 = np.linalg.norm(v1)
|
356 |
+
norm2 = np.linalg.norm(v2)
|
357 |
+
return dot_product / (norm1 * norm2)
|
358 |
+
|
359 |
+
|
360 |
+
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
361 |
+
"""Quantize embedding to specified bits"""
|
362 |
+
# Calculate min/max values for reconstruction
|
363 |
+
min_val = embedding.min()
|
364 |
+
max_val = embedding.max()
|
365 |
+
|
366 |
+
# Quantize to 0-255 range
|
367 |
+
scale = (2**bits - 1) / (max_val - min_val)
|
368 |
+
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
369 |
+
|
370 |
+
return quantized, min_val, max_val
|
371 |
+
|
372 |
+
|
373 |
+
def dequantize_embedding(
|
374 |
+
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
375 |
+
) -> np.ndarray:
|
376 |
+
"""Restore quantized embedding"""
|
377 |
+
scale = (max_val - min_val) / (2**bits - 1)
|
378 |
+
return (quantized * scale + min_val).astype(np.float32)
|