zrguo commited on
Commit
a9d42e4
·
unverified ·
2 Parent(s): c7cae5d 6450edf

Merge pull request #693 from danielaskdd/fix-concurrent-problem

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -13,18 +13,6 @@ from fastapi import (
13
  from typing import Dict
14
  import threading
15
 
16
- # Global progress tracker
17
- scan_progress: Dict = {
18
- "is_scanning": False,
19
- "current_file": "",
20
- "indexed_count": 0,
21
- "total_files": 0,
22
- "progress": 0,
23
- }
24
-
25
- # Lock for thread-safe operations
26
- progress_lock = threading.Lock()
27
-
28
  import json
29
  import os
30
 
@@ -34,7 +22,7 @@ import logging
34
  import argparse
35
  import time
36
  import re
37
- from typing import List, Dict, Any, Optional, Union
38
  from lightrag import LightRAG, QueryParam
39
  from lightrag.api import __api_version__
40
 
@@ -57,8 +45,21 @@ import pipmaster as pm
57
 
58
  from dotenv import load_dotenv
59
 
 
60
  load_dotenv()
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def estimate_tokens(text: str) -> int:
64
  """Estimate the number of tokens in text
@@ -918,6 +919,12 @@ def create_app(args):
918
  vector_db_storage_cls_kwargs={
919
  "cosine_better_than_threshold": args.cosine_threshold
920
  },
 
 
 
 
 
 
921
  )
922
  else:
923
  rag = LightRAG(
@@ -941,6 +948,12 @@ def create_app(args):
941
  vector_db_storage_cls_kwargs={
942
  "cosine_better_than_threshold": args.cosine_threshold
943
  },
 
 
 
 
 
 
944
  )
945
 
946
  async def index_file(file_path: Union[str, Path]) -> None:
 
13
  from typing import Dict
14
  import threading
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import json
17
  import os
18
 
 
22
  import argparse
23
  import time
24
  import re
25
+ from typing import List, Any, Optional, Union
26
  from lightrag import LightRAG, QueryParam
27
  from lightrag.api import __api_version__
28
 
 
45
 
46
  from dotenv import load_dotenv
47
 
48
+ # Load environment variables
49
  load_dotenv()
50
 
51
+ # Global progress tracker
52
+ scan_progress: Dict = {
53
+ "is_scanning": False,
54
+ "current_file": "",
55
+ "indexed_count": 0,
56
+ "total_files": 0,
57
+ "progress": 0,
58
+ }
59
+
60
+ # Lock for thread-safe operations
61
+ progress_lock = threading.Lock()
62
+
63
 
64
  def estimate_tokens(text: str) -> int:
65
  """Estimate the number of tokens in text
 
919
  vector_db_storage_cls_kwargs={
920
  "cosine_better_than_threshold": args.cosine_threshold
921
  },
922
+ enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
923
+ embedding_cache_config={
924
+ "enabled": True,
925
+ "similarity_threshold": 0.95,
926
+ "use_llm_check": False,
927
+ },
928
  )
929
  else:
930
  rag = LightRAG(
 
948
  vector_db_storage_cls_kwargs={
949
  "cosine_better_than_threshold": args.cosine_threshold
950
  },
951
+ enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
952
+ embedding_cache_config={
953
+ "enabled": True,
954
+ "similarity_threshold": 0.95,
955
+ "use_llm_check": False,
956
+ },
957
  )
958
 
959
  async def index_file(file_path: Union[str, Path]) -> None:
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
76
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
77
 
78
  def __post_init__(self):
 
 
79
  # Use global config value if specified, otherwise use default
80
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
81
  self.cosine_better_than_threshold = config.get(
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
138
  embedding = await self.embedding_func([query])
139
  embedding = embedding[0]
140
  logger.info(
141
- f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
142
  )
143
  results = self._client.query(
144
  query=embedding,
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
210
  logger.error(f"Error deleting relations for {entity_name}: {e}")
211
 
212
  async def index_done_callback(self):
213
- self._client.save()
 
 
 
76
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
77
 
78
  def __post_init__(self):
79
+ # Initialize lock only for file operations
80
+ self._save_lock = asyncio.Lock()
81
  # Use global config value if specified, otherwise use default
82
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
83
  self.cosine_better_than_threshold = config.get(
 
140
  embedding = await self.embedding_func([query])
141
  embedding = embedding[0]
142
  logger.info(
143
+ f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
144
  )
145
  results = self._client.query(
146
  query=embedding,
 
212
  logger.error(f"Error deleting relations for {entity_name}: {e}")
213
 
214
  async def index_done_callback(self):
215
+ # Protect file write operation
216
+ async with self._save_lock:
217
+ self._client.save()
lightrag/lightrag.py CHANGED
@@ -231,7 +231,7 @@ class LightRAG:
231
 
232
  self.llm_response_cache = self.key_string_value_json_storage_cls(
233
  namespace="llm_response_cache",
234
- embedding_func=None,
235
  )
236
 
237
  ####
@@ -275,7 +275,7 @@ class LightRAG:
275
  else:
276
  hashing_kv = self.key_string_value_json_storage_cls(
277
  namespace="llm_response_cache",
278
- embedding_func=None,
279
  )
280
 
281
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
@@ -916,7 +916,7 @@ class LightRAG:
916
  else self.key_string_value_json_storage_cls(
917
  namespace="llm_response_cache",
918
  global_config=asdict(self),
919
- embedding_func=None,
920
  ),
921
  prompt=prompt,
922
  )
@@ -933,7 +933,7 @@ class LightRAG:
933
  else self.key_string_value_json_storage_cls(
934
  namespace="llm_response_cache",
935
  global_config=asdict(self),
936
- embedding_func=None,
937
  ),
938
  )
939
  elif param.mode == "mix":
@@ -952,7 +952,7 @@ class LightRAG:
952
  else self.key_string_value_json_storage_cls(
953
  namespace="llm_response_cache",
954
  global_config=asdict(self),
955
- embedding_func=None,
956
  ),
957
  )
958
  else:
@@ -993,7 +993,7 @@ class LightRAG:
993
  or self.key_string_value_json_storage_cls(
994
  namespace="llm_response_cache",
995
  global_config=asdict(self),
996
- embedding_func=None,
997
  ),
998
  )
999
 
@@ -1024,7 +1024,7 @@ class LightRAG:
1024
  else self.key_string_value_json_storage_cls(
1025
  namespace="llm_response_cache",
1026
  global_config=asdict(self),
1027
- embedding_func=None,
1028
  ),
1029
  )
1030
  elif param.mode == "naive":
@@ -1040,7 +1040,7 @@ class LightRAG:
1040
  else self.key_string_value_json_storage_cls(
1041
  namespace="llm_response_cache",
1042
  global_config=asdict(self),
1043
- embedding_func=None,
1044
  ),
1045
  )
1046
  elif param.mode == "mix":
@@ -1059,7 +1059,7 @@ class LightRAG:
1059
  else self.key_string_value_json_storage_cls(
1060
  namespace="llm_response_cache",
1061
  global_config=asdict(self),
1062
- embedding_func=None,
1063
  ),
1064
  )
1065
  else:
 
231
 
232
  self.llm_response_cache = self.key_string_value_json_storage_cls(
233
  namespace="llm_response_cache",
234
+ embedding_func=self.embedding_func,
235
  )
236
 
237
  ####
 
275
  else:
276
  hashing_kv = self.key_string_value_json_storage_cls(
277
  namespace="llm_response_cache",
278
+ embedding_func=self.embedding_func,
279
  )
280
 
281
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
 
916
  else self.key_string_value_json_storage_cls(
917
  namespace="llm_response_cache",
918
  global_config=asdict(self),
919
+ embedding_func=self.embedding_func,
920
  ),
921
  prompt=prompt,
922
  )
 
933
  else self.key_string_value_json_storage_cls(
934
  namespace="llm_response_cache",
935
  global_config=asdict(self),
936
+ embedding_func=self.embedding_func,
937
  ),
938
  )
939
  elif param.mode == "mix":
 
952
  else self.key_string_value_json_storage_cls(
953
  namespace="llm_response_cache",
954
  global_config=asdict(self),
955
+ embedding_func=self.embedding_func,
956
  ),
957
  )
958
  else:
 
993
  or self.key_string_value_json_storage_cls(
994
  namespace="llm_response_cache",
995
  global_config=asdict(self),
996
+ embedding_func=self.embedding_func,
997
  ),
998
  )
999
 
 
1024
  else self.key_string_value_json_storage_cls(
1025
  namespace="llm_response_cache",
1026
  global_config=asdict(self),
1027
+ embedding_func=self.embedding_funcne,
1028
  ),
1029
  )
1030
  elif param.mode == "naive":
 
1040
  else self.key_string_value_json_storage_cls(
1041
  namespace="llm_response_cache",
1042
  global_config=asdict(self),
1043
+ embedding_func=self.embedding_func,
1044
  ),
1045
  )
1046
  elif param.mode == "mix":
 
1059
  else self.key_string_value_json_storage_cls(
1060
  namespace="llm_response_cache",
1061
  global_config=asdict(self),
1062
+ embedding_func=self.embedding_func,
1063
  ),
1064
  )
1065
  else:
lightrag/operate.py CHANGED
@@ -352,16 +352,6 @@ async def extract_entities(
352
  input_text: str, history_messages: list[dict[str, str]] = None
353
  ) -> str:
354
  if enable_llm_cache_for_entity_extract and llm_response_cache:
355
- need_to_restore = False
356
- if (
357
- global_config["embedding_cache_config"]
358
- and global_config["embedding_cache_config"]["enabled"]
359
- ):
360
- new_config = global_config.copy()
361
- new_config["embedding_cache_config"] = None
362
- new_config["enable_llm_cache"] = True
363
- llm_response_cache.global_config = new_config
364
- need_to_restore = True
365
  if history_messages:
366
  history = json.dumps(history_messages, ensure_ascii=False)
367
  _prompt = history + "\n" + input_text
@@ -370,10 +360,13 @@ async def extract_entities(
370
 
371
  arg_hash = compute_args_hash(_prompt)
372
  cached_return, _1, _2, _3 = await handle_cache(
373
- llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
 
 
 
 
 
374
  )
375
- if need_to_restore:
376
- llm_response_cache.global_config = global_config
377
  if cached_return:
378
  logger.debug(f"Found cache for {arg_hash}")
379
  statistic_data["llm_cache"] += 1
@@ -387,7 +380,12 @@ async def extract_entities(
387
  res: str = await use_llm_func(input_text)
388
  await save_to_cache(
389
  llm_response_cache,
390
- CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
 
 
 
 
 
391
  )
392
  return res
393
 
@@ -740,7 +738,7 @@ async def extract_keywords_only(
740
  # 6. Parse out JSON from the LLM response
741
  match = re.search(r"\{.*\}", result, re.DOTALL)
742
  if not match:
743
- logger.error("No JSON-like structure found in the result.")
744
  return [], []
745
  try:
746
  keywords_data = json.loads(match.group(0))
@@ -752,20 +750,24 @@ async def extract_keywords_only(
752
  ll_keywords = keywords_data.get("low_level_keywords", [])
753
 
754
  # 7. Cache only the processed keywords with cache type
755
- cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
756
- await save_to_cache(
757
- hashing_kv,
758
- CacheData(
759
- args_hash=args_hash,
760
- content=json.dumps(cache_data),
761
- prompt=text,
762
- quantized=quantized,
763
- min_val=min_val,
764
- max_val=max_val,
765
- mode=param.mode,
766
- cache_type="keywords",
767
- ),
768
- )
 
 
 
 
769
  return hl_keywords, ll_keywords
770
 
771
 
 
352
  input_text: str, history_messages: list[dict[str, str]] = None
353
  ) -> str:
354
  if enable_llm_cache_for_entity_extract and llm_response_cache:
 
 
 
 
 
 
 
 
 
 
355
  if history_messages:
356
  history = json.dumps(history_messages, ensure_ascii=False)
357
  _prompt = history + "\n" + input_text
 
360
 
361
  arg_hash = compute_args_hash(_prompt)
362
  cached_return, _1, _2, _3 = await handle_cache(
363
+ llm_response_cache,
364
+ arg_hash,
365
+ _prompt,
366
+ "default",
367
+ cache_type="extract",
368
+ force_llm_cache=True,
369
  )
 
 
370
  if cached_return:
371
  logger.debug(f"Found cache for {arg_hash}")
372
  statistic_data["llm_cache"] += 1
 
380
  res: str = await use_llm_func(input_text)
381
  await save_to_cache(
382
  llm_response_cache,
383
+ CacheData(
384
+ args_hash=arg_hash,
385
+ content=res,
386
+ prompt=_prompt,
387
+ cache_type="extract",
388
+ ),
389
  )
390
  return res
391
 
 
738
  # 6. Parse out JSON from the LLM response
739
  match = re.search(r"\{.*\}", result, re.DOTALL)
740
  if not match:
741
+ logger.error("No JSON-like structure found in the LLM respond.")
742
  return [], []
743
  try:
744
  keywords_data = json.loads(match.group(0))
 
750
  ll_keywords = keywords_data.get("low_level_keywords", [])
751
 
752
  # 7. Cache only the processed keywords with cache type
753
+ if hl_keywords or ll_keywords:
754
+ cache_data = {
755
+ "high_level_keywords": hl_keywords,
756
+ "low_level_keywords": ll_keywords,
757
+ }
758
+ await save_to_cache(
759
+ hashing_kv,
760
+ CacheData(
761
+ args_hash=args_hash,
762
+ content=json.dumps(cache_data),
763
+ prompt=text,
764
+ quantized=quantized,
765
+ min_val=min_val,
766
+ max_val=max_val,
767
+ mode=param.mode,
768
+ cache_type="keywords",
769
+ ),
770
+ )
771
  return hl_keywords, ll_keywords
772
 
773
 
lightrag/prompt.py CHANGED
@@ -290,9 +290,8 @@ PROMPTS[
290
  Question 1: {original_prompt}
291
  Question 2: {cached_prompt}
292
 
293
- Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
294
- 1. Whether these two questions are semantically similar
295
- 2. Whether the answer to Question 2 can be used to answer Question 1
296
  Similarity score criteria:
297
  0: Completely unrelated or answer cannot be reused, including but not limited to:
298
  - The questions have different topics
 
290
  Question 1: {original_prompt}
291
  Question 2: {cached_prompt}
292
 
293
+ Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
294
+
 
295
  Similarity score criteria:
296
  0: Completely unrelated or answer cannot be reused, including but not limited to:
297
  - The questions have different topics
lightrag/utils.py CHANGED
@@ -58,17 +58,10 @@ class EmbeddingFunc:
58
  embedding_dim: int
59
  max_token_size: int
60
  func: callable
61
- concurrent_limit: int = 16
62
-
63
- def __post_init__(self):
64
- if self.concurrent_limit != 0:
65
- self._semaphore = asyncio.Semaphore(self.concurrent_limit)
66
- else:
67
- self._semaphore = UnlimitedSemaphore()
68
 
69
  async def __call__(self, *args, **kwargs) -> np.ndarray:
70
- async with self._semaphore:
71
- return await self.func(*args, **kwargs)
72
 
73
 
74
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
112
  """Compute a hash for the given arguments.
113
  Args:
114
  *args: Arguments to hash
115
- cache_type: Type of cache (e.g., 'keywords', 'query')
116
  Returns:
117
  str: Hash string
118
  """
@@ -131,22 +124,17 @@ def compute_mdhash_id(content, prefix: str = ""):
131
  return prefix + md5(content.encode()).hexdigest()
132
 
133
 
134
- def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
135
- """Add restriction of maximum async calling times for a async func"""
136
 
137
  def final_decro(func):
138
- """Not using async.Semaphore to aovid use nest-asyncio"""
139
- __current_size = 0
140
 
141
  @wraps(func)
142
  async def wait_func(*args, **kwargs):
143
- nonlocal __current_size
144
- while __current_size >= max_size:
145
- await asyncio.sleep(waitting_time)
146
- __current_size += 1
147
- result = await func(*args, **kwargs)
148
- __current_size -= 1
149
- return result
150
 
151
  return wait_func
152
 
@@ -380,6 +368,9 @@ async def get_best_cached_response(
380
  original_prompt=None,
381
  cache_type=None,
382
  ) -> Union[str, None]:
 
 
 
383
  mode_cache = await hashing_kv.get_by_id(mode)
384
  if not mode_cache:
385
  return None
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
470
  return dot_product / (norm1 * norm2)
471
 
472
 
473
- def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
474
  """Quantize embedding to specified bits"""
 
 
 
 
475
  # Calculate min/max values for reconstruction
476
  min_val = embedding.min()
477
  max_val = embedding.max()
@@ -491,59 +486,60 @@ def dequantize_embedding(
491
  return (quantized * scale + min_val).astype(np.float32)
492
 
493
 
494
- async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
 
 
 
 
 
 
 
495
  """Generic cache handling function"""
496
- if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
 
 
497
  return None, None, None, None
498
 
499
- # For default mode, only use simple cache matching
500
- if mode == "default":
501
- if exists_func(hashing_kv, "get_by_mode_and_id"):
502
- mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
503
- else:
504
- mode_cache = await hashing_kv.get_by_id(mode) or {}
505
- if args_hash in mode_cache:
506
- return mode_cache[args_hash]["return"], None, None, None
507
- return None, None, None, None
508
-
509
- # Get embedding cache configuration
510
- embedding_cache_config = hashing_kv.global_config.get(
511
- "embedding_cache_config",
512
- {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
513
- )
514
- is_embedding_cache_enabled = embedding_cache_config["enabled"]
515
- use_llm_check = embedding_cache_config.get("use_llm_check", False)
516
-
517
- quantized = min_val = max_val = None
518
- if is_embedding_cache_enabled:
519
- # Use embedding cache
520
- embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
521
- llm_model_func = hashing_kv.global_config.get("llm_model_func")
522
-
523
- current_embedding = await embedding_model_func([prompt])
524
- quantized, min_val, max_val = quantize_embedding(current_embedding[0])
525
- best_cached_response = await get_best_cached_response(
526
- hashing_kv,
527
- current_embedding[0],
528
- similarity_threshold=embedding_cache_config["similarity_threshold"],
529
- mode=mode,
530
- use_llm_check=use_llm_check,
531
- llm_func=llm_model_func if use_llm_check else None,
532
- original_prompt=prompt if use_llm_check else None,
533
- cache_type=cache_type,
534
  )
535
- if best_cached_response is not None:
536
- return best_cached_response, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  else:
538
- # Use regular cache
539
- if exists_func(hashing_kv, "get_by_mode_and_id"):
540
- mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
541
- else:
542
- mode_cache = await hashing_kv.get_by_id(mode) or {}
543
- if args_hash in mode_cache:
544
- return mode_cache[args_hash]["return"], None, None, None
545
 
546
- return None, quantized, min_val, max_val
547
 
548
 
549
  @dataclass
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
572
 
573
  mode_cache[cache_data.args_hash] = {
574
  "return": cache_data.content,
 
575
  "embedding": cache_data.quantized.tobytes().hex()
576
  if cache_data.quantized is not None
577
  else None,
 
58
  embedding_dim: int
59
  max_token_size: int
60
  func: callable
61
+ # concurrent_limit: int = 16
 
 
 
 
 
 
62
 
63
  async def __call__(self, *args, **kwargs) -> np.ndarray:
64
+ return await self.func(*args, **kwargs)
 
65
 
66
 
67
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
 
105
  """Compute a hash for the given arguments.
106
  Args:
107
  *args: Arguments to hash
108
+ cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
109
  Returns:
110
  str: Hash string
111
  """
 
124
  return prefix + md5(content.encode()).hexdigest()
125
 
126
 
127
+ def limit_async_func_call(max_size: int):
128
+ """Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
129
 
130
  def final_decro(func):
131
+ sem = asyncio.Semaphore(max_size)
 
132
 
133
  @wraps(func)
134
  async def wait_func(*args, **kwargs):
135
+ async with sem:
136
+ result = await func(*args, **kwargs)
137
+ return result
 
 
 
 
138
 
139
  return wait_func
140
 
 
368
  original_prompt=None,
369
  cache_type=None,
370
  ) -> Union[str, None]:
371
+ logger.debug(
372
+ f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
373
+ )
374
  mode_cache = await hashing_kv.get_by_id(mode)
375
  if not mode_cache:
376
  return None
 
461
  return dot_product / (norm1 * norm2)
462
 
463
 
464
+ def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
465
  """Quantize embedding to specified bits"""
466
+ # Convert list to numpy array if needed
467
+ if isinstance(embedding, list):
468
+ embedding = np.array(embedding)
469
+
470
  # Calculate min/max values for reconstruction
471
  min_val = embedding.min()
472
  max_val = embedding.max()
 
486
  return (quantized * scale + min_val).astype(np.float32)
487
 
488
 
489
+ async def handle_cache(
490
+ hashing_kv,
491
+ args_hash,
492
+ prompt,
493
+ mode="default",
494
+ cache_type=None,
495
+ force_llm_cache=False,
496
+ ):
497
  """Generic cache handling function"""
498
+ if hashing_kv is None or not (
499
+ force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
500
+ ):
501
  return None, None, None, None
502
 
503
+ if mode != "default":
504
+ # Get embedding cache configuration
505
+ embedding_cache_config = hashing_kv.global_config.get(
506
+ "embedding_cache_config",
507
+ {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  )
509
+ is_embedding_cache_enabled = embedding_cache_config["enabled"]
510
+ use_llm_check = embedding_cache_config.get("use_llm_check", False)
511
+
512
+ quantized = min_val = max_val = None
513
+ if is_embedding_cache_enabled:
514
+ # Use embedding cache
515
+ current_embedding = await hashing_kv.embedding_func([prompt])
516
+ llm_model_func = hashing_kv.global_config.get("llm_model_func")
517
+ quantized, min_val, max_val = quantize_embedding(current_embedding[0])
518
+ best_cached_response = await get_best_cached_response(
519
+ hashing_kv,
520
+ current_embedding[0],
521
+ similarity_threshold=embedding_cache_config["similarity_threshold"],
522
+ mode=mode,
523
+ use_llm_check=use_llm_check,
524
+ llm_func=llm_model_func if use_llm_check else None,
525
+ original_prompt=prompt,
526
+ cache_type=cache_type,
527
+ )
528
+ if best_cached_response is not None:
529
+ return best_cached_response, None, None, None
530
+ else:
531
+ return None, quantized, min_val, max_val
532
+
533
+ # For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
534
+ # Use regular cache
535
+ if exists_func(hashing_kv, "get_by_mode_and_id"):
536
+ mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
537
  else:
538
+ mode_cache = await hashing_kv.get_by_id(mode) or {}
539
+ if args_hash in mode_cache:
540
+ return mode_cache[args_hash]["return"], None, None, None
 
 
 
 
541
 
542
+ return None, None, None, None
543
 
544
 
545
  @dataclass
 
568
 
569
  mode_cache[cache_data.args_hash] = {
570
  "return": cache_data.content,
571
+ "cache_type": cache_data.cache_type,
572
  "embedding": cache_data.quantized.tobytes().hex()
573
  if cache_data.quantized is not None
574
  else None,