Merge pull request #693 from danielaskdd/fix-concurrent-problem
Browse files- lightrag/api/lightrag_server.py +26 -13
- lightrag/kg/nano_vector_db_impl.py +6 -2
- lightrag/lightrag.py +9 -9
- lightrag/operate.py +31 -29
- lightrag/prompt.py +2 -3
- lightrag/utils.py +66 -69
lightrag/api/lightrag_server.py
CHANGED
@@ -13,18 +13,6 @@ from fastapi import (
|
|
13 |
from typing import Dict
|
14 |
import threading
|
15 |
|
16 |
-
# Global progress tracker
|
17 |
-
scan_progress: Dict = {
|
18 |
-
"is_scanning": False,
|
19 |
-
"current_file": "",
|
20 |
-
"indexed_count": 0,
|
21 |
-
"total_files": 0,
|
22 |
-
"progress": 0,
|
23 |
-
}
|
24 |
-
|
25 |
-
# Lock for thread-safe operations
|
26 |
-
progress_lock = threading.Lock()
|
27 |
-
|
28 |
import json
|
29 |
import os
|
30 |
|
@@ -34,7 +22,7 @@ import logging
|
|
34 |
import argparse
|
35 |
import time
|
36 |
import re
|
37 |
-
from typing import List,
|
38 |
from lightrag import LightRAG, QueryParam
|
39 |
from lightrag.api import __api_version__
|
40 |
|
@@ -57,8 +45,21 @@ import pipmaster as pm
|
|
57 |
|
58 |
from dotenv import load_dotenv
|
59 |
|
|
|
60 |
load_dotenv()
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def estimate_tokens(text: str) -> int:
|
64 |
"""Estimate the number of tokens in text
|
@@ -918,6 +919,12 @@ def create_app(args):
|
|
918 |
vector_db_storage_cls_kwargs={
|
919 |
"cosine_better_than_threshold": args.cosine_threshold
|
920 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
921 |
)
|
922 |
else:
|
923 |
rag = LightRAG(
|
@@ -941,6 +948,12 @@ def create_app(args):
|
|
941 |
vector_db_storage_cls_kwargs={
|
942 |
"cosine_better_than_threshold": args.cosine_threshold
|
943 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
)
|
945 |
|
946 |
async def index_file(file_path: Union[str, Path]) -> None:
|
|
|
13 |
from typing import Dict
|
14 |
import threading
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import json
|
17 |
import os
|
18 |
|
|
|
22 |
import argparse
|
23 |
import time
|
24 |
import re
|
25 |
+
from typing import List, Any, Optional, Union
|
26 |
from lightrag import LightRAG, QueryParam
|
27 |
from lightrag.api import __api_version__
|
28 |
|
|
|
45 |
|
46 |
from dotenv import load_dotenv
|
47 |
|
48 |
+
# Load environment variables
|
49 |
load_dotenv()
|
50 |
|
51 |
+
# Global progress tracker
|
52 |
+
scan_progress: Dict = {
|
53 |
+
"is_scanning": False,
|
54 |
+
"current_file": "",
|
55 |
+
"indexed_count": 0,
|
56 |
+
"total_files": 0,
|
57 |
+
"progress": 0,
|
58 |
+
}
|
59 |
+
|
60 |
+
# Lock for thread-safe operations
|
61 |
+
progress_lock = threading.Lock()
|
62 |
+
|
63 |
|
64 |
def estimate_tokens(text: str) -> int:
|
65 |
"""Estimate the number of tokens in text
|
|
|
919 |
vector_db_storage_cls_kwargs={
|
920 |
"cosine_better_than_threshold": args.cosine_threshold
|
921 |
},
|
922 |
+
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
923 |
+
embedding_cache_config={
|
924 |
+
"enabled": True,
|
925 |
+
"similarity_threshold": 0.95,
|
926 |
+
"use_llm_check": False,
|
927 |
+
},
|
928 |
)
|
929 |
else:
|
930 |
rag = LightRAG(
|
|
|
948 |
vector_db_storage_cls_kwargs={
|
949 |
"cosine_better_than_threshold": args.cosine_threshold
|
950 |
},
|
951 |
+
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
|
952 |
+
embedding_cache_config={
|
953 |
+
"enabled": True,
|
954 |
+
"similarity_threshold": 0.95,
|
955 |
+
"use_llm_check": False,
|
956 |
+
},
|
957 |
)
|
958 |
|
959 |
async def index_file(file_path: Union[str, Path]) -> None:
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -76,6 +76,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
76 |
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
77 |
|
78 |
def __post_init__(self):
|
|
|
|
|
79 |
# Use global config value if specified, otherwise use default
|
80 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
81 |
self.cosine_better_than_threshold = config.get(
|
@@ -138,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
138 |
embedding = await self.embedding_func([query])
|
139 |
embedding = embedding[0]
|
140 |
logger.info(
|
141 |
-
f"Query: {query}, top_k: {top_k},
|
142 |
)
|
143 |
results = self._client.query(
|
144 |
query=embedding,
|
@@ -210,4 +212,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
210 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
211 |
|
212 |
async def index_done_callback(self):
|
213 |
-
|
|
|
|
|
|
76 |
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
77 |
|
78 |
def __post_init__(self):
|
79 |
+
# Initialize lock only for file operations
|
80 |
+
self._save_lock = asyncio.Lock()
|
81 |
# Use global config value if specified, otherwise use default
|
82 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
83 |
self.cosine_better_than_threshold = config.get(
|
|
|
140 |
embedding = await self.embedding_func([query])
|
141 |
embedding = embedding[0]
|
142 |
logger.info(
|
143 |
+
f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}"
|
144 |
)
|
145 |
results = self._client.query(
|
146 |
query=embedding,
|
|
|
212 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
213 |
|
214 |
async def index_done_callback(self):
|
215 |
+
# Protect file write operation
|
216 |
+
async with self._save_lock:
|
217 |
+
self._client.save()
|
lightrag/lightrag.py
CHANGED
@@ -231,7 +231,7 @@ class LightRAG:
|
|
231 |
|
232 |
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
233 |
namespace="llm_response_cache",
|
234 |
-
embedding_func=
|
235 |
)
|
236 |
|
237 |
####
|
@@ -275,7 +275,7 @@ class LightRAG:
|
|
275 |
else:
|
276 |
hashing_kv = self.key_string_value_json_storage_cls(
|
277 |
namespace="llm_response_cache",
|
278 |
-
embedding_func=
|
279 |
)
|
280 |
|
281 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
@@ -916,7 +916,7 @@ class LightRAG:
|
|
916 |
else self.key_string_value_json_storage_cls(
|
917 |
namespace="llm_response_cache",
|
918 |
global_config=asdict(self),
|
919 |
-
embedding_func=
|
920 |
),
|
921 |
prompt=prompt,
|
922 |
)
|
@@ -933,7 +933,7 @@ class LightRAG:
|
|
933 |
else self.key_string_value_json_storage_cls(
|
934 |
namespace="llm_response_cache",
|
935 |
global_config=asdict(self),
|
936 |
-
embedding_func=
|
937 |
),
|
938 |
)
|
939 |
elif param.mode == "mix":
|
@@ -952,7 +952,7 @@ class LightRAG:
|
|
952 |
else self.key_string_value_json_storage_cls(
|
953 |
namespace="llm_response_cache",
|
954 |
global_config=asdict(self),
|
955 |
-
embedding_func=
|
956 |
),
|
957 |
)
|
958 |
else:
|
@@ -993,7 +993,7 @@ class LightRAG:
|
|
993 |
or self.key_string_value_json_storage_cls(
|
994 |
namespace="llm_response_cache",
|
995 |
global_config=asdict(self),
|
996 |
-
embedding_func=
|
997 |
),
|
998 |
)
|
999 |
|
@@ -1024,7 +1024,7 @@ class LightRAG:
|
|
1024 |
else self.key_string_value_json_storage_cls(
|
1025 |
namespace="llm_response_cache",
|
1026 |
global_config=asdict(self),
|
1027 |
-
embedding_func=
|
1028 |
),
|
1029 |
)
|
1030 |
elif param.mode == "naive":
|
@@ -1040,7 +1040,7 @@ class LightRAG:
|
|
1040 |
else self.key_string_value_json_storage_cls(
|
1041 |
namespace="llm_response_cache",
|
1042 |
global_config=asdict(self),
|
1043 |
-
embedding_func=
|
1044 |
),
|
1045 |
)
|
1046 |
elif param.mode == "mix":
|
@@ -1059,7 +1059,7 @@ class LightRAG:
|
|
1059 |
else self.key_string_value_json_storage_cls(
|
1060 |
namespace="llm_response_cache",
|
1061 |
global_config=asdict(self),
|
1062 |
-
embedding_func=
|
1063 |
),
|
1064 |
)
|
1065 |
else:
|
|
|
231 |
|
232 |
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
233 |
namespace="llm_response_cache",
|
234 |
+
embedding_func=self.embedding_func,
|
235 |
)
|
236 |
|
237 |
####
|
|
|
275 |
else:
|
276 |
hashing_kv = self.key_string_value_json_storage_cls(
|
277 |
namespace="llm_response_cache",
|
278 |
+
embedding_func=self.embedding_func,
|
279 |
)
|
280 |
|
281 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
|
916 |
else self.key_string_value_json_storage_cls(
|
917 |
namespace="llm_response_cache",
|
918 |
global_config=asdict(self),
|
919 |
+
embedding_func=self.embedding_func,
|
920 |
),
|
921 |
prompt=prompt,
|
922 |
)
|
|
|
933 |
else self.key_string_value_json_storage_cls(
|
934 |
namespace="llm_response_cache",
|
935 |
global_config=asdict(self),
|
936 |
+
embedding_func=self.embedding_func,
|
937 |
),
|
938 |
)
|
939 |
elif param.mode == "mix":
|
|
|
952 |
else self.key_string_value_json_storage_cls(
|
953 |
namespace="llm_response_cache",
|
954 |
global_config=asdict(self),
|
955 |
+
embedding_func=self.embedding_func,
|
956 |
),
|
957 |
)
|
958 |
else:
|
|
|
993 |
or self.key_string_value_json_storage_cls(
|
994 |
namespace="llm_response_cache",
|
995 |
global_config=asdict(self),
|
996 |
+
embedding_func=self.embedding_func,
|
997 |
),
|
998 |
)
|
999 |
|
|
|
1024 |
else self.key_string_value_json_storage_cls(
|
1025 |
namespace="llm_response_cache",
|
1026 |
global_config=asdict(self),
|
1027 |
+
embedding_func=self.embedding_funcne,
|
1028 |
),
|
1029 |
)
|
1030 |
elif param.mode == "naive":
|
|
|
1040 |
else self.key_string_value_json_storage_cls(
|
1041 |
namespace="llm_response_cache",
|
1042 |
global_config=asdict(self),
|
1043 |
+
embedding_func=self.embedding_func,
|
1044 |
),
|
1045 |
)
|
1046 |
elif param.mode == "mix":
|
|
|
1059 |
else self.key_string_value_json_storage_cls(
|
1060 |
namespace="llm_response_cache",
|
1061 |
global_config=asdict(self),
|
1062 |
+
embedding_func=self.embedding_func,
|
1063 |
),
|
1064 |
)
|
1065 |
else:
|
lightrag/operate.py
CHANGED
@@ -352,16 +352,6 @@ async def extract_entities(
|
|
352 |
input_text: str, history_messages: list[dict[str, str]] = None
|
353 |
) -> str:
|
354 |
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
355 |
-
need_to_restore = False
|
356 |
-
if (
|
357 |
-
global_config["embedding_cache_config"]
|
358 |
-
and global_config["embedding_cache_config"]["enabled"]
|
359 |
-
):
|
360 |
-
new_config = global_config.copy()
|
361 |
-
new_config["embedding_cache_config"] = None
|
362 |
-
new_config["enable_llm_cache"] = True
|
363 |
-
llm_response_cache.global_config = new_config
|
364 |
-
need_to_restore = True
|
365 |
if history_messages:
|
366 |
history = json.dumps(history_messages, ensure_ascii=False)
|
367 |
_prompt = history + "\n" + input_text
|
@@ -370,10 +360,13 @@ async def extract_entities(
|
|
370 |
|
371 |
arg_hash = compute_args_hash(_prompt)
|
372 |
cached_return, _1, _2, _3 = await handle_cache(
|
373 |
-
llm_response_cache,
|
|
|
|
|
|
|
|
|
|
|
374 |
)
|
375 |
-
if need_to_restore:
|
376 |
-
llm_response_cache.global_config = global_config
|
377 |
if cached_return:
|
378 |
logger.debug(f"Found cache for {arg_hash}")
|
379 |
statistic_data["llm_cache"] += 1
|
@@ -387,7 +380,12 @@ async def extract_entities(
|
|
387 |
res: str = await use_llm_func(input_text)
|
388 |
await save_to_cache(
|
389 |
llm_response_cache,
|
390 |
-
CacheData(
|
|
|
|
|
|
|
|
|
|
|
391 |
)
|
392 |
return res
|
393 |
|
@@ -740,7 +738,7 @@ async def extract_keywords_only(
|
|
740 |
# 6. Parse out JSON from the LLM response
|
741 |
match = re.search(r"\{.*\}", result, re.DOTALL)
|
742 |
if not match:
|
743 |
-
logger.error("No JSON-like structure found in the
|
744 |
return [], []
|
745 |
try:
|
746 |
keywords_data = json.loads(match.group(0))
|
@@ -752,20 +750,24 @@ async def extract_keywords_only(
|
|
752 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
753 |
|
754 |
# 7. Cache only the processed keywords with cache type
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
|
|
|
|
|
|
|
|
769 |
return hl_keywords, ll_keywords
|
770 |
|
771 |
|
|
|
352 |
input_text: str, history_messages: list[dict[str, str]] = None
|
353 |
) -> str:
|
354 |
if enable_llm_cache_for_entity_extract and llm_response_cache:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
if history_messages:
|
356 |
history = json.dumps(history_messages, ensure_ascii=False)
|
357 |
_prompt = history + "\n" + input_text
|
|
|
360 |
|
361 |
arg_hash = compute_args_hash(_prompt)
|
362 |
cached_return, _1, _2, _3 = await handle_cache(
|
363 |
+
llm_response_cache,
|
364 |
+
arg_hash,
|
365 |
+
_prompt,
|
366 |
+
"default",
|
367 |
+
cache_type="extract",
|
368 |
+
force_llm_cache=True,
|
369 |
)
|
|
|
|
|
370 |
if cached_return:
|
371 |
logger.debug(f"Found cache for {arg_hash}")
|
372 |
statistic_data["llm_cache"] += 1
|
|
|
380 |
res: str = await use_llm_func(input_text)
|
381 |
await save_to_cache(
|
382 |
llm_response_cache,
|
383 |
+
CacheData(
|
384 |
+
args_hash=arg_hash,
|
385 |
+
content=res,
|
386 |
+
prompt=_prompt,
|
387 |
+
cache_type="extract",
|
388 |
+
),
|
389 |
)
|
390 |
return res
|
391 |
|
|
|
738 |
# 6. Parse out JSON from the LLM response
|
739 |
match = re.search(r"\{.*\}", result, re.DOTALL)
|
740 |
if not match:
|
741 |
+
logger.error("No JSON-like structure found in the LLM respond.")
|
742 |
return [], []
|
743 |
try:
|
744 |
keywords_data = json.loads(match.group(0))
|
|
|
750 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
751 |
|
752 |
# 7. Cache only the processed keywords with cache type
|
753 |
+
if hl_keywords or ll_keywords:
|
754 |
+
cache_data = {
|
755 |
+
"high_level_keywords": hl_keywords,
|
756 |
+
"low_level_keywords": ll_keywords,
|
757 |
+
}
|
758 |
+
await save_to_cache(
|
759 |
+
hashing_kv,
|
760 |
+
CacheData(
|
761 |
+
args_hash=args_hash,
|
762 |
+
content=json.dumps(cache_data),
|
763 |
+
prompt=text,
|
764 |
+
quantized=quantized,
|
765 |
+
min_val=min_val,
|
766 |
+
max_val=max_val,
|
767 |
+
mode=param.mode,
|
768 |
+
cache_type="keywords",
|
769 |
+
),
|
770 |
+
)
|
771 |
return hl_keywords, ll_keywords
|
772 |
|
773 |
|
lightrag/prompt.py
CHANGED
@@ -290,9 +290,8 @@ PROMPTS[
|
|
290 |
Question 1: {original_prompt}
|
291 |
Question 2: {cached_prompt}
|
292 |
|
293 |
-
Please evaluate
|
294 |
-
|
295 |
-
2. Whether the answer to Question 2 can be used to answer Question 1
|
296 |
Similarity score criteria:
|
297 |
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
298 |
- The questions have different topics
|
|
|
290 |
Question 1: {original_prompt}
|
291 |
Question 2: {cached_prompt}
|
292 |
|
293 |
+
Please evaluate whether these two questions are semantically similar, and whether the answer to Question 2 can be used to answer Question 1, provide a similarity score between 0 and 1 directly.
|
294 |
+
|
|
|
295 |
Similarity score criteria:
|
296 |
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
297 |
- The questions have different topics
|
lightrag/utils.py
CHANGED
@@ -58,17 +58,10 @@ class EmbeddingFunc:
|
|
58 |
embedding_dim: int
|
59 |
max_token_size: int
|
60 |
func: callable
|
61 |
-
concurrent_limit: int = 16
|
62 |
-
|
63 |
-
def __post_init__(self):
|
64 |
-
if self.concurrent_limit != 0:
|
65 |
-
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
66 |
-
else:
|
67 |
-
self._semaphore = UnlimitedSemaphore()
|
68 |
|
69 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
70 |
-
|
71 |
-
return await self.func(*args, **kwargs)
|
72 |
|
73 |
|
74 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
@@ -112,7 +105,7 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|
112 |
"""Compute a hash for the given arguments.
|
113 |
Args:
|
114 |
*args: Arguments to hash
|
115 |
-
cache_type: Type of cache (e.g., 'keywords', 'query')
|
116 |
Returns:
|
117 |
str: Hash string
|
118 |
"""
|
@@ -131,22 +124,17 @@ def compute_mdhash_id(content, prefix: str = ""):
|
|
131 |
return prefix + md5(content.encode()).hexdigest()
|
132 |
|
133 |
|
134 |
-
def limit_async_func_call(max_size: int
|
135 |
-
"""Add restriction of maximum async
|
136 |
|
137 |
def final_decro(func):
|
138 |
-
|
139 |
-
__current_size = 0
|
140 |
|
141 |
@wraps(func)
|
142 |
async def wait_func(*args, **kwargs):
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
__current_size += 1
|
147 |
-
result = await func(*args, **kwargs)
|
148 |
-
__current_size -= 1
|
149 |
-
return result
|
150 |
|
151 |
return wait_func
|
152 |
|
@@ -380,6 +368,9 @@ async def get_best_cached_response(
|
|
380 |
original_prompt=None,
|
381 |
cache_type=None,
|
382 |
) -> Union[str, None]:
|
|
|
|
|
|
|
383 |
mode_cache = await hashing_kv.get_by_id(mode)
|
384 |
if not mode_cache:
|
385 |
return None
|
@@ -470,8 +461,12 @@ def cosine_similarity(v1, v2):
|
|
470 |
return dot_product / (norm1 * norm2)
|
471 |
|
472 |
|
473 |
-
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
474 |
"""Quantize embedding to specified bits"""
|
|
|
|
|
|
|
|
|
475 |
# Calculate min/max values for reconstruction
|
476 |
min_val = embedding.min()
|
477 |
max_val = embedding.max()
|
@@ -491,59 +486,60 @@ def dequantize_embedding(
|
|
491 |
return (quantized * scale + min_val).astype(np.float32)
|
492 |
|
493 |
|
494 |
-
async def handle_cache(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
"""Generic cache handling function"""
|
496 |
-
if hashing_kv is None or not
|
|
|
|
|
497 |
return None, None, None, None
|
498 |
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
505 |
-
if args_hash in mode_cache:
|
506 |
-
return mode_cache[args_hash]["return"], None, None, None
|
507 |
-
return None, None, None, None
|
508 |
-
|
509 |
-
# Get embedding cache configuration
|
510 |
-
embedding_cache_config = hashing_kv.global_config.get(
|
511 |
-
"embedding_cache_config",
|
512 |
-
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
513 |
-
)
|
514 |
-
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
515 |
-
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
516 |
-
|
517 |
-
quantized = min_val = max_val = None
|
518 |
-
if is_embedding_cache_enabled:
|
519 |
-
# Use embedding cache
|
520 |
-
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
521 |
-
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
522 |
-
|
523 |
-
current_embedding = await embedding_model_func([prompt])
|
524 |
-
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
525 |
-
best_cached_response = await get_best_cached_response(
|
526 |
-
hashing_kv,
|
527 |
-
current_embedding[0],
|
528 |
-
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
529 |
-
mode=mode,
|
530 |
-
use_llm_check=use_llm_check,
|
531 |
-
llm_func=llm_model_func if use_llm_check else None,
|
532 |
-
original_prompt=prompt if use_llm_check else None,
|
533 |
-
cache_type=cache_type,
|
534 |
)
|
535 |
-
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
else:
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
else:
|
542 |
-
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
543 |
-
if args_hash in mode_cache:
|
544 |
-
return mode_cache[args_hash]["return"], None, None, None
|
545 |
|
546 |
-
return None,
|
547 |
|
548 |
|
549 |
@dataclass
|
@@ -572,6 +568,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|
572 |
|
573 |
mode_cache[cache_data.args_hash] = {
|
574 |
"return": cache_data.content,
|
|
|
575 |
"embedding": cache_data.quantized.tobytes().hex()
|
576 |
if cache_data.quantized is not None
|
577 |
else None,
|
|
|
58 |
embedding_dim: int
|
59 |
max_token_size: int
|
60 |
func: callable
|
61 |
+
# concurrent_limit: int = 16
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
64 |
+
return await self.func(*args, **kwargs)
|
|
|
65 |
|
66 |
|
67 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
|
105 |
"""Compute a hash for the given arguments.
|
106 |
Args:
|
107 |
*args: Arguments to hash
|
108 |
+
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
109 |
Returns:
|
110 |
str: Hash string
|
111 |
"""
|
|
|
124 |
return prefix + md5(content.encode()).hexdigest()
|
125 |
|
126 |
|
127 |
+
def limit_async_func_call(max_size: int):
|
128 |
+
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
129 |
|
130 |
def final_decro(func):
|
131 |
+
sem = asyncio.Semaphore(max_size)
|
|
|
132 |
|
133 |
@wraps(func)
|
134 |
async def wait_func(*args, **kwargs):
|
135 |
+
async with sem:
|
136 |
+
result = await func(*args, **kwargs)
|
137 |
+
return result
|
|
|
|
|
|
|
|
|
138 |
|
139 |
return wait_func
|
140 |
|
|
|
368 |
original_prompt=None,
|
369 |
cache_type=None,
|
370 |
) -> Union[str, None]:
|
371 |
+
logger.debug(
|
372 |
+
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
373 |
+
)
|
374 |
mode_cache = await hashing_kv.get_by_id(mode)
|
375 |
if not mode_cache:
|
376 |
return None
|
|
|
461 |
return dot_product / (norm1 * norm2)
|
462 |
|
463 |
|
464 |
+
def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
|
465 |
"""Quantize embedding to specified bits"""
|
466 |
+
# Convert list to numpy array if needed
|
467 |
+
if isinstance(embedding, list):
|
468 |
+
embedding = np.array(embedding)
|
469 |
+
|
470 |
# Calculate min/max values for reconstruction
|
471 |
min_val = embedding.min()
|
472 |
max_val = embedding.max()
|
|
|
486 |
return (quantized * scale + min_val).astype(np.float32)
|
487 |
|
488 |
|
489 |
+
async def handle_cache(
|
490 |
+
hashing_kv,
|
491 |
+
args_hash,
|
492 |
+
prompt,
|
493 |
+
mode="default",
|
494 |
+
cache_type=None,
|
495 |
+
force_llm_cache=False,
|
496 |
+
):
|
497 |
"""Generic cache handling function"""
|
498 |
+
if hashing_kv is None or not (
|
499 |
+
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
500 |
+
):
|
501 |
return None, None, None, None
|
502 |
|
503 |
+
if mode != "default":
|
504 |
+
# Get embedding cache configuration
|
505 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
506 |
+
"embedding_cache_config",
|
507 |
+
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
)
|
509 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
510 |
+
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
511 |
+
|
512 |
+
quantized = min_val = max_val = None
|
513 |
+
if is_embedding_cache_enabled:
|
514 |
+
# Use embedding cache
|
515 |
+
current_embedding = await hashing_kv.embedding_func([prompt])
|
516 |
+
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
517 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
518 |
+
best_cached_response = await get_best_cached_response(
|
519 |
+
hashing_kv,
|
520 |
+
current_embedding[0],
|
521 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
522 |
+
mode=mode,
|
523 |
+
use_llm_check=use_llm_check,
|
524 |
+
llm_func=llm_model_func if use_llm_check else None,
|
525 |
+
original_prompt=prompt,
|
526 |
+
cache_type=cache_type,
|
527 |
+
)
|
528 |
+
if best_cached_response is not None:
|
529 |
+
return best_cached_response, None, None, None
|
530 |
+
else:
|
531 |
+
return None, quantized, min_val, max_val
|
532 |
+
|
533 |
+
# For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False
|
534 |
+
# Use regular cache
|
535 |
+
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
536 |
+
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
537 |
else:
|
538 |
+
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
539 |
+
if args_hash in mode_cache:
|
540 |
+
return mode_cache[args_hash]["return"], None, None, None
|
|
|
|
|
|
|
|
|
541 |
|
542 |
+
return None, None, None, None
|
543 |
|
544 |
|
545 |
@dataclass
|
|
|
568 |
|
569 |
mode_cache[cache_data.args_hash] = {
|
570 |
"return": cache_data.content,
|
571 |
+
"cache_type": cache_data.cache_type,
|
572 |
"embedding": cache_data.quantized.tobytes().hex()
|
573 |
if cache_data.quantized is not None
|
574 |
else None,
|