Magicyuan commited on
Commit
d4fb0be
·
1 Parent(s): 2bd1942

feat(lightrag): 添加 查询时使用embedding缓存功能

Browse files

- 在 LightRAG 类中添加 embedding_cache_config配置项
- 实现基于 embedding 相似度的缓存查询和存储
- 添加量化和反量化函数,用于压缩 embedding 数据
- 新增示例演示 embedding 缓存的使用

README.md CHANGED
@@ -596,6 +596,7 @@ if __name__ == "__main__":
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
 
599
 
600
  ## API Server Implementation
601
 
 
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
599
+ | **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
600
 
601
  ## API Server Implementation
602
 
examples/lightrag_openai_compatible_demo_embedding_cache.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+ import numpy as np
7
+
8
+ WORKING_DIR = "./dickens"
9
+
10
+ if not os.path.exists(WORKING_DIR):
11
+ os.mkdir(WORKING_DIR)
12
+
13
+
14
+ async def llm_model_func(
15
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
16
+ ) -> str:
17
+ return await openai_complete_if_cache(
18
+ "solar-mini",
19
+ prompt,
20
+ system_prompt=system_prompt,
21
+ history_messages=history_messages,
22
+ api_key=os.getenv("UPSTAGE_API_KEY"),
23
+ base_url="https://api.upstage.ai/v1/solar",
24
+ **kwargs,
25
+ )
26
+
27
+
28
+ async def embedding_func(texts: list[str]) -> np.ndarray:
29
+ return await openai_embedding(
30
+ texts,
31
+ model="solar-embedding-1-large-query",
32
+ api_key=os.getenv("UPSTAGE_API_KEY"),
33
+ base_url="https://api.upstage.ai/v1/solar",
34
+ )
35
+
36
+
37
+ async def get_embedding_dim():
38
+ test_text = ["This is a test sentence."]
39
+ embedding = await embedding_func(test_text)
40
+ embedding_dim = embedding.shape[1]
41
+ return embedding_dim
42
+
43
+
44
+ # function test
45
+ async def test_funcs():
46
+ result = await llm_model_func("How are you?")
47
+ print("llm_model_func: ", result)
48
+
49
+ result = await embedding_func(["How are you?"])
50
+ print("embedding_func: ", result)
51
+
52
+
53
+ # asyncio.run(test_funcs())
54
+
55
+
56
+ async def main():
57
+ try:
58
+ embedding_dimension = await get_embedding_dim()
59
+ print(f"Detected embedding dimension: {embedding_dimension}")
60
+
61
+ rag = LightRAG(
62
+ working_dir=WORKING_DIR,
63
+ embedding_cache_config={
64
+ "enabled": True,
65
+ "similarity_threshold": 0.90, # 可以自定义阈值
66
+ },
67
+ llm_model_func=llm_model_func,
68
+ embedding_func=EmbeddingFunc(
69
+ embedding_dim=embedding_dimension,
70
+ max_token_size=8192,
71
+ func=embedding_func,
72
+ ),
73
+ )
74
+
75
+ with open("./book.txt", "r", encoding="utf-8") as f:
76
+ await rag.ainsert(f.read())
77
+
78
+ # Perform naive search
79
+ print(
80
+ await rag.aquery(
81
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
82
+ )
83
+ )
84
+
85
+ # Perform local search
86
+ print(
87
+ await rag.aquery(
88
+ "What are the top themes in this story?", param=QueryParam(mode="local")
89
+ )
90
+ )
91
+
92
+ # Perform global search
93
+ print(
94
+ await rag.aquery(
95
+ "What are the top themes in this story?",
96
+ param=QueryParam(mode="global"),
97
+ )
98
+ )
99
+
100
+ # Perform hybrid search
101
+ print(
102
+ await rag.aquery(
103
+ "What are the top themes in this story?",
104
+ param=QueryParam(mode="hybrid"),
105
+ )
106
+ )
107
+ except Exception as e:
108
+ print(f"An error occurred: {e}")
109
+
110
+
111
+ if __name__ == "__main__":
112
+ asyncio.run(main())
lightrag/lightrag.py CHANGED
@@ -85,7 +85,10 @@ class LightRAG:
85
  working_dir: str = field(
86
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
87
  )
88
-
 
 
 
89
  kv_storage: str = field(default="JsonKVStorage")
90
  vector_storage: str = field(default="NanoVectorDBStorage")
91
  graph_storage: str = field(default="NetworkXStorage")
 
85
  working_dir: str = field(
86
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
87
  )
88
+ # Default not to use embedding cache
89
+ embedding_cache_config: dict = field(
90
+ default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
91
+ )
92
  kv_storage: str = field(default="JsonKVStorage")
93
  vector_storage: str = field(default="NanoVectorDBStorage")
94
  graph_storage: str = field(default="NetworkXStorage")
lightrag/llm.py CHANGED
@@ -33,6 +33,8 @@ from .utils import (
33
  compute_args_hash,
34
  wrap_embedding_func_with_attrs,
35
  locate_json_string_body_from_string,
 
 
36
  )
37
 
38
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -65,10 +67,29 @@ async def openai_complete_if_cache(
65
  messages.extend(history_messages)
66
  messages.append({"role": "user", "content": prompt})
67
  if hashing_kv is not None:
68
- args_hash = compute_args_hash(model, messages)
69
- if_cache_return = await hashing_kv.get_by_id(args_hash)
70
- if if_cache_return is not None:
71
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  if "response_format" in kwargs:
74
  response = await openai_async_client.beta.chat.completions.parse(
@@ -81,10 +102,24 @@ async def openai_complete_if_cache(
81
  content = response.choices[0].message.content
82
  if r"\u" in content:
83
  content = content.encode("utf-8").decode("unicode_escape")
84
- # print(content)
85
  if hashing_kv is not None:
86
  await hashing_kv.upsert(
87
- {args_hash: {"return": response.choices[0].message.content, "model": model}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
  return content
90
 
@@ -125,10 +160,28 @@ async def azure_openai_complete_if_cache(
125
  if prompt is not None:
126
  messages.append({"role": "user", "content": prompt})
127
  if hashing_kv is not None:
128
- args_hash = compute_args_hash(model, messages)
129
- if_cache_return = await hashing_kv.get_by_id(args_hash)
130
- if if_cache_return is not None:
131
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  response = await openai_async_client.chat.completions.create(
134
  model=model, messages=messages, **kwargs
@@ -136,7 +189,21 @@ async def azure_openai_complete_if_cache(
136
 
137
  if hashing_kv is not None:
138
  await hashing_kv.upsert(
139
- {args_hash: {"return": response.choices[0].message.content, "model": model}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
  return response.choices[0].message.content
142
 
@@ -204,10 +271,29 @@ async def bedrock_complete_if_cache(
204
 
205
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
206
  if hashing_kv is not None:
207
- args_hash = compute_args_hash(model, messages)
208
- if_cache_return = await hashing_kv.get_by_id(args_hash)
209
- if if_cache_return is not None:
210
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # Call model via Converse API
213
  session = aioboto3.Session()
@@ -223,6 +309,19 @@ async def bedrock_complete_if_cache(
223
  args_hash: {
224
  "return": response["output"]["message"]["content"][0]["text"],
225
  "model": model,
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  }
227
  }
228
  )
@@ -245,7 +344,11 @@ def initialize_hf_model(model_name):
245
 
246
 
247
  async def hf_model_if_cache(
248
- model, prompt, system_prompt=None, history_messages=[], **kwargs
 
 
 
 
249
  ) -> str:
250
  model_name = model
251
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
@@ -257,10 +360,30 @@ async def hf_model_if_cache(
257
  messages.append({"role": "user", "content": prompt})
258
 
259
  if hashing_kv is not None:
260
- args_hash = compute_args_hash(model, messages)
261
- if_cache_return = await hashing_kv.get_by_id(args_hash)
262
- if if_cache_return is not None:
263
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  input_prompt = ""
265
  try:
266
  input_prompt = hf_tokenizer.apply_chat_template(
@@ -305,12 +428,32 @@ async def hf_model_if_cache(
305
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
306
  )
307
  if hashing_kv is not None:
308
- await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  return response_text
310
 
311
 
312
  async def ollama_model_if_cache(
313
- model, prompt, system_prompt=None, history_messages=[], **kwargs
 
 
 
 
314
  ) -> str:
315
  kwargs.pop("max_tokens", None)
316
  # kwargs.pop("response_format", None) # allow json
@@ -326,18 +469,52 @@ async def ollama_model_if_cache(
326
  messages.extend(history_messages)
327
  messages.append({"role": "user", "content": prompt})
328
  if hashing_kv is not None:
329
- args_hash = compute_args_hash(model, messages)
330
- if_cache_return = await hashing_kv.get_by_id(args_hash)
331
- if if_cache_return is not None:
332
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
335
 
336
  result = response["message"]["content"]
337
 
338
  if hashing_kv is not None:
339
- await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
340
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  return result
342
 
343
 
@@ -444,10 +621,29 @@ async def lmdeploy_model_if_cache(
444
  messages.extend(history_messages)
445
  messages.append({"role": "user", "content": prompt})
446
  if hashing_kv is not None:
447
- args_hash = compute_args_hash(model, messages)
448
- if_cache_return = await hashing_kv.get_by_id(args_hash)
449
- if if_cache_return is not None:
450
- return if_cache_return["return"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  gen_config = GenerationConfig(
453
  skip_special_tokens=skip_special_tokens,
@@ -466,7 +662,23 @@ async def lmdeploy_model_if_cache(
466
  response += res.response
467
 
468
  if hashing_kv is not None:
469
- await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  return response
471
 
472
 
 
33
  compute_args_hash,
34
  wrap_embedding_func_with_attrs,
35
  locate_json_string_body_from_string,
36
+ quantize_embedding,
37
+ get_best_cached_response,
38
  )
39
 
40
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
67
  messages.extend(history_messages)
68
  messages.append({"role": "user", "content": prompt})
69
  if hashing_kv is not None:
70
+ # Get embedding cache configuration
71
+ embedding_cache_config = hashing_kv.global_config.get(
72
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
73
+ )
74
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
75
+ if is_embedding_cache_enabled:
76
+ # Use embedding cache
77
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
78
+ current_embedding = await embedding_model_func([prompt])
79
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
80
+ best_cached_response = await get_best_cached_response(
81
+ hashing_kv,
82
+ current_embedding[0],
83
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
84
+ )
85
+ if best_cached_response is not None:
86
+ return best_cached_response
87
+ else:
88
+ # Use regular cache
89
+ args_hash = compute_args_hash(model, messages)
90
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
91
+ if if_cache_return is not None:
92
+ return if_cache_return["return"]
93
 
94
  if "response_format" in kwargs:
95
  response = await openai_async_client.beta.chat.completions.parse(
 
102
  content = response.choices[0].message.content
103
  if r"\u" in content:
104
  content = content.encode("utf-8").decode("unicode_escape")
105
+
106
  if hashing_kv is not None:
107
  await hashing_kv.upsert(
108
+ {
109
+ args_hash: {
110
+ "return": content,
111
+ "model": model,
112
+ "embedding": quantized.tobytes().hex()
113
+ if is_embedding_cache_enabled
114
+ else None,
115
+ "embedding_shape": quantized.shape
116
+ if is_embedding_cache_enabled
117
+ else None,
118
+ "embedding_min": min_val if is_embedding_cache_enabled else None,
119
+ "embedding_max": max_val if is_embedding_cache_enabled else None,
120
+ "original_prompt": prompt,
121
+ }
122
+ }
123
  )
124
  return content
125
 
 
160
  if prompt is not None:
161
  messages.append({"role": "user", "content": prompt})
162
  if hashing_kv is not None:
163
+ # Get embedding cache configuration
164
+ embedding_cache_config = hashing_kv.global_config.get(
165
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
166
+ )
167
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
168
+ if is_embedding_cache_enabled:
169
+ # Use embedding cache
170
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
171
+ current_embedding = await embedding_model_func([prompt])
172
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
173
+ best_cached_response = await get_best_cached_response(
174
+ hashing_kv,
175
+ current_embedding[0],
176
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
177
+ )
178
+ if best_cached_response is not None:
179
+ return best_cached_response
180
+ else:
181
+ args_hash = compute_args_hash(model, messages)
182
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
183
+ if if_cache_return is not None:
184
+ return if_cache_return["return"]
185
 
186
  response = await openai_async_client.chat.completions.create(
187
  model=model, messages=messages, **kwargs
 
189
 
190
  if hashing_kv is not None:
191
  await hashing_kv.upsert(
192
+ {
193
+ args_hash: {
194
+ "return": response.choices[0].message.content,
195
+ "model": model,
196
+ "embedding": quantized.tobytes().hex()
197
+ if is_embedding_cache_enabled
198
+ else None,
199
+ "embedding_shape": quantized.shape
200
+ if is_embedding_cache_enabled
201
+ else None,
202
+ "embedding_min": min_val if is_embedding_cache_enabled else None,
203
+ "embedding_max": max_val if is_embedding_cache_enabled else None,
204
+ "original_prompt": prompt,
205
+ }
206
+ }
207
  )
208
  return response.choices[0].message.content
209
 
 
271
 
272
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
273
  if hashing_kv is not None:
274
+ # Get embedding cache configuration
275
+ embedding_cache_config = hashing_kv.global_config.get(
276
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
277
+ )
278
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
279
+ if is_embedding_cache_enabled:
280
+ # Use embedding cache
281
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
282
+ current_embedding = await embedding_model_func([prompt])
283
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
284
+ best_cached_response = await get_best_cached_response(
285
+ hashing_kv,
286
+ current_embedding[0],
287
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
288
+ )
289
+ if best_cached_response is not None:
290
+ return best_cached_response
291
+ else:
292
+ # Use regular cache
293
+ args_hash = compute_args_hash(model, messages)
294
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
295
+ if if_cache_return is not None:
296
+ return if_cache_return["return"]
297
 
298
  # Call model via Converse API
299
  session = aioboto3.Session()
 
309
  args_hash: {
310
  "return": response["output"]["message"]["content"][0]["text"],
311
  "model": model,
312
+ "embedding": quantized.tobytes().hex()
313
+ if is_embedding_cache_enabled
314
+ else None,
315
+ "embedding_shape": quantized.shape
316
+ if is_embedding_cache_enabled
317
+ else None,
318
+ "embedding_min": min_val
319
+ if is_embedding_cache_enabled
320
+ else None,
321
+ "embedding_max": max_val
322
+ if is_embedding_cache_enabled
323
+ else None,
324
+ "original_prompt": prompt,
325
  }
326
  }
327
  )
 
344
 
345
 
346
  async def hf_model_if_cache(
347
+ model,
348
+ prompt,
349
+ system_prompt=None,
350
+ history_messages=[],
351
+ **kwargs,
352
  ) -> str:
353
  model_name = model
354
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
 
360
  messages.append({"role": "user", "content": prompt})
361
 
362
  if hashing_kv is not None:
363
+ # Get embedding cache configuration
364
+ embedding_cache_config = hashing_kv.global_config.get(
365
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
366
+ )
367
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
368
+ if is_embedding_cache_enabled:
369
+ # Use embedding cache
370
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
371
+ current_embedding = await embedding_model_func([prompt])
372
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
373
+ best_cached_response = await get_best_cached_response(
374
+ hashing_kv,
375
+ current_embedding[0],
376
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
377
+ )
378
+ if best_cached_response is not None:
379
+ return best_cached_response
380
+ else:
381
+ # Use regular cache
382
+ args_hash = compute_args_hash(model, messages)
383
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
384
+ if if_cache_return is not None:
385
+ return if_cache_return["return"]
386
+
387
  input_prompt = ""
388
  try:
389
  input_prompt = hf_tokenizer.apply_chat_template(
 
428
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
429
  )
430
  if hashing_kv is not None:
431
+ await hashing_kv.upsert(
432
+ {
433
+ args_hash: {
434
+ "return": response_text,
435
+ "model": model,
436
+ "embedding": quantized.tobytes().hex()
437
+ if is_embedding_cache_enabled
438
+ else None,
439
+ "embedding_shape": quantized.shape
440
+ if is_embedding_cache_enabled
441
+ else None,
442
+ "embedding_min": min_val if is_embedding_cache_enabled else None,
443
+ "embedding_max": max_val if is_embedding_cache_enabled else None,
444
+ "original_prompt": prompt,
445
+ }
446
+ }
447
+ )
448
  return response_text
449
 
450
 
451
  async def ollama_model_if_cache(
452
+ model,
453
+ prompt,
454
+ system_prompt=None,
455
+ history_messages=[],
456
+ **kwargs,
457
  ) -> str:
458
  kwargs.pop("max_tokens", None)
459
  # kwargs.pop("response_format", None) # allow json
 
469
  messages.extend(history_messages)
470
  messages.append({"role": "user", "content": prompt})
471
  if hashing_kv is not None:
472
+ # Get embedding cache configuration
473
+ embedding_cache_config = hashing_kv.global_config.get(
474
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
475
+ )
476
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
477
+ if is_embedding_cache_enabled:
478
+ # Use embedding cache
479
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
480
+ current_embedding = await embedding_model_func([prompt])
481
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
482
+ best_cached_response = await get_best_cached_response(
483
+ hashing_kv,
484
+ current_embedding[0],
485
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
486
+ )
487
+ if best_cached_response is not None:
488
+ return best_cached_response
489
+ else:
490
+ # Use regular cache
491
+ args_hash = compute_args_hash(model, messages)
492
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
493
+ if if_cache_return is not None:
494
+ return if_cache_return["return"]
495
 
496
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
497
 
498
  result = response["message"]["content"]
499
 
500
  if hashing_kv is not None:
501
+ await hashing_kv.upsert(
502
+ {
503
+ args_hash: {
504
+ "return": result,
505
+ "model": model,
506
+ "embedding": quantized.tobytes().hex()
507
+ if is_embedding_cache_enabled
508
+ else None,
509
+ "embedding_shape": quantized.shape
510
+ if is_embedding_cache_enabled
511
+ else None,
512
+ "embedding_min": min_val if is_embedding_cache_enabled else None,
513
+ "embedding_max": max_val if is_embedding_cache_enabled else None,
514
+ "original_prompt": prompt,
515
+ }
516
+ }
517
+ )
518
  return result
519
 
520
 
 
621
  messages.extend(history_messages)
622
  messages.append({"role": "user", "content": prompt})
623
  if hashing_kv is not None:
624
+ # Get embedding cache configuration
625
+ embedding_cache_config = hashing_kv.global_config.get(
626
+ "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
627
+ )
628
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
629
+ if is_embedding_cache_enabled:
630
+ # Use embedding cache
631
+ embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
632
+ current_embedding = await embedding_model_func([prompt])
633
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
634
+ best_cached_response = await get_best_cached_response(
635
+ hashing_kv,
636
+ current_embedding[0],
637
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
638
+ )
639
+ if best_cached_response is not None:
640
+ return best_cached_response
641
+ else:
642
+ # Use regular cache
643
+ args_hash = compute_args_hash(model, messages)
644
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
645
+ if if_cache_return is not None:
646
+ return if_cache_return["return"]
647
 
648
  gen_config = GenerationConfig(
649
  skip_special_tokens=skip_special_tokens,
 
662
  response += res.response
663
 
664
  if hashing_kv is not None:
665
+ await hashing_kv.upsert(
666
+ {
667
+ args_hash: {
668
+ "return": response,
669
+ "model": model,
670
+ "embedding": quantized.tobytes().hex()
671
+ if is_embedding_cache_enabled
672
+ else None,
673
+ "embedding_shape": quantized.shape
674
+ if is_embedding_cache_enabled
675
+ else None,
676
+ "embedding_min": min_val if is_embedding_cache_enabled else None,
677
+ "embedding_max": max_val if is_embedding_cache_enabled else None,
678
+ "original_prompt": prompt,
679
+ }
680
+ }
681
+ )
682
  return response
683
 
684
 
lightrag/utils.py CHANGED
@@ -307,3 +307,72 @@ def process_combine_contexts(hl, ll):
307
  combined_sources_result = "\n".join(combined_sources_result)
308
 
309
  return combined_sources_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  combined_sources_result = "\n".join(combined_sources_result)
308
 
309
  return combined_sources_result
310
+
311
+
312
+ async def get_best_cached_response(
313
+ hashing_kv, current_embedding, similarity_threshold=0.95
314
+ ):
315
+ """Get the cached response with highest similarity"""
316
+ try:
317
+ # Get all keys using list_keys()
318
+ all_keys = await hashing_kv.all_keys()
319
+ max_similarity = 0
320
+ best_cached_response = None
321
+
322
+ # Get cached data one by one
323
+ for key in all_keys:
324
+ cache_data = await hashing_kv.get_by_id(key)
325
+ if cache_data is None or "embedding" not in cache_data:
326
+ continue
327
+
328
+ # Convert cached embedding list to ndarray
329
+ cached_quantized = np.frombuffer(
330
+ bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
331
+ ).reshape(cache_data["embedding_shape"])
332
+ cached_embedding = dequantize_embedding(
333
+ cached_quantized,
334
+ cache_data["embedding_min"],
335
+ cache_data["embedding_max"],
336
+ )
337
+
338
+ similarity = cosine_similarity(current_embedding, cached_embedding)
339
+ if similarity > max_similarity:
340
+ max_similarity = similarity
341
+ best_cached_response = cache_data["return"]
342
+
343
+ if max_similarity > similarity_threshold:
344
+ return best_cached_response
345
+ return None
346
+
347
+ except Exception as e:
348
+ logger.warning(f"Error in get_best_cached_response: {e}")
349
+ return None
350
+
351
+
352
+ def cosine_similarity(v1, v2):
353
+ """Calculate cosine similarity between two vectors"""
354
+ dot_product = np.dot(v1, v2)
355
+ norm1 = np.linalg.norm(v1)
356
+ norm2 = np.linalg.norm(v2)
357
+ return dot_product / (norm1 * norm2)
358
+
359
+
360
+ def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
361
+ """Quantize embedding to specified bits"""
362
+ # Calculate min/max values for reconstruction
363
+ min_val = embedding.min()
364
+ max_val = embedding.max()
365
+
366
+ # Quantize to 0-255 range
367
+ scale = (2**bits - 1) / (max_val - min_val)
368
+ quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
369
+
370
+ return quantized, min_val, max_val
371
+
372
+
373
+ def dequantize_embedding(
374
+ quantized: np.ndarray, min_val: float, max_val: float, bits=8
375
+ ) -> np.ndarray:
376
+ """Restore quantized embedding"""
377
+ scale = (max_val - min_val) / (2**bits - 1)
378
+ return (quantized * scale + min_val).astype(np.float32)