缓存计算函数迁移到工具类
Browse files- lightrag/utils.py +69 -1
lightrag/utils.py
CHANGED
@@ -9,7 +9,7 @@ import re
|
|
9 |
from dataclasses import dataclass
|
10 |
from functools import wraps
|
11 |
from hashlib import md5
|
12 |
-
from typing import Any, Union, List
|
13 |
import xml.etree.ElementTree as ET
|
14 |
|
15 |
import numpy as np
|
@@ -390,3 +390,71 @@ def dequantize_embedding(
|
|
390 |
"""Restore quantized embedding"""
|
391 |
scale = (max_val - min_val) / (2**bits - 1)
|
392 |
return (quantized * scale + min_val).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from dataclasses import dataclass
|
10 |
from functools import wraps
|
11 |
from hashlib import md5
|
12 |
+
from typing import Any, Union, List, Optional
|
13 |
import xml.etree.ElementTree as ET
|
14 |
|
15 |
import numpy as np
|
|
|
390 |
"""Restore quantized embedding"""
|
391 |
scale = (max_val - min_val) / (2**bits - 1)
|
392 |
return (quantized * scale + min_val).astype(np.float32)
|
393 |
+
|
394 |
+
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
395 |
+
"""Generic cache handling function"""
|
396 |
+
if hashing_kv is None:
|
397 |
+
return None, None, None, None
|
398 |
+
|
399 |
+
# Get embedding cache configuration
|
400 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
401 |
+
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
402 |
+
)
|
403 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
404 |
+
|
405 |
+
quantized = min_val = max_val = None
|
406 |
+
if is_embedding_cache_enabled:
|
407 |
+
# Use embedding cache
|
408 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
409 |
+
current_embedding = await embedding_model_func([prompt])
|
410 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
411 |
+
best_cached_response = await get_best_cached_response(
|
412 |
+
hashing_kv,
|
413 |
+
current_embedding[0],
|
414 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
415 |
+
mode=mode,
|
416 |
+
)
|
417 |
+
if best_cached_response is not None:
|
418 |
+
return best_cached_response, None, None, None
|
419 |
+
else:
|
420 |
+
# Use regular cache
|
421 |
+
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
422 |
+
if args_hash in mode_cache:
|
423 |
+
return mode_cache[args_hash]["return"], None, None, None
|
424 |
+
|
425 |
+
return None, quantized, min_val, max_val
|
426 |
+
|
427 |
+
|
428 |
+
@dataclass
|
429 |
+
class CacheData:
|
430 |
+
args_hash: str
|
431 |
+
content: str
|
432 |
+
model: str
|
433 |
+
prompt: str
|
434 |
+
quantized: Optional[np.ndarray] = None
|
435 |
+
min_val: Optional[float] = None
|
436 |
+
max_val: Optional[float] = None
|
437 |
+
mode: str = "default"
|
438 |
+
|
439 |
+
|
440 |
+
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
441 |
+
if hashing_kv is None:
|
442 |
+
return
|
443 |
+
|
444 |
+
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
445 |
+
|
446 |
+
mode_cache[cache_data.args_hash] = {
|
447 |
+
"return": cache_data.content,
|
448 |
+
"model": cache_data.model,
|
449 |
+
"embedding": cache_data.quantized.tobytes().hex()
|
450 |
+
if cache_data.quantized is not None
|
451 |
+
else None,
|
452 |
+
"embedding_shape": cache_data.quantized.shape
|
453 |
+
if cache_data.quantized is not None
|
454 |
+
else None,
|
455 |
+
"embedding_min": cache_data.min_val,
|
456 |
+
"embedding_max": cache_data.max_val,
|
457 |
+
"original_prompt": cache_data.prompt,
|
458 |
+
}
|
459 |
+
|
460 |
+
await hashing_kv.upsert({cache_data.mode: mode_cache})
|