Magicyuan commited on
Commit
f45e51c
·
1 Parent(s): abffd90

feat(cache): 增加 LLM 相似性检查功能并优化缓存机制

Browse files

- 在 embedding 缓存配置中添加 use_llm_check 参数
- 实现 LLM 相似性检查逻辑,作为缓存命中的二次验证- 优化 naive 模式的缓存处理流程
- 调整缓存数据结构,移除不必要的 model 字段

README.md CHANGED
@@ -596,11 +596,7 @@ if __name__ == "__main__":
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
599
- | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
600
- - `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
601
- - `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
602
-
603
- Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
604
 
605
  ## API Server Implementation
606
 
 
596
  | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
597
  | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
598
  | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
599
+ | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
 
 
 
 
600
 
601
  ## API Server Implementation
602
 
lightrag/lightrag.py CHANGED
@@ -87,7 +87,11 @@ class LightRAG:
87
  )
88
  # Default not to use embedding cache
89
  embedding_cache_config: dict = field(
90
- default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
 
 
 
 
91
  )
92
  kv_storage: str = field(default="JsonKVStorage")
93
  vector_storage: str = field(default="NanoVectorDBStorage")
@@ -174,7 +178,6 @@ class LightRAG:
174
  if self.enable_llm_cache
175
  else None
176
  )
177
-
178
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
179
  self.embedding_func
180
  )
@@ -481,6 +484,7 @@ class LightRAG:
481
  self.text_chunks,
482
  param,
483
  asdict(self),
 
484
  )
485
  elif param.mode == "naive":
486
  response = await naive_query(
@@ -489,6 +493,7 @@ class LightRAG:
489
  self.text_chunks,
490
  param,
491
  asdict(self),
 
492
  )
493
  else:
494
  raise ValueError(f"Unknown mode {param.mode}")
 
87
  )
88
  # Default not to use embedding cache
89
  embedding_cache_config: dict = field(
90
+ default_factory=lambda: {
91
+ "enabled": False,
92
+ "similarity_threshold": 0.95,
93
+ "use_llm_check": False,
94
+ }
95
  )
96
  kv_storage: str = field(default="JsonKVStorage")
97
  vector_storage: str = field(default="NanoVectorDBStorage")
 
178
  if self.enable_llm_cache
179
  else None
180
  )
 
181
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
182
  self.embedding_func
183
  )
 
484
  self.text_chunks,
485
  param,
486
  asdict(self),
487
+ hashing_kv=self.llm_response_cache,
488
  )
489
  elif param.mode == "naive":
490
  response = await naive_query(
 
493
  self.text_chunks,
494
  param,
495
  asdict(self),
496
+ hashing_kv=self.llm_response_cache,
497
  )
498
  else:
499
  raise ValueError(f"Unknown mode {param.mode}")
lightrag/llm.py CHANGED
@@ -4,8 +4,7 @@ import json
4
  import os
5
  import struct
6
  from functools import lru_cache
7
- from typing import List, Dict, Callable, Any, Union, Optional
8
- from dataclasses import dataclass
9
  import aioboto3
10
  import aiohttp
11
  import numpy as np
@@ -27,13 +26,9 @@ from tenacity import (
27
  )
28
  from transformers import AutoTokenizer, AutoModelForCausalLM
29
 
30
- from .base import BaseKVStorage
31
  from .utils import (
32
- compute_args_hash,
33
  wrap_embedding_func_with_attrs,
34
  locate_json_string_body_from_string,
35
- quantize_embedding,
36
- get_best_cached_response,
37
  )
38
 
39
  import sys
@@ -66,23 +61,13 @@ async def openai_complete_if_cache(
66
  openai_async_client = (
67
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
68
  )
69
-
70
  messages = []
71
  if system_prompt:
72
  messages.append({"role": "system", "content": system_prompt})
73
  messages.extend(history_messages)
74
  messages.append({"role": "user", "content": prompt})
75
 
76
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
77
- # Handle cache
78
- mode = kwargs.pop("mode", "default")
79
- args_hash = compute_args_hash(model, messages)
80
- cached_response, quantized, min_val, max_val = await handle_cache(
81
- hashing_kv, args_hash, prompt, mode
82
- )
83
- if cached_response is not None:
84
- return cached_response
85
-
86
  if "response_format" in kwargs:
87
  response = await openai_async_client.beta.chat.completions.parse(
88
  model=model, messages=messages, **kwargs
@@ -95,21 +80,6 @@ async def openai_complete_if_cache(
95
  if r"\u" in content:
96
  content = content.encode("utf-8").decode("unicode_escape")
97
 
98
- # Save to cache
99
- await save_to_cache(
100
- hashing_kv,
101
- CacheData(
102
- args_hash=args_hash,
103
- content=content,
104
- model=model,
105
- prompt=prompt,
106
- quantized=quantized,
107
- min_val=min_val,
108
- max_val=max_val,
109
- mode=mode,
110
- ),
111
- )
112
-
113
  return content
114
 
115
 
@@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache(
140
  api_key=os.getenv("AZURE_OPENAI_API_KEY"),
141
  api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
142
  )
143
-
144
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
145
- mode = kwargs.pop("mode", "default")
146
-
147
  messages = []
148
  if system_prompt:
149
  messages.append({"role": "system", "content": system_prompt})
@@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache(
151
  if prompt is not None:
152
  messages.append({"role": "user", "content": prompt})
153
 
154
- # Handle cache
155
- args_hash = compute_args_hash(model, messages)
156
- cached_response, quantized, min_val, max_val = await handle_cache(
157
- hashing_kv, args_hash, prompt, mode
158
- )
159
- if cached_response is not None:
160
- return cached_response
161
-
162
  response = await openai_async_client.chat.completions.create(
163
  model=model, messages=messages, **kwargs
164
  )
165
  content = response.choices[0].message.content
166
 
167
- # Save to cache
168
- await save_to_cache(
169
- hashing_kv,
170
- CacheData(
171
- args_hash=args_hash,
172
- content=content,
173
- model=model,
174
- prompt=prompt,
175
- quantized=quantized,
176
- min_val=min_val,
177
- max_val=max_val,
178
- mode=mode,
179
- ),
180
- )
181
-
182
  return content
183
 
184
 
@@ -210,7 +154,7 @@ async def bedrock_complete_if_cache(
210
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
211
  "AWS_SESSION_TOKEN", aws_session_token
212
  )
213
-
214
  # Fix message history format
215
  messages = []
216
  for history_message in history_messages:
@@ -220,15 +164,6 @@ async def bedrock_complete_if_cache(
220
 
221
  # Add user prompt
222
  messages.append({"role": "user", "content": [{"text": prompt}]})
223
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
224
- # Handle cache
225
- mode = kwargs.pop("mode", "default")
226
- args_hash = compute_args_hash(model, messages)
227
- cached_response, quantized, min_val, max_val = await handle_cache(
228
- hashing_kv, args_hash, prompt, mode
229
- )
230
- if cached_response is not None:
231
- return cached_response
232
 
233
  # Initialize Converse API arguments
234
  args = {"modelId": model, "messages": messages}
@@ -251,15 +186,6 @@ async def bedrock_complete_if_cache(
251
  args["inferenceConfig"][inference_params_map.get(param, param)] = (
252
  kwargs.pop(param)
253
  )
254
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
255
- # Handle cache
256
- mode = kwargs.pop("mode", "default")
257
- args_hash = compute_args_hash(model, messages)
258
- cached_response, quantized, min_val, max_val = await handle_cache(
259
- hashing_kv, args_hash, prompt, mode
260
- )
261
- if cached_response is not None:
262
- return cached_response
263
 
264
  # Call model via Converse API
265
  session = aioboto3.Session()
@@ -269,21 +195,6 @@ async def bedrock_complete_if_cache(
269
  except Exception as e:
270
  raise BedrockError(e)
271
 
272
- # Save to cache
273
- await save_to_cache(
274
- hashing_kv,
275
- CacheData(
276
- args_hash=args_hash,
277
- content=response["output"]["message"]["content"][0]["text"],
278
- model=model,
279
- prompt=prompt,
280
- quantized=quantized,
281
- min_val=min_val,
282
- max_val=max_val,
283
- mode=mode,
284
- ),
285
- )
286
-
287
  return response["output"]["message"]["content"][0]["text"]
288
 
289
 
@@ -315,22 +226,12 @@ async def hf_model_if_cache(
315
  ) -> str:
316
  model_name = model
317
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
318
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
319
  messages = []
320
  if system_prompt:
321
  messages.append({"role": "system", "content": system_prompt})
322
  messages.extend(history_messages)
323
  messages.append({"role": "user", "content": prompt})
324
-
325
- # Handle cache
326
- mode = kwargs.pop("mode", "default")
327
- args_hash = compute_args_hash(model, messages)
328
- cached_response, quantized, min_val, max_val = await handle_cache(
329
- hashing_kv, args_hash, prompt, mode
330
- )
331
- if cached_response is not None:
332
- return cached_response
333
-
334
  input_prompt = ""
335
  try:
336
  input_prompt = hf_tokenizer.apply_chat_template(
@@ -375,21 +276,6 @@ async def hf_model_if_cache(
375
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
376
  )
377
 
378
- # Save to cache
379
- await save_to_cache(
380
- hashing_kv,
381
- CacheData(
382
- args_hash=args_hash,
383
- content=response_text,
384
- model=model,
385
- prompt=prompt,
386
- quantized=quantized,
387
- min_val=min_val,
388
- max_val=max_val,
389
- mode=mode,
390
- ),
391
- )
392
-
393
  return response_text
394
 
395
 
@@ -410,25 +296,14 @@ async def ollama_model_if_cache(
410
  # kwargs.pop("response_format", None) # allow json
411
  host = kwargs.pop("host", None)
412
  timeout = kwargs.pop("timeout", None)
413
-
414
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
415
  messages = []
416
  if system_prompt:
417
  messages.append({"role": "system", "content": system_prompt})
418
-
419
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
420
  messages.extend(history_messages)
421
  messages.append({"role": "user", "content": prompt})
422
 
423
- # Handle cache
424
- mode = kwargs.pop("mode", "default")
425
- args_hash = compute_args_hash(model, messages)
426
- cached_response, quantized, min_val, max_val = await handle_cache(
427
- hashing_kv, args_hash, prompt, mode
428
- )
429
- if cached_response is not None:
430
- return cached_response
431
-
432
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
433
  if stream:
434
  """ cannot cache stream response """
@@ -441,38 +316,7 @@ async def ollama_model_if_cache(
441
  else:
442
  result = response["message"]["content"]
443
  # Save to cache
444
- await save_to_cache(
445
- hashing_kv,
446
- CacheData(
447
- args_hash=args_hash,
448
- content=result,
449
- model=model,
450
- prompt=prompt,
451
- quantized=quantized,
452
- min_val=min_val,
453
- max_val=max_val,
454
- mode=mode,
455
- ),
456
- )
457
  return result
458
- result = response["message"]["content"]
459
-
460
- # Save to cache
461
- await save_to_cache(
462
- hashing_kv,
463
- CacheData(
464
- args_hash=args_hash,
465
- content=result,
466
- model=model,
467
- prompt=prompt,
468
- quantized=quantized,
469
- min_val=min_val,
470
- max_val=max_val,
471
- mode=mode,
472
- ),
473
- )
474
-
475
- return result
476
 
477
 
478
  @lru_cache(maxsize=1)
@@ -547,7 +391,7 @@ async def lmdeploy_model_if_cache(
547
  from lmdeploy import version_info, GenerationConfig
548
  except Exception:
549
  raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
550
-
551
  kwargs.pop("response_format", None)
552
  max_new_tokens = kwargs.pop("max_tokens", 512)
553
  tp = kwargs.pop("tp", 1)
@@ -579,19 +423,9 @@ async def lmdeploy_model_if_cache(
579
  if system_prompt:
580
  messages.append({"role": "system", "content": system_prompt})
581
 
582
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
583
  messages.extend(history_messages)
584
  messages.append({"role": "user", "content": prompt})
585
 
586
- # Handle cache
587
- mode = kwargs.pop("mode", "default")
588
- args_hash = compute_args_hash(model, messages)
589
- cached_response, quantized, min_val, max_val = await handle_cache(
590
- hashing_kv, args_hash, prompt, mode
591
- )
592
- if cached_response is not None:
593
- return cached_response
594
-
595
  gen_config = GenerationConfig(
596
  skip_special_tokens=skip_special_tokens,
597
  max_new_tokens=max_new_tokens,
@@ -607,22 +441,6 @@ async def lmdeploy_model_if_cache(
607
  session_id=1,
608
  ):
609
  response += res.response
610
-
611
- # Save to cache
612
- await save_to_cache(
613
- hashing_kv,
614
- CacheData(
615
- args_hash=args_hash,
616
- content=response,
617
- model=model,
618
- prompt=prompt,
619
- quantized=quantized,
620
- min_val=min_val,
621
- max_val=max_val,
622
- mode=mode,
623
- ),
624
- )
625
-
626
  return response
627
 
628
 
@@ -1052,75 +870,6 @@ class MultiModel:
1052
  return await next_model.gen_func(**args)
1053
 
1054
 
1055
- async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
1056
- """Generic cache handling function"""
1057
- if hashing_kv is None:
1058
- return None, None, None, None
1059
-
1060
- # Get embedding cache configuration
1061
- embedding_cache_config = hashing_kv.global_config.get(
1062
- "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
1063
- )
1064
- is_embedding_cache_enabled = embedding_cache_config["enabled"]
1065
-
1066
- quantized = min_val = max_val = None
1067
- if is_embedding_cache_enabled:
1068
- # Use embedding cache
1069
- embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
1070
- current_embedding = await embedding_model_func([prompt])
1071
- quantized, min_val, max_val = quantize_embedding(current_embedding[0])
1072
- best_cached_response = await get_best_cached_response(
1073
- hashing_kv,
1074
- current_embedding[0],
1075
- similarity_threshold=embedding_cache_config["similarity_threshold"],
1076
- mode=mode,
1077
- )
1078
- if best_cached_response is not None:
1079
- return best_cached_response, None, None, None
1080
- else:
1081
- # Use regular cache
1082
- mode_cache = await hashing_kv.get_by_id(mode) or {}
1083
- if args_hash in mode_cache:
1084
- return mode_cache[args_hash]["return"], None, None, None
1085
-
1086
- return None, quantized, min_val, max_val
1087
-
1088
-
1089
- @dataclass
1090
- class CacheData:
1091
- args_hash: str
1092
- content: str
1093
- model: str
1094
- prompt: str
1095
- quantized: Optional[np.ndarray] = None
1096
- min_val: Optional[float] = None
1097
- max_val: Optional[float] = None
1098
- mode: str = "default"
1099
-
1100
-
1101
- async def save_to_cache(hashing_kv, cache_data: CacheData):
1102
- if hashing_kv is None:
1103
- return
1104
-
1105
- mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
1106
-
1107
- mode_cache[cache_data.args_hash] = {
1108
- "return": cache_data.content,
1109
- "model": cache_data.model,
1110
- "embedding": cache_data.quantized.tobytes().hex()
1111
- if cache_data.quantized is not None
1112
- else None,
1113
- "embedding_shape": cache_data.quantized.shape
1114
- if cache_data.quantized is not None
1115
- else None,
1116
- "embedding_min": cache_data.min_val,
1117
- "embedding_max": cache_data.max_val,
1118
- "original_prompt": cache_data.prompt,
1119
- }
1120
-
1121
- await hashing_kv.upsert({cache_data.mode: mode_cache})
1122
-
1123
-
1124
  if __name__ == "__main__":
1125
  import asyncio
1126
 
 
4
  import os
5
  import struct
6
  from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any, Union
 
8
  import aioboto3
9
  import aiohttp
10
  import numpy as np
 
26
  )
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
 
 
29
  from .utils import (
 
30
  wrap_embedding_func_with_attrs,
31
  locate_json_string_body_from_string,
 
 
32
  )
33
 
34
  import sys
 
61
  openai_async_client = (
62
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
63
  )
64
+ kwargs.pop("hashing_kv", None)
65
  messages = []
66
  if system_prompt:
67
  messages.append({"role": "system", "content": system_prompt})
68
  messages.extend(history_messages)
69
  messages.append({"role": "user", "content": prompt})
70
 
 
 
 
 
 
 
 
 
 
 
71
  if "response_format" in kwargs:
72
  response = await openai_async_client.beta.chat.completions.parse(
73
  model=model, messages=messages, **kwargs
 
80
  if r"\u" in content:
81
  content = content.encode("utf-8").decode("unicode_escape")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return content
84
 
85
 
 
110
  api_key=os.getenv("AZURE_OPENAI_API_KEY"),
111
  api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
112
  )
113
+ kwargs.pop("hashing_kv", None)
 
 
 
114
  messages = []
115
  if system_prompt:
116
  messages.append({"role": "system", "content": system_prompt})
 
118
  if prompt is not None:
119
  messages.append({"role": "user", "content": prompt})
120
 
 
 
 
 
 
 
 
 
121
  response = await openai_async_client.chat.completions.create(
122
  model=model, messages=messages, **kwargs
123
  )
124
  content = response.choices[0].message.content
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return content
127
 
128
 
 
154
  os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
155
  "AWS_SESSION_TOKEN", aws_session_token
156
  )
157
+ kwargs.pop("hashing_kv", None)
158
  # Fix message history format
159
  messages = []
160
  for history_message in history_messages:
 
164
 
165
  # Add user prompt
166
  messages.append({"role": "user", "content": [{"text": prompt}]})
 
 
 
 
 
 
 
 
 
167
 
168
  # Initialize Converse API arguments
169
  args = {"modelId": model, "messages": messages}
 
186
  args["inferenceConfig"][inference_params_map.get(param, param)] = (
187
  kwargs.pop(param)
188
  )
 
 
 
 
 
 
 
 
 
189
 
190
  # Call model via Converse API
191
  session = aioboto3.Session()
 
195
  except Exception as e:
196
  raise BedrockError(e)
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  return response["output"]["message"]["content"][0]["text"]
199
 
200
 
 
226
  ) -> str:
227
  model_name = model
228
  hf_model, hf_tokenizer = initialize_hf_model(model_name)
 
229
  messages = []
230
  if system_prompt:
231
  messages.append({"role": "system", "content": system_prompt})
232
  messages.extend(history_messages)
233
  messages.append({"role": "user", "content": prompt})
234
+ kwargs.pop("hashing_kv", None)
 
 
 
 
 
 
 
 
 
235
  input_prompt = ""
236
  try:
237
  input_prompt = hf_tokenizer.apply_chat_template(
 
276
  output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
277
  )
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  return response_text
280
 
281
 
 
296
  # kwargs.pop("response_format", None) # allow json
297
  host = kwargs.pop("host", None)
298
  timeout = kwargs.pop("timeout", None)
299
+ kwargs.pop("hashing_kv", None)
300
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
301
  messages = []
302
  if system_prompt:
303
  messages.append({"role": "system", "content": system_prompt})
 
 
304
  messages.extend(history_messages)
305
  messages.append({"role": "user", "content": prompt})
306
 
 
 
 
 
 
 
 
 
 
307
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
308
  if stream:
309
  """ cannot cache stream response """
 
316
  else:
317
  result = response["message"]["content"]
318
  # Save to cache
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  @lru_cache(maxsize=1)
 
391
  from lmdeploy import version_info, GenerationConfig
392
  except Exception:
393
  raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
394
+ kwargs.pop("hashing_kv", None)
395
  kwargs.pop("response_format", None)
396
  max_new_tokens = kwargs.pop("max_tokens", 512)
397
  tp = kwargs.pop("tp", 1)
 
423
  if system_prompt:
424
  messages.append({"role": "system", "content": system_prompt})
425
 
 
426
  messages.extend(history_messages)
427
  messages.append({"role": "user", "content": prompt})
428
 
 
 
 
 
 
 
 
 
 
429
  gen_config = GenerationConfig(
430
  skip_special_tokens=skip_special_tokens,
431
  max_new_tokens=max_new_tokens,
 
441
  session_id=1,
442
  ):
443
  response += res.response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  return response
445
 
446
 
 
870
  return await next_model.gen_func(**args)
871
 
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  if __name__ == "__main__":
874
  import asyncio
875
 
lightrag/operate.py CHANGED
@@ -17,6 +17,10 @@ from .utils import (
17
  split_string_by_multi_markers,
18
  truncate_list_by_token_size,
19
  process_combine_contexts,
 
 
 
 
20
  )
21
  from .base import (
22
  BaseGraphStorage,
@@ -452,8 +456,17 @@ async def kg_query(
452
  text_chunks_db: BaseKVStorage[TextChunkSchema],
453
  query_param: QueryParam,
454
  global_config: dict,
 
455
  ) -> str:
456
- context = None
 
 
 
 
 
 
 
 
457
  example_number = global_config["addon_params"].get("example_number", None)
458
  if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
459
  examples = "\n".join(
@@ -471,12 +484,9 @@ async def kg_query(
471
  return PROMPTS["fail_response"]
472
 
473
  # LLM generate keywords
474
- use_model_func = global_config["llm_model_func"]
475
  kw_prompt_temp = PROMPTS["keywords_extraction"]
476
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
477
- result = await use_model_func(
478
- kw_prompt, keyword_extraction=True, mode=query_param.mode
479
- )
480
  logger.info("kw_prompt result:")
481
  print(result)
482
  try:
@@ -537,7 +547,6 @@ async def kg_query(
537
  query,
538
  system_prompt=sys_prompt,
539
  stream=query_param.stream,
540
- mode=query_param.mode,
541
  )
542
  if isinstance(response, str) and len(response) > len(sys_prompt):
543
  response = (
@@ -550,6 +559,20 @@ async def kg_query(
550
  .strip()
551
  )
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  return response
554
 
555
 
@@ -1013,8 +1036,17 @@ async def naive_query(
1013
  text_chunks_db: BaseKVStorage[TextChunkSchema],
1014
  query_param: QueryParam,
1015
  global_config: dict,
 
1016
  ):
 
1017
  use_model_func = global_config["llm_model_func"]
 
 
 
 
 
 
 
1018
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1019
  if not len(results):
1020
  return PROMPTS["fail_response"]
@@ -1039,7 +1071,6 @@ async def naive_query(
1039
  response = await use_model_func(
1040
  query,
1041
  system_prompt=sys_prompt,
1042
- mode=query_param.mode,
1043
  )
1044
 
1045
  if len(response) > len(sys_prompt):
@@ -1054,4 +1085,18 @@ async def naive_query(
1054
  .strip()
1055
  )
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  return response
 
17
  split_string_by_multi_markers,
18
  truncate_list_by_token_size,
19
  process_combine_contexts,
20
+ compute_args_hash,
21
+ handle_cache,
22
+ save_to_cache,
23
+ CacheData,
24
  )
25
  from .base import (
26
  BaseGraphStorage,
 
456
  text_chunks_db: BaseKVStorage[TextChunkSchema],
457
  query_param: QueryParam,
458
  global_config: dict,
459
+ hashing_kv: BaseKVStorage = None,
460
  ) -> str:
461
+ # Handle cache
462
+ use_model_func = global_config["llm_model_func"]
463
+ args_hash = compute_args_hash(query_param.mode, query)
464
+ cached_response, quantized, min_val, max_val = await handle_cache(
465
+ hashing_kv, args_hash, query, query_param.mode
466
+ )
467
+ if cached_response is not None:
468
+ return cached_response
469
+
470
  example_number = global_config["addon_params"].get("example_number", None)
471
  if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
472
  examples = "\n".join(
 
484
  return PROMPTS["fail_response"]
485
 
486
  # LLM generate keywords
 
487
  kw_prompt_temp = PROMPTS["keywords_extraction"]
488
  kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
489
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
 
 
490
  logger.info("kw_prompt result:")
491
  print(result)
492
  try:
 
547
  query,
548
  system_prompt=sys_prompt,
549
  stream=query_param.stream,
 
550
  )
551
  if isinstance(response, str) and len(response) > len(sys_prompt):
552
  response = (
 
559
  .strip()
560
  )
561
 
562
+ # Save to cache
563
+ await save_to_cache(
564
+ hashing_kv,
565
+ CacheData(
566
+ args_hash=args_hash,
567
+ content=response,
568
+ prompt=query,
569
+ quantized=quantized,
570
+ min_val=min_val,
571
+ max_val=max_val,
572
+ mode=query_param.mode,
573
+ ),
574
+ )
575
+
576
  return response
577
 
578
 
 
1036
  text_chunks_db: BaseKVStorage[TextChunkSchema],
1037
  query_param: QueryParam,
1038
  global_config: dict,
1039
+ hashing_kv: BaseKVStorage = None,
1040
  ):
1041
+ # Handle cache
1042
  use_model_func = global_config["llm_model_func"]
1043
+ args_hash = compute_args_hash(query_param.mode, query)
1044
+ cached_response, quantized, min_val, max_val = await handle_cache(
1045
+ hashing_kv, args_hash, query, query_param.mode
1046
+ )
1047
+ if cached_response is not None:
1048
+ return cached_response
1049
+
1050
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1051
  if not len(results):
1052
  return PROMPTS["fail_response"]
 
1071
  response = await use_model_func(
1072
  query,
1073
  system_prompt=sys_prompt,
 
1074
  )
1075
 
1076
  if len(response) > len(sys_prompt):
 
1085
  .strip()
1086
  )
1087
 
1088
+ # Save to cache
1089
+ await save_to_cache(
1090
+ hashing_kv,
1091
+ CacheData(
1092
+ args_hash=args_hash,
1093
+ content=response,
1094
+ prompt=query,
1095
+ quantized=quantized,
1096
+ min_val=min_val,
1097
+ max_val=max_val,
1098
+ mode=query_param.mode,
1099
+ ),
1100
+ )
1101
+
1102
  return response
lightrag/prompt.py CHANGED
@@ -261,3 +261,22 @@ Do not include information where the supporting evidence for it is not provided.
261
 
262
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
263
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
263
  """
264
+
265
+ PROMPTS[
266
+ "similarity_check"
267
+ ] = """Please analyze the similarity between these two questions:
268
+
269
+ Question 1: {original_prompt}
270
+ Question 2: {cached_prompt}
271
+
272
+ Please evaluate:
273
+ 1. Whether these two questions are semantically similar
274
+ 2. Whether the answer to Question 2 can be used to answer Question 1
275
+
276
+ Please provide a similarity score between 0 and 1, where:
277
+ 0: Completely unrelated or answer cannot be reused
278
+ 1: Identical and answer can be directly reused
279
+ 0.5: Partially related and answer needs modification to be used
280
+
281
+ Return only a number between 0-1, without any additional content.
282
+ """
lightrag/utils.py CHANGED
@@ -15,6 +15,8 @@ import xml.etree.ElementTree as ET
15
  import numpy as np
16
  import tiktoken
17
 
 
 
18
  ENCODER = None
19
 
20
  logger = logging.getLogger("lightrag")
@@ -314,6 +316,9 @@ async def get_best_cached_response(
314
  current_embedding,
315
  similarity_threshold=0.95,
316
  mode="default",
 
 
 
317
  ) -> Union[str, None]:
318
  # Get mode-specific cache
319
  mode_cache = await hashing_kv.get_by_id(mode)
@@ -348,6 +353,37 @@ async def get_best_cached_response(
348
  best_cache_id = cache_id
349
 
350
  if best_similarity > similarity_threshold:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  prompt_display = (
352
  best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
353
  )
@@ -391,21 +427,33 @@ def dequantize_embedding(
391
  scale = (max_val - min_val) / (2**bits - 1)
392
  return (quantized * scale + min_val).astype(np.float32)
393
 
 
394
  async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
395
  """Generic cache handling function"""
396
  if hashing_kv is None:
397
  return None, None, None, None
398
 
 
 
 
 
 
 
 
399
  # Get embedding cache configuration
400
  embedding_cache_config = hashing_kv.global_config.get(
401
- "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
 
402
  )
403
  is_embedding_cache_enabled = embedding_cache_config["enabled"]
 
404
 
405
  quantized = min_val = max_val = None
406
  if is_embedding_cache_enabled:
407
  # Use embedding cache
408
  embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
 
 
409
  current_embedding = await embedding_model_func([prompt])
410
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
411
  best_cached_response = await get_best_cached_response(
@@ -413,6 +461,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
413
  current_embedding[0],
414
  similarity_threshold=embedding_cache_config["similarity_threshold"],
415
  mode=mode,
 
 
 
416
  )
417
  if best_cached_response is not None:
418
  return best_cached_response, None, None, None
@@ -429,7 +480,6 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
429
  class CacheData:
430
  args_hash: str
431
  content: str
432
- model: str
433
  prompt: str
434
  quantized: Optional[np.ndarray] = None
435
  min_val: Optional[float] = None
@@ -445,7 +495,6 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
445
 
446
  mode_cache[cache_data.args_hash] = {
447
  "return": cache_data.content,
448
- "model": cache_data.model,
449
  "embedding": cache_data.quantized.tobytes().hex()
450
  if cache_data.quantized is not None
451
  else None,
 
15
  import numpy as np
16
  import tiktoken
17
 
18
+ from lightrag.prompt import PROMPTS
19
+
20
  ENCODER = None
21
 
22
  logger = logging.getLogger("lightrag")
 
316
  current_embedding,
317
  similarity_threshold=0.95,
318
  mode="default",
319
+ use_llm_check=False,
320
+ llm_func=None,
321
+ original_prompt=None,
322
  ) -> Union[str, None]:
323
  # Get mode-specific cache
324
  mode_cache = await hashing_kv.get_by_id(mode)
 
353
  best_cache_id = cache_id
354
 
355
  if best_similarity > similarity_threshold:
356
+ # If LLM check is enabled and all required parameters are provided
357
+ if use_llm_check and llm_func and original_prompt and best_prompt:
358
+ compare_prompt = PROMPTS["similarity_check"].format(
359
+ original_prompt=original_prompt, cached_prompt=best_prompt
360
+ )
361
+
362
+ try:
363
+ llm_result = await llm_func(compare_prompt)
364
+ llm_result = llm_result.strip()
365
+ llm_similarity = float(llm_result)
366
+
367
+ # Replace vector similarity with LLM similarity score
368
+ best_similarity = llm_similarity
369
+ if best_similarity < similarity_threshold:
370
+ log_data = {
371
+ "event": "llm_check_cache_rejected",
372
+ "original_question": original_prompt[:100] + "..."
373
+ if len(original_prompt) > 100
374
+ else original_prompt,
375
+ "cached_question": best_prompt[:100] + "..."
376
+ if len(best_prompt) > 100
377
+ else best_prompt,
378
+ "similarity_score": round(best_similarity, 4),
379
+ "threshold": similarity_threshold,
380
+ }
381
+ logger.info(json.dumps(log_data, ensure_ascii=False))
382
+ return None
383
+ except Exception as e: # Catch all possible exceptions
384
+ logger.warning(f"LLM similarity check failed: {e}")
385
+ return None # Return None directly when LLM check fails
386
+
387
  prompt_display = (
388
  best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
389
  )
 
427
  scale = (max_val - min_val) / (2**bits - 1)
428
  return (quantized * scale + min_val).astype(np.float32)
429
 
430
+
431
  async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
432
  """Generic cache handling function"""
433
  if hashing_kv is None:
434
  return None, None, None, None
435
 
436
+ # For naive mode, only use simple cache matching
437
+ if mode == "naive":
438
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
439
+ if args_hash in mode_cache:
440
+ return mode_cache[args_hash]["return"], None, None, None
441
+ return None, None, None, None
442
+
443
  # Get embedding cache configuration
444
  embedding_cache_config = hashing_kv.global_config.get(
445
+ "embedding_cache_config",
446
+ {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
447
  )
448
  is_embedding_cache_enabled = embedding_cache_config["enabled"]
449
+ use_llm_check = embedding_cache_config.get("use_llm_check", False)
450
 
451
  quantized = min_val = max_val = None
452
  if is_embedding_cache_enabled:
453
  # Use embedding cache
454
  embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
455
+ llm_model_func = hashing_kv.global_config.get("llm_model_func")
456
+
457
  current_embedding = await embedding_model_func([prompt])
458
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
459
  best_cached_response = await get_best_cached_response(
 
461
  current_embedding[0],
462
  similarity_threshold=embedding_cache_config["similarity_threshold"],
463
  mode=mode,
464
+ use_llm_check=use_llm_check,
465
+ llm_func=llm_model_func if use_llm_check else None,
466
+ original_prompt=prompt if use_llm_check else None,
467
  )
468
  if best_cached_response is not None:
469
  return best_cached_response, None, None, None
 
480
  class CacheData:
481
  args_hash: str
482
  content: str
 
483
  prompt: str
484
  quantized: Optional[np.ndarray] = None
485
  min_val: Optional[float] = None
 
495
 
496
  mode_cache[cache_data.args_hash] = {
497
  "return": cache_data.content,
 
498
  "embedding": cache_data.quantized.tobytes().hex()
499
  if cache_data.quantized is not None
500
  else None,