feat: Flatten LLM cache structure for improved recall efficiency
Browse filesRefactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency.
- examples/unofficial-sample/copy_llm_cache_to_another_storage.py +34 -18
- lightrag/kg/__init__.py +3 -0
- lightrag/kg/chroma_impl.py +1 -2
- lightrag/kg/json_kv_impl.py +83 -28
- lightrag/kg/milvus_impl.py +1 -1
- lightrag/kg/mongo_impl.py +19 -51
- lightrag/kg/postgres_impl.py +158 -71
- lightrag/kg/qdrant_impl.py +1 -1
- lightrag/kg/redis_impl.py +477 -36
- lightrag/kg/tidb_impl.py +2 -4
- lightrag/operate.py +7 -7
- lightrag/utils.py +52 -158
examples/unofficial-sample/copy_llm_cache_to_another_storage.py
CHANGED
@@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
|
|
52 |
embedding_func=None,
|
53 |
)
|
54 |
|
|
|
|
|
|
|
|
|
55 |
kv = {}
|
56 |
-
for
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
67 |
await to_llm_response_cache.upsert(kv)
|
68 |
await to_llm_response_cache.index_done_callback()
|
69 |
print("Mission accomplished!")
|
@@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
|
|
85 |
db=postgres_db,
|
86 |
)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
if __name__ == "__main__":
|
|
|
52 |
embedding_func=None,
|
53 |
)
|
54 |
|
55 |
+
# Get all cache data using the new flattened structure
|
56 |
+
all_data = await from_llm_response_cache.get_all()
|
57 |
+
|
58 |
+
# Convert flattened data to hierarchical structure for JsonKVStorage
|
59 |
kv = {}
|
60 |
+
for flattened_key, cache_entry in all_data.items():
|
61 |
+
# Parse flattened key: {mode}:{cache_type}:{hash}
|
62 |
+
parts = flattened_key.split(":", 2)
|
63 |
+
if len(parts) == 3:
|
64 |
+
mode, cache_type, hash_value = parts
|
65 |
+
if mode not in kv:
|
66 |
+
kv[mode] = {}
|
67 |
+
kv[mode][hash_value] = cache_entry
|
68 |
+
print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
|
69 |
+
else:
|
70 |
+
print(f"Skipping invalid key format: {flattened_key}")
|
71 |
+
|
72 |
await to_llm_response_cache.upsert(kv)
|
73 |
await to_llm_response_cache.index_done_callback()
|
74 |
print("Mission accomplished!")
|
|
|
90 |
db=postgres_db,
|
91 |
)
|
92 |
|
93 |
+
# Get all cache data from JsonKVStorage (hierarchical structure)
|
94 |
+
all_data = await from_llm_response_cache.get_all()
|
95 |
+
|
96 |
+
# Convert hierarchical data to flattened structure for PGKVStorage
|
97 |
+
flattened_data = {}
|
98 |
+
for mode, mode_data in all_data.items():
|
99 |
+
print(f"Processing mode: {mode}")
|
100 |
+
for hash_value, cache_entry in mode_data.items():
|
101 |
+
# Determine cache_type from cache entry or use default
|
102 |
+
cache_type = cache_entry.get("cache_type", "extract")
|
103 |
+
# Create flattened key: {mode}:{cache_type}:{hash}
|
104 |
+
flattened_key = f"{mode}:{cache_type}:{hash_value}"
|
105 |
+
flattened_data[flattened_key] = cache_entry
|
106 |
+
print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
|
107 |
+
|
108 |
+
# Upsert the flattened data
|
109 |
+
await to_llm_response_cache.upsert(flattened_data)
|
110 |
+
print("Mission accomplished!")
|
111 |
|
112 |
|
113 |
if __name__ == "__main__":
|
lightrag/kg/__init__.py
CHANGED
@@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|
37 |
"DOC_STATUS_STORAGE": {
|
38 |
"implementations": [
|
39 |
"JsonDocStatusStorage",
|
|
|
40 |
"PGDocStatusStorage",
|
41 |
"MongoDocStatusStorage",
|
42 |
],
|
@@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
79 |
"MongoVectorDBStorage": [],
|
80 |
# Document Status Storage Implementations
|
81 |
"JsonDocStatusStorage": [],
|
|
|
82 |
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
83 |
"MongoDocStatusStorage": [],
|
84 |
}
|
@@ -96,6 +98,7 @@ STORAGES = {
|
|
96 |
"MongoGraphStorage": ".kg.mongo_impl",
|
97 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
98 |
"RedisKVStorage": ".kg.redis_impl",
|
|
|
99 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
100 |
# "TiDBKVStorage": ".kg.tidb_impl",
|
101 |
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
|
|
37 |
"DOC_STATUS_STORAGE": {
|
38 |
"implementations": [
|
39 |
"JsonDocStatusStorage",
|
40 |
+
"RedisDocStatusStorage",
|
41 |
"PGDocStatusStorage",
|
42 |
"MongoDocStatusStorage",
|
43 |
],
|
|
|
80 |
"MongoVectorDBStorage": [],
|
81 |
# Document Status Storage Implementations
|
82 |
"JsonDocStatusStorage": [],
|
83 |
+
"RedisDocStatusStorage": ["REDIS_URI"],
|
84 |
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
85 |
"MongoDocStatusStorage": [],
|
86 |
}
|
|
|
98 |
"MongoGraphStorage": ".kg.mongo_impl",
|
99 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
100 |
"RedisKVStorage": ".kg.redis_impl",
|
101 |
+
"RedisDocStatusStorage": ".kg.redis_impl",
|
102 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
103 |
# "TiDBKVStorage": ".kg.tidb_impl",
|
104 |
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
lightrag/kg/chroma_impl.py
CHANGED
@@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
109 |
raise
|
110 |
|
111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
112 |
-
logger.
|
113 |
if not data:
|
114 |
return
|
115 |
|
@@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
234 |
ids: List of vector IDs to be deleted
|
235 |
"""
|
236 |
try:
|
237 |
-
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
238 |
self._collection.delete(ids=ids)
|
239 |
logger.debug(
|
240 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
|
|
109 |
raise
|
110 |
|
111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
112 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
113 |
if not data:
|
114 |
return
|
115 |
|
|
|
234 |
ids: List of vector IDs to be deleted
|
235 |
"""
|
236 |
try:
|
|
|
237 |
self._collection.delete(ids=ids)
|
238 |
logger.debug(
|
239 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
|
|
42 |
if need_init:
|
43 |
loaded_data = load_json(self._file_name) or {}
|
44 |
async with self._storage_lock:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# For cache namespaces, sum the cache entries across all cache types
|
50 |
-
data_count = sum(
|
51 |
-
len(first_level_dict)
|
52 |
-
for first_level_dict in loaded_data.values()
|
53 |
-
if isinstance(first_level_dict, dict)
|
54 |
)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
logger.info(
|
60 |
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
@@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
|
|
67 |
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
68 |
)
|
69 |
|
70 |
-
# Calculate data count
|
71 |
-
|
72 |
-
# # For cache namespaces, sum the cache entries across all cache types
|
73 |
-
data_count = sum(
|
74 |
-
len(first_level_dict)
|
75 |
-
for first_level_dict in data_dict.values()
|
76 |
-
if isinstance(first_level_dict, dict)
|
77 |
-
)
|
78 |
-
else:
|
79 |
-
# For non-cache namespaces, use the original count method
|
80 |
-
data_count = len(data_dict)
|
81 |
|
82 |
logger.debug(
|
83 |
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
@@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage):
|
|
150 |
await set_all_update_flags(self.namespace)
|
151 |
|
152 |
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
153 |
-
"""Delete specific records from storage by
|
154 |
|
155 |
Importance notes for in-memory storage:
|
156 |
1. Changes will be persisted to disk during the next index_done_callback
|
157 |
2. update flags to notify other processes that data persistence is needed
|
158 |
|
159 |
Args:
|
160 |
-
|
161 |
|
162 |
Returns:
|
163 |
True: if the cache drop successfully
|
@@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage):
|
|
167 |
return False
|
168 |
|
169 |
try:
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
return True
|
172 |
-
except Exception:
|
|
|
173 |
return False
|
174 |
|
175 |
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
@@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage):
|
|
245 |
logger.error(f"Error dropping {self.namespace}: {e}")
|
246 |
return {"status": "error", "message": str(e)}
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
async def finalize(self):
|
249 |
"""Finalize storage resources
|
250 |
Persistence cache data to disk before exiting
|
251 |
"""
|
252 |
-
if self.namespace.endswith("
|
253 |
await self.index_done_callback()
|
|
|
42 |
if need_init:
|
43 |
loaded_data = load_json(self._file_name) or {}
|
44 |
async with self._storage_lock:
|
45 |
+
# Migrate legacy cache structure if needed
|
46 |
+
if self.namespace.endswith("_cache"):
|
47 |
+
loaded_data = await self._migrate_legacy_cache_structure(
|
48 |
+
loaded_data
|
|
|
|
|
|
|
|
|
|
|
49 |
)
|
50 |
+
|
51 |
+
self._data.update(loaded_data)
|
52 |
+
data_count = len(loaded_data)
|
53 |
|
54 |
logger.info(
|
55 |
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
|
|
62 |
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
63 |
)
|
64 |
|
65 |
+
# Calculate data count - all data is now flattened
|
66 |
+
data_count = len(data_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
logger.debug(
|
69 |
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
|
|
136 |
await set_all_update_flags(self.namespace)
|
137 |
|
138 |
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
139 |
+
"""Delete specific records from storage by cache mode
|
140 |
|
141 |
Importance notes for in-memory storage:
|
142 |
1. Changes will be persisted to disk during the next index_done_callback
|
143 |
2. update flags to notify other processes that data persistence is needed
|
144 |
|
145 |
Args:
|
146 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
147 |
|
148 |
Returns:
|
149 |
True: if the cache drop successfully
|
|
|
153 |
return False
|
154 |
|
155 |
try:
|
156 |
+
async with self._storage_lock:
|
157 |
+
keys_to_delete = []
|
158 |
+
modes_set = set(modes) # Convert to set for efficient lookup
|
159 |
+
|
160 |
+
for key in list(self._data.keys()):
|
161 |
+
# Parse flattened cache key: mode:cache_type:hash
|
162 |
+
parts = key.split(":", 2)
|
163 |
+
if len(parts) == 3 and parts[0] in modes_set:
|
164 |
+
keys_to_delete.append(key)
|
165 |
+
|
166 |
+
# Batch delete
|
167 |
+
for key in keys_to_delete:
|
168 |
+
self._data.pop(key, None)
|
169 |
+
|
170 |
+
if keys_to_delete:
|
171 |
+
await set_all_update_flags(self.namespace)
|
172 |
+
logger.info(
|
173 |
+
f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
|
174 |
+
)
|
175 |
+
|
176 |
return True
|
177 |
+
except Exception as e:
|
178 |
+
logger.error(f"Error dropping cache by modes: {e}")
|
179 |
return False
|
180 |
|
181 |
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
|
|
|
251 |
logger.error(f"Error dropping {self.namespace}: {e}")
|
252 |
return {"status": "error", "message": str(e)}
|
253 |
|
254 |
+
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
|
255 |
+
"""Migrate legacy nested cache structure to flattened structure
|
256 |
+
|
257 |
+
Args:
|
258 |
+
data: Original data dictionary that may contain legacy structure
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
Migrated data dictionary with flattened cache keys
|
262 |
+
"""
|
263 |
+
from lightrag.utils import generate_cache_key
|
264 |
+
|
265 |
+
# Early return if data is empty
|
266 |
+
if not data:
|
267 |
+
return data
|
268 |
+
|
269 |
+
# Check first entry to see if it's already in new format
|
270 |
+
first_key = next(iter(data.keys()))
|
271 |
+
if ":" in first_key and len(first_key.split(":")) == 3:
|
272 |
+
# Already in flattened format, return as-is
|
273 |
+
return data
|
274 |
+
|
275 |
+
migrated_data = {}
|
276 |
+
migration_count = 0
|
277 |
+
|
278 |
+
for key, value in data.items():
|
279 |
+
# Check if this is a legacy nested cache structure
|
280 |
+
if isinstance(value, dict) and all(
|
281 |
+
isinstance(v, dict) and "return" in v for v in value.values()
|
282 |
+
):
|
283 |
+
# This looks like a legacy cache mode with nested structure
|
284 |
+
mode = key
|
285 |
+
for cache_hash, cache_entry in value.items():
|
286 |
+
cache_type = cache_entry.get("cache_type", "extract")
|
287 |
+
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
288 |
+
migrated_data[flattened_key] = cache_entry
|
289 |
+
migration_count += 1
|
290 |
+
else:
|
291 |
+
# Keep non-cache data or already flattened cache data as-is
|
292 |
+
migrated_data[key] = value
|
293 |
+
|
294 |
+
if migration_count > 0:
|
295 |
+
logger.info(
|
296 |
+
f"Migrated {migration_count} legacy cache entries to flattened structure"
|
297 |
+
)
|
298 |
+
# Persist migrated data immediately
|
299 |
+
write_json(migrated_data, self._file_name)
|
300 |
+
|
301 |
+
return migrated_data
|
302 |
+
|
303 |
async def finalize(self):
|
304 |
"""Finalize storage resources
|
305 |
Persistence cache data to disk before exiting
|
306 |
"""
|
307 |
+
if self.namespace.endswith("_cache"):
|
308 |
await self.index_done_callback()
|
lightrag/kg/milvus_impl.py
CHANGED
@@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|
75 |
)
|
76 |
|
77 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
78 |
-
logger.
|
79 |
if not data:
|
80 |
return
|
81 |
|
|
|
75 |
)
|
76 |
|
77 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
78 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
79 |
if not data:
|
80 |
return
|
81 |
|
lightrag/kg/mongo_impl.py
CHANGED
@@ -15,7 +15,6 @@ from ..base import (
|
|
15 |
DocStatus,
|
16 |
DocStatusStorage,
|
17 |
)
|
18 |
-
from ..namespace import NameSpace, is_namespace
|
19 |
from ..utils import logger, compute_mdhash_id
|
20 |
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
21 |
from ..constants import GRAPH_FIELD_SEP
|
@@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage):
|
|
98 |
self._data = None
|
99 |
|
100 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
101 |
-
|
102 |
-
|
103 |
-
cursor = self._data.find({"_id": {"$regex": "^default_"}})
|
104 |
-
result = {}
|
105 |
-
async for doc in cursor:
|
106 |
-
# Use the complete _id as key
|
107 |
-
result[doc["_id"]] = doc
|
108 |
-
return result if result else None
|
109 |
-
else:
|
110 |
-
# Original behavior for non-"default" ids
|
111 |
-
return await self._data.find_one({"_id": id})
|
112 |
|
113 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
114 |
cursor = self._data.find({"_id": {"$in": ids}})
|
@@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage):
|
|
133 |
return result
|
134 |
|
135 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
136 |
-
logger.
|
137 |
if not data:
|
138 |
return
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
await asyncio.gather(*update_tasks)
|
152 |
-
else:
|
153 |
-
update_tasks = []
|
154 |
-
for k, v in data.items():
|
155 |
-
data[k]["_id"] = k
|
156 |
-
update_tasks.append(
|
157 |
-
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
158 |
-
)
|
159 |
-
await asyncio.gather(*update_tasks)
|
160 |
-
|
161 |
-
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
162 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
163 |
-
res = {}
|
164 |
-
v = await self._data.find_one({"_id": mode + "_" + id})
|
165 |
-
if v:
|
166 |
-
res[id] = v
|
167 |
-
logger.debug(f"llm_response_cache find one by:{id}")
|
168 |
-
return res
|
169 |
-
else:
|
170 |
-
return None
|
171 |
-
else:
|
172 |
-
return None
|
173 |
|
174 |
async def index_done_callback(self) -> None:
|
175 |
# Mongo handles persistence automatically
|
@@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage):
|
|
209 |
return False
|
210 |
|
211 |
try:
|
212 |
-
# Build regex pattern to match
|
213 |
-
pattern = f"^({'|'.join(modes)})
|
214 |
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
215 |
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
216 |
return True
|
@@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
274 |
return data - existing_ids
|
275 |
|
276 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
277 |
-
logger.
|
278 |
if not data:
|
279 |
return
|
280 |
update_tasks: list[Any] = []
|
@@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
1282 |
logger.debug("vector index already exist")
|
1283 |
|
1284 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
1285 |
-
logger.
|
1286 |
if not data:
|
1287 |
return
|
1288 |
|
@@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
1371 |
Args:
|
1372 |
ids: List of vector IDs to be deleted
|
1373 |
"""
|
1374 |
-
logger.
|
1375 |
if not ids:
|
1376 |
return
|
1377 |
|
|
|
15 |
DocStatus,
|
16 |
DocStatusStorage,
|
17 |
)
|
|
|
18 |
from ..utils import logger, compute_mdhash_id
|
19 |
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
20 |
from ..constants import GRAPH_FIELD_SEP
|
|
|
97 |
self._data = None
|
98 |
|
99 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
100 |
+
# Unified handling for flattened keys
|
101 |
+
return await self._data.find_one({"_id": id})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
104 |
cursor = self._data.find({"_id": {"$in": ids}})
|
|
|
123 |
return result
|
124 |
|
125 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
126 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
127 |
if not data:
|
128 |
return
|
129 |
|
130 |
+
# Unified handling for all namespaces with flattened keys
|
131 |
+
# Use bulk_write for better performance
|
132 |
+
from pymongo import UpdateOne
|
133 |
+
|
134 |
+
operations = []
|
135 |
+
for k, v in data.items():
|
136 |
+
v["_id"] = k # Use flattened key as _id
|
137 |
+
operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
|
138 |
+
|
139 |
+
if operations:
|
140 |
+
await self._data.bulk_write(operations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
async def index_done_callback(self) -> None:
|
143 |
# Mongo handles persistence automatically
|
|
|
177 |
return False
|
178 |
|
179 |
try:
|
180 |
+
# Build regex pattern to match flattened key format: mode:cache_type:hash
|
181 |
+
pattern = f"^({'|'.join(modes)}):"
|
182 |
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
183 |
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
184 |
return True
|
|
|
242 |
return data - existing_ids
|
243 |
|
244 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
245 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
246 |
if not data:
|
247 |
return
|
248 |
update_tasks: list[Any] = []
|
|
|
1250 |
logger.debug("vector index already exist")
|
1251 |
|
1252 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
1253 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
1254 |
if not data:
|
1255 |
return
|
1256 |
|
|
|
1339 |
Args:
|
1340 |
ids: List of vector IDs to be deleted
|
1341 |
"""
|
1342 |
+
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
|
1343 |
if not ids:
|
1344 |
return
|
1345 |
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -247,6 +247,116 @@ class PostgreSQLDB:
|
|
247 |
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
248 |
# Do not re-raise, to allow the application to start
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
async def check_tables(self):
|
251 |
# First create all tables
|
252 |
for k, v in TABLES.items():
|
@@ -304,6 +414,13 @@ class PostgreSQLDB:
|
|
304 |
except Exception as e:
|
305 |
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
async def query(
|
308 |
self,
|
309 |
sql: str,
|
@@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage):
|
|
486 |
|
487 |
try:
|
488 |
results = await self.db.query(sql, params, multirows=True)
|
489 |
-
|
|
|
490 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
491 |
-
|
492 |
for row in results:
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
except Exception as e:
|
501 |
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
502 |
return {}
|
503 |
|
504 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
505 |
-
"""Get
|
506 |
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
array_res = await self.db.query(sql, params, multirows=True)
|
511 |
-
res = {}
|
512 |
-
for row in array_res:
|
513 |
-
# Dynamically add cache_type field based on mode
|
514 |
-
row_with_cache_type = dict(row)
|
515 |
-
if id == "default":
|
516 |
-
row_with_cache_type["cache_type"] = "extract"
|
517 |
-
else:
|
518 |
-
row_with_cache_type["cache_type"] = "unknown"
|
519 |
-
res[row["id"]] = row_with_cache_type
|
520 |
-
return res if res else None
|
521 |
-
else:
|
522 |
-
params = {"workspace": self.db.workspace, "id": id}
|
523 |
-
response = await self.db.query(sql, params)
|
524 |
-
return response if response else None
|
525 |
-
|
526 |
-
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
527 |
-
"""Specifically for llm_response_cache."""
|
528 |
-
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
529 |
-
params = {"workspace": self.db.workspace, "mode": mode, "id": id}
|
530 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
531 |
-
array_res = await self.db.query(sql, params, multirows=True)
|
532 |
-
res = {}
|
533 |
-
for row in array_res:
|
534 |
-
res[row["id"]] = row
|
535 |
-
return res
|
536 |
-
else:
|
537 |
-
return None
|
538 |
|
539 |
# Query by id
|
540 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
541 |
-
"""Get
|
542 |
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
543 |
ids=",".join([f"'{id}'" for id in ids])
|
544 |
)
|
545 |
params = {"workspace": self.db.workspace}
|
546 |
-
|
547 |
-
array_res = await self.db.query(sql, params, multirows=True)
|
548 |
-
modes = set()
|
549 |
-
dict_res: dict[str, dict] = {}
|
550 |
-
for row in array_res:
|
551 |
-
modes.add(row["mode"])
|
552 |
-
for mode in modes:
|
553 |
-
if mode not in dict_res:
|
554 |
-
dict_res[mode] = {}
|
555 |
-
for row in array_res:
|
556 |
-
dict_res[row["mode"]][row["id"]] = row
|
557 |
-
return [{k: v} for k, v in dict_res.items()]
|
558 |
-
else:
|
559 |
-
return await self.db.query(sql, params, multirows=True)
|
560 |
|
561 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
562 |
"""Specifically for llm_response_cache."""
|
@@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage):
|
|
617 |
}
|
618 |
await self.db.execute(upsert_sql, _data)
|
619 |
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
620 |
-
for
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
}
|
631 |
|
632 |
-
|
633 |
|
634 |
async def index_done_callback(self) -> None:
|
635 |
# PG handles persistence automatically
|
@@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
1035 |
else:
|
1036 |
exist_keys = []
|
1037 |
new_keys = set([s for s in keys if s not in exist_keys])
|
1038 |
-
print(f"keys: {keys}")
|
1039 |
-
print(f"new_keys: {new_keys}")
|
1040 |
return new_keys
|
1041 |
except Exception as e:
|
1042 |
logger.error(
|
@@ -2621,7 +2708,7 @@ SQL_TEMPLATES = {
|
|
2621 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
2622 |
""",
|
2623 |
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
2624 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND
|
2625 |
""",
|
2626 |
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
2627 |
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
|
|
247 |
logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
|
248 |
# Do not re-raise, to allow the application to start
|
249 |
|
250 |
+
async def _check_llm_cache_needs_migration(self):
|
251 |
+
"""Check if LLM cache data needs migration by examining the first record"""
|
252 |
+
try:
|
253 |
+
# Only query the first record to determine format
|
254 |
+
check_sql = """
|
255 |
+
SELECT id FROM LIGHTRAG_LLM_CACHE
|
256 |
+
ORDER BY create_time ASC
|
257 |
+
LIMIT 1
|
258 |
+
"""
|
259 |
+
result = await self.query(check_sql)
|
260 |
+
|
261 |
+
if result and result.get("id"):
|
262 |
+
# If id doesn't contain colon, it's old format
|
263 |
+
return ":" not in result["id"]
|
264 |
+
|
265 |
+
return False # No data or already new format
|
266 |
+
except Exception as e:
|
267 |
+
logger.warning(f"Failed to check LLM cache migration status: {e}")
|
268 |
+
return False
|
269 |
+
|
270 |
+
async def _migrate_llm_cache_to_flattened_keys(self):
|
271 |
+
"""Migrate LLM cache to flattened key format, recalculating hash values"""
|
272 |
+
try:
|
273 |
+
# Get all old format data
|
274 |
+
old_data_sql = """
|
275 |
+
SELECT id, mode, original_prompt, return_value, chunk_id,
|
276 |
+
create_time, update_time
|
277 |
+
FROM LIGHTRAG_LLM_CACHE
|
278 |
+
WHERE id NOT LIKE '%:%'
|
279 |
+
"""
|
280 |
+
|
281 |
+
old_records = await self.query(old_data_sql, multirows=True)
|
282 |
+
|
283 |
+
if not old_records:
|
284 |
+
logger.info("No old format LLM cache data found, skipping migration")
|
285 |
+
return
|
286 |
+
|
287 |
+
logger.info(
|
288 |
+
f"Found {len(old_records)} old format cache records, starting migration..."
|
289 |
+
)
|
290 |
+
|
291 |
+
# Import hash calculation function
|
292 |
+
from ..utils import compute_args_hash
|
293 |
+
|
294 |
+
migrated_count = 0
|
295 |
+
|
296 |
+
# Migrate data in batches
|
297 |
+
for record in old_records:
|
298 |
+
try:
|
299 |
+
# Recalculate hash using correct method
|
300 |
+
new_hash = compute_args_hash(
|
301 |
+
record["mode"], record["original_prompt"]
|
302 |
+
)
|
303 |
+
|
304 |
+
# Generate new flattened key
|
305 |
+
cache_type = "extract" # Default type
|
306 |
+
new_key = f"{record['mode']}:{cache_type}:{new_hash}"
|
307 |
+
|
308 |
+
# Insert new format data
|
309 |
+
insert_sql = """
|
310 |
+
INSERT INTO LIGHTRAG_LLM_CACHE
|
311 |
+
(workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time)
|
312 |
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
313 |
+
ON CONFLICT (workspace, mode, id) DO NOTHING
|
314 |
+
"""
|
315 |
+
|
316 |
+
await self.execute(
|
317 |
+
insert_sql,
|
318 |
+
{
|
319 |
+
"workspace": self.workspace,
|
320 |
+
"id": new_key,
|
321 |
+
"mode": record["mode"],
|
322 |
+
"original_prompt": record["original_prompt"],
|
323 |
+
"return_value": record["return_value"],
|
324 |
+
"chunk_id": record["chunk_id"],
|
325 |
+
"create_time": record["create_time"],
|
326 |
+
"update_time": record["update_time"],
|
327 |
+
},
|
328 |
+
)
|
329 |
+
|
330 |
+
# Delete old data
|
331 |
+
delete_sql = """
|
332 |
+
DELETE FROM LIGHTRAG_LLM_CACHE
|
333 |
+
WHERE workspace=$1 AND mode=$2 AND id=$3
|
334 |
+
"""
|
335 |
+
await self.execute(
|
336 |
+
delete_sql,
|
337 |
+
{
|
338 |
+
"workspace": self.workspace,
|
339 |
+
"mode": record["mode"],
|
340 |
+
"id": record["id"], # Old id
|
341 |
+
},
|
342 |
+
)
|
343 |
+
|
344 |
+
migrated_count += 1
|
345 |
+
|
346 |
+
except Exception as e:
|
347 |
+
logger.warning(
|
348 |
+
f"Failed to migrate cache record {record['id']}: {e}"
|
349 |
+
)
|
350 |
+
continue
|
351 |
+
|
352 |
+
logger.info(
|
353 |
+
f"Successfully migrated {migrated_count} cache records to flattened format"
|
354 |
+
)
|
355 |
+
|
356 |
+
except Exception as e:
|
357 |
+
logger.error(f"LLM cache migration failed: {e}")
|
358 |
+
# Don't raise exception, allow system to continue startup
|
359 |
+
|
360 |
async def check_tables(self):
|
361 |
# First create all tables
|
362 |
for k, v in TABLES.items():
|
|
|
414 |
except Exception as e:
|
415 |
logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
|
416 |
|
417 |
+
# Check and migrate LLM cache to flattened keys if needed
|
418 |
+
try:
|
419 |
+
if await self._check_llm_cache_needs_migration():
|
420 |
+
await self._migrate_llm_cache_to_flattened_keys()
|
421 |
+
except Exception as e:
|
422 |
+
logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
|
423 |
+
|
424 |
async def query(
|
425 |
self,
|
426 |
sql: str,
|
|
|
603 |
|
604 |
try:
|
605 |
results = await self.db.query(sql, params, multirows=True)
|
606 |
+
|
607 |
+
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
608 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
609 |
+
processed_results = {}
|
610 |
for row in results:
|
611 |
+
# Parse flattened key to extract cache_type
|
612 |
+
key_parts = row["id"].split(":")
|
613 |
+
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
614 |
+
|
615 |
+
# Map field names and add cache_type for compatibility
|
616 |
+
processed_row = {
|
617 |
+
**row,
|
618 |
+
"return": row.get("return_value", ""), # Map return_value to return
|
619 |
+
"cache_type": cache_type, # Add cache_type from key
|
620 |
+
"original_prompt": row.get("original_prompt", ""),
|
621 |
+
"chunk_id": row.get("chunk_id"),
|
622 |
+
"mode": row.get("mode", "default")
|
623 |
+
}
|
624 |
+
processed_results[row["id"]] = processed_row
|
625 |
+
return processed_results
|
626 |
+
|
627 |
+
# For other namespaces, return as-is
|
628 |
+
return {row["id"]: row for row in results}
|
629 |
except Exception as e:
|
630 |
logger.error(f"Error retrieving all data from {self.namespace}: {e}")
|
631 |
return {}
|
632 |
|
633 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
634 |
+
"""Get data by id."""
|
635 |
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
636 |
+
params = {"workspace": self.db.workspace, "id": id}
|
637 |
+
response = await self.db.query(sql, params)
|
638 |
+
return response if response else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
# Query by id
|
641 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
642 |
+
"""Get data by ids"""
|
643 |
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
644 |
ids=",".join([f"'{id}'" for id in ids])
|
645 |
)
|
646 |
params = {"workspace": self.db.workspace}
|
647 |
+
return await self.db.query(sql, params, multirows=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
|
649 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
650 |
"""Specifically for llm_response_cache."""
|
|
|
705 |
}
|
706 |
await self.db.execute(upsert_sql, _data)
|
707 |
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
708 |
+
for k, v in data.items():
|
709 |
+
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
710 |
+
_data = {
|
711 |
+
"workspace": self.db.workspace,
|
712 |
+
"id": k, # Use flattened key as id
|
713 |
+
"original_prompt": v["original_prompt"],
|
714 |
+
"return_value": v["return"],
|
715 |
+
"mode": v.get("mode", "default"), # Get mode from data
|
716 |
+
"chunk_id": v.get("chunk_id"),
|
717 |
+
}
|
|
|
718 |
|
719 |
+
await self.db.execute(upsert_sql, _data)
|
720 |
|
721 |
async def index_done_callback(self) -> None:
|
722 |
# PG handles persistence automatically
|
|
|
1122 |
else:
|
1123 |
exist_keys = []
|
1124 |
new_keys = set([s for s in keys if s not in exist_keys])
|
1125 |
+
# print(f"keys: {keys}")
|
1126 |
+
# print(f"new_keys: {new_keys}")
|
1127 |
return new_keys
|
1128 |
except Exception as e:
|
1129 |
logger.error(
|
|
|
2708 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
2709 |
""",
|
2710 |
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
2711 |
+
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
|
2712 |
""",
|
2713 |
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
|
2714 |
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
|
lightrag/kg/qdrant_impl.py
CHANGED
@@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|
85 |
)
|
86 |
|
87 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
88 |
-
logger.
|
89 |
if not data:
|
90 |
return
|
91 |
|
|
|
85 |
)
|
86 |
|
87 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
88 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
89 |
if not data:
|
90 |
return
|
91 |
|
lightrag/kg/redis_impl.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import os
|
2 |
-
from typing import Any, final
|
3 |
from dataclasses import dataclass
|
4 |
import pipmaster as pm
|
5 |
import configparser
|
6 |
from contextlib import asynccontextmanager
|
|
|
7 |
|
8 |
if not pm.is_installed("redis"):
|
9 |
pm.install("redis")
|
@@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
|
13 |
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
14 |
from lightrag.utils import logger
|
15 |
|
16 |
-
from lightrag.base import BaseKVStorage
|
17 |
import json
|
18 |
|
19 |
|
@@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0
|
|
26 |
SOCKET_CONNECT_TIMEOUT = 3.0
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
@final
|
30 |
@dataclass
|
31 |
class RedisKVStorage(BaseKVStorage):
|
@@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage):
|
|
33 |
redis_url = os.environ.get(
|
34 |
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
35 |
)
|
36 |
-
#
|
37 |
-
self._pool =
|
38 |
-
redis_url,
|
39 |
-
max_connections=MAX_CONNECTIONS,
|
40 |
-
decode_responses=True,
|
41 |
-
socket_timeout=SOCKET_TIMEOUT,
|
42 |
-
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
43 |
-
)
|
44 |
self._redis = Redis(connection_pool=self._pool)
|
45 |
logger.info(
|
46 |
-
f"Initialized Redis
|
47 |
)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
@asynccontextmanager
|
50 |
async def _get_redis_connection(self):
|
51 |
"""Safe context manager for Redis operations."""
|
@@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage):
|
|
99 |
logger.error(f"JSON decode error in batch get: {e}")
|
100 |
return [None] * len(ids)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
103 |
async with self._get_redis_connection() as redis:
|
104 |
pipe = redis.pipeline()
|
105 |
-
|
|
|
106 |
pipe.exists(f"{self.namespace}:{key}")
|
107 |
results = await pipe.execute()
|
108 |
|
109 |
-
existing_ids = {
|
110 |
return set(keys) - existing_ids
|
111 |
|
112 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
113 |
if not data:
|
114 |
return
|
115 |
-
|
116 |
-
logger.info(f"Inserting {len(data)} items to {self.namespace}")
|
117 |
async with self._get_redis_connection() as redis:
|
118 |
try:
|
119 |
pipe = redis.pipeline()
|
@@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage):
|
|
148 |
)
|
149 |
|
150 |
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
151 |
-
"""Delete specific records from storage by
|
152 |
|
153 |
Importance notes for Redis storage:
|
154 |
1. This will immediately delete the specified cache modes from Redis
|
155 |
|
156 |
Args:
|
157 |
-
modes (list[str]): List of cache
|
158 |
|
159 |
Returns:
|
160 |
True: if the cache drop successfully
|
@@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage):
|
|
164 |
return False
|
165 |
|
166 |
try:
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
return True
|
169 |
-
except Exception:
|
|
|
170 |
return False
|
171 |
|
172 |
async def drop(self) -> dict[str, str]:
|
@@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage):
|
|
177 |
"""
|
178 |
async with self._get_redis_connection() as redis:
|
179 |
try:
|
180 |
-
keys
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
except Exception as e:
|
199 |
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
200 |
return {"status": "error", "message": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import Any, final, Union
|
3 |
from dataclasses import dataclass
|
4 |
import pipmaster as pm
|
5 |
import configparser
|
6 |
from contextlib import asynccontextmanager
|
7 |
+
import threading
|
8 |
|
9 |
if not pm.is_installed("redis"):
|
10 |
pm.install("redis")
|
|
|
14 |
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
15 |
from lightrag.utils import logger
|
16 |
|
17 |
+
from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
|
18 |
import json
|
19 |
|
20 |
|
|
|
27 |
SOCKET_CONNECT_TIMEOUT = 3.0
|
28 |
|
29 |
|
30 |
+
class RedisConnectionManager:
|
31 |
+
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
32 |
+
|
33 |
+
_pools = {}
|
34 |
+
_lock = threading.Lock()
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
38 |
+
"""Get or create a connection pool for the given Redis URL"""
|
39 |
+
if redis_url not in cls._pools:
|
40 |
+
with cls._lock:
|
41 |
+
if redis_url not in cls._pools:
|
42 |
+
cls._pools[redis_url] = ConnectionPool.from_url(
|
43 |
+
redis_url,
|
44 |
+
max_connections=MAX_CONNECTIONS,
|
45 |
+
decode_responses=True,
|
46 |
+
socket_timeout=SOCKET_TIMEOUT,
|
47 |
+
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
48 |
+
)
|
49 |
+
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
50 |
+
return cls._pools[redis_url]
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def close_all_pools(cls):
|
54 |
+
"""Close all connection pools (for cleanup)"""
|
55 |
+
with cls._lock:
|
56 |
+
for url, pool in cls._pools.items():
|
57 |
+
try:
|
58 |
+
pool.disconnect()
|
59 |
+
logger.info(f"Closed Redis connection pool for {url}")
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Error closing Redis pool for {url}: {e}")
|
62 |
+
cls._pools.clear()
|
63 |
+
|
64 |
+
|
65 |
@final
|
66 |
@dataclass
|
67 |
class RedisKVStorage(BaseKVStorage):
|
|
|
69 |
redis_url = os.environ.get(
|
70 |
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
71 |
)
|
72 |
+
# Use shared connection pool
|
73 |
+
self._pool = RedisConnectionManager.get_pool(redis_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
self._redis = Redis(connection_pool=self._pool)
|
75 |
logger.info(
|
76 |
+
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
|
77 |
)
|
78 |
|
79 |
+
async def initialize(self):
|
80 |
+
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
81 |
+
# Test connection
|
82 |
+
try:
|
83 |
+
async with self._get_redis_connection() as redis:
|
84 |
+
await redis.ping()
|
85 |
+
logger.info(f"Connected to Redis for namespace {self.namespace}")
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Failed to connect to Redis: {e}")
|
88 |
+
raise
|
89 |
+
|
90 |
+
# Migrate legacy cache structure if this is a cache namespace
|
91 |
+
if self.namespace.endswith("_cache"):
|
92 |
+
await self._migrate_legacy_cache_structure()
|
93 |
+
|
94 |
@asynccontextmanager
|
95 |
async def _get_redis_connection(self):
|
96 |
"""Safe context manager for Redis operations."""
|
|
|
144 |
logger.error(f"JSON decode error in batch get: {e}")
|
145 |
return [None] * len(ids)
|
146 |
|
147 |
+
async def get_all(self) -> dict[str, Any]:
|
148 |
+
"""Get all data from storage
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Dictionary containing all stored data
|
152 |
+
"""
|
153 |
+
async with self._get_redis_connection() as redis:
|
154 |
+
try:
|
155 |
+
# Get all keys for this namespace
|
156 |
+
keys = await redis.keys(f"{self.namespace}:*")
|
157 |
+
|
158 |
+
if not keys:
|
159 |
+
return {}
|
160 |
+
|
161 |
+
# Get all values in batch
|
162 |
+
pipe = redis.pipeline()
|
163 |
+
for key in keys:
|
164 |
+
pipe.get(key)
|
165 |
+
values = await pipe.execute()
|
166 |
+
|
167 |
+
# Build result dictionary
|
168 |
+
result = {}
|
169 |
+
for key, value in zip(keys, values):
|
170 |
+
if value:
|
171 |
+
# Extract the ID part (after namespace:)
|
172 |
+
key_id = key.split(":", 1)[1]
|
173 |
+
try:
|
174 |
+
result[key_id] = json.loads(value)
|
175 |
+
except json.JSONDecodeError as e:
|
176 |
+
logger.error(f"JSON decode error for key {key}: {e}")
|
177 |
+
continue
|
178 |
+
|
179 |
+
return result
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error getting all data from Redis: {e}")
|
182 |
+
return {}
|
183 |
+
|
184 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
185 |
async with self._get_redis_connection() as redis:
|
186 |
pipe = redis.pipeline()
|
187 |
+
keys_list = list(keys) # Convert set to list for indexing
|
188 |
+
for key in keys_list:
|
189 |
pipe.exists(f"{self.namespace}:{key}")
|
190 |
results = await pipe.execute()
|
191 |
|
192 |
+
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
193 |
return set(keys) - existing_ids
|
194 |
|
195 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
196 |
if not data:
|
197 |
return
|
|
|
|
|
198 |
async with self._get_redis_connection() as redis:
|
199 |
try:
|
200 |
pipe = redis.pipeline()
|
|
|
229 |
)
|
230 |
|
231 |
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
232 |
+
"""Delete specific records from storage by cache mode
|
233 |
|
234 |
Importance notes for Redis storage:
|
235 |
1. This will immediately delete the specified cache modes from Redis
|
236 |
|
237 |
Args:
|
238 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
239 |
|
240 |
Returns:
|
241 |
True: if the cache drop successfully
|
|
|
245 |
return False
|
246 |
|
247 |
try:
|
248 |
+
async with self._get_redis_connection() as redis:
|
249 |
+
keys_to_delete = []
|
250 |
+
|
251 |
+
# Find matching keys for each mode using SCAN
|
252 |
+
for mode in modes:
|
253 |
+
# Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
|
254 |
+
pattern = f"{self.namespace}:{mode}:*"
|
255 |
+
cursor = 0
|
256 |
+
mode_keys = []
|
257 |
+
|
258 |
+
while True:
|
259 |
+
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
260 |
+
if keys:
|
261 |
+
mode_keys.extend(keys)
|
262 |
+
|
263 |
+
if cursor == 0:
|
264 |
+
break
|
265 |
+
|
266 |
+
keys_to_delete.extend(mode_keys)
|
267 |
+
logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
|
268 |
+
|
269 |
+
if keys_to_delete:
|
270 |
+
# Batch delete
|
271 |
+
pipe = redis.pipeline()
|
272 |
+
for key in keys_to_delete:
|
273 |
+
pipe.delete(key)
|
274 |
+
results = await pipe.execute()
|
275 |
+
deleted_count = sum(results)
|
276 |
+
logger.info(
|
277 |
+
f"Dropped {deleted_count} cache entries for modes: {modes}"
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
logger.warning(f"No cache entries found for modes: {modes}")
|
281 |
+
|
282 |
return True
|
283 |
+
except Exception as e:
|
284 |
+
logger.error(f"Error dropping cache by modes in Redis: {e}")
|
285 |
return False
|
286 |
|
287 |
async def drop(self) -> dict[str, str]:
|
|
|
292 |
"""
|
293 |
async with self._get_redis_connection() as redis:
|
294 |
try:
|
295 |
+
# Use SCAN to find all keys with the namespace prefix
|
296 |
+
pattern = f"{self.namespace}:*"
|
297 |
+
cursor = 0
|
298 |
+
deleted_count = 0
|
299 |
+
|
300 |
+
while True:
|
301 |
+
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
302 |
+
if keys:
|
303 |
+
# Delete keys in batches
|
304 |
+
pipe = redis.pipeline()
|
305 |
+
for key in keys:
|
306 |
+
pipe.delete(key)
|
307 |
+
results = await pipe.execute()
|
308 |
+
deleted_count += sum(results)
|
309 |
+
|
310 |
+
if cursor == 0:
|
311 |
+
break
|
312 |
+
|
313 |
+
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
314 |
+
return {
|
315 |
+
"status": "success",
|
316 |
+
"message": f"{deleted_count} keys dropped",
|
317 |
+
}
|
318 |
|
319 |
except Exception as e:
|
320 |
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
321 |
return {"status": "error", "message": str(e)}
|
322 |
+
|
323 |
+
async def _migrate_legacy_cache_structure(self):
|
324 |
+
"""Migrate legacy nested cache structure to flattened structure for Redis
|
325 |
+
|
326 |
+
Redis already stores data in a flattened way, but we need to check for
|
327 |
+
legacy keys that might contain nested JSON structures and migrate them.
|
328 |
+
|
329 |
+
Early exit if any flattened key is found (indicating migration already done).
|
330 |
+
"""
|
331 |
+
from lightrag.utils import generate_cache_key
|
332 |
+
|
333 |
+
async with self._get_redis_connection() as redis:
|
334 |
+
# Get all keys for this namespace
|
335 |
+
keys = await redis.keys(f"{self.namespace}:*")
|
336 |
+
|
337 |
+
if not keys:
|
338 |
+
return
|
339 |
+
|
340 |
+
# Check if we have any flattened keys already - if so, skip migration
|
341 |
+
has_flattened_keys = False
|
342 |
+
keys_to_migrate = []
|
343 |
+
|
344 |
+
for key in keys:
|
345 |
+
# Extract the ID part (after namespace:)
|
346 |
+
key_id = key.split(":", 1)[1]
|
347 |
+
|
348 |
+
# Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
|
349 |
+
if ":" in key_id and len(key_id.split(":")) == 3:
|
350 |
+
has_flattened_keys = True
|
351 |
+
break # Early exit - migration already done
|
352 |
+
|
353 |
+
# Get the data to check if it's a legacy nested structure
|
354 |
+
data = await redis.get(key)
|
355 |
+
if data:
|
356 |
+
try:
|
357 |
+
parsed_data = json.loads(data)
|
358 |
+
# Check if this looks like a legacy cache mode with nested structure
|
359 |
+
if isinstance(parsed_data, dict) and all(
|
360 |
+
isinstance(v, dict) and "return" in v
|
361 |
+
for v in parsed_data.values()
|
362 |
+
):
|
363 |
+
keys_to_migrate.append((key, key_id, parsed_data))
|
364 |
+
except json.JSONDecodeError:
|
365 |
+
continue
|
366 |
+
|
367 |
+
# If we found any flattened keys, assume migration is already done
|
368 |
+
if has_flattened_keys:
|
369 |
+
logger.debug(
|
370 |
+
f"Found flattened cache keys in {self.namespace}, skipping migration"
|
371 |
+
)
|
372 |
+
return
|
373 |
+
|
374 |
+
if not keys_to_migrate:
|
375 |
+
return
|
376 |
+
|
377 |
+
# Perform migration
|
378 |
+
pipe = redis.pipeline()
|
379 |
+
migration_count = 0
|
380 |
+
|
381 |
+
for old_key, mode, nested_data in keys_to_migrate:
|
382 |
+
# Delete the old key
|
383 |
+
pipe.delete(old_key)
|
384 |
+
|
385 |
+
# Create new flattened keys
|
386 |
+
for cache_hash, cache_entry in nested_data.items():
|
387 |
+
cache_type = cache_entry.get("cache_type", "extract")
|
388 |
+
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
|
389 |
+
full_key = f"{self.namespace}:{flattened_key}"
|
390 |
+
pipe.set(full_key, json.dumps(cache_entry))
|
391 |
+
migration_count += 1
|
392 |
+
|
393 |
+
await pipe.execute()
|
394 |
+
|
395 |
+
if migration_count > 0:
|
396 |
+
logger.info(
|
397 |
+
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
|
398 |
+
)
|
399 |
+
|
400 |
+
|
401 |
+
@final
|
402 |
+
@dataclass
|
403 |
+
class RedisDocStatusStorage(DocStatusStorage):
|
404 |
+
"""Redis implementation of document status storage"""
|
405 |
+
|
406 |
+
def __post_init__(self):
|
407 |
+
redis_url = os.environ.get(
|
408 |
+
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
|
409 |
+
)
|
410 |
+
# Use shared connection pool
|
411 |
+
self._pool = RedisConnectionManager.get_pool(redis_url)
|
412 |
+
self._redis = Redis(connection_pool=self._pool)
|
413 |
+
logger.info(
|
414 |
+
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
|
415 |
+
)
|
416 |
+
|
417 |
+
async def initialize(self):
|
418 |
+
"""Initialize Redis connection"""
|
419 |
+
try:
|
420 |
+
async with self._get_redis_connection() as redis:
|
421 |
+
await redis.ping()
|
422 |
+
logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
|
423 |
+
except Exception as e:
|
424 |
+
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
425 |
+
raise
|
426 |
+
|
427 |
+
@asynccontextmanager
|
428 |
+
async def _get_redis_connection(self):
|
429 |
+
"""Safe context manager for Redis operations."""
|
430 |
+
try:
|
431 |
+
yield self._redis
|
432 |
+
except ConnectionError as e:
|
433 |
+
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
|
434 |
+
raise
|
435 |
+
except RedisError as e:
|
436 |
+
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
|
437 |
+
raise
|
438 |
+
except Exception as e:
|
439 |
+
logger.error(
|
440 |
+
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
|
441 |
+
)
|
442 |
+
raise
|
443 |
+
|
444 |
+
async def close(self):
|
445 |
+
"""Close the Redis connection."""
|
446 |
+
if hasattr(self, "_redis") and self._redis:
|
447 |
+
await self._redis.close()
|
448 |
+
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
|
449 |
+
|
450 |
+
async def __aenter__(self):
|
451 |
+
"""Support for async context manager."""
|
452 |
+
return self
|
453 |
+
|
454 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
455 |
+
"""Ensure Redis resources are cleaned up when exiting context."""
|
456 |
+
await self.close()
|
457 |
+
|
458 |
+
async def filter_keys(self, keys: set[str]) -> set[str]:
|
459 |
+
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
460 |
+
async with self._get_redis_connection() as redis:
|
461 |
+
pipe = redis.pipeline()
|
462 |
+
keys_list = list(keys)
|
463 |
+
for key in keys_list:
|
464 |
+
pipe.exists(f"{self.namespace}:{key}")
|
465 |
+
results = await pipe.execute()
|
466 |
+
|
467 |
+
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
|
468 |
+
return set(keys) - existing_ids
|
469 |
+
|
470 |
+
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
471 |
+
result: list[dict[str, Any]] = []
|
472 |
+
async with self._get_redis_connection() as redis:
|
473 |
+
try:
|
474 |
+
pipe = redis.pipeline()
|
475 |
+
for id in ids:
|
476 |
+
pipe.get(f"{self.namespace}:{id}")
|
477 |
+
results = await pipe.execute()
|
478 |
+
|
479 |
+
for result_data in results:
|
480 |
+
if result_data:
|
481 |
+
try:
|
482 |
+
result.append(json.loads(result_data))
|
483 |
+
except json.JSONDecodeError as e:
|
484 |
+
logger.error(f"JSON decode error in get_by_ids: {e}")
|
485 |
+
continue
|
486 |
+
except Exception as e:
|
487 |
+
logger.error(f"Error in get_by_ids: {e}")
|
488 |
+
return result
|
489 |
+
|
490 |
+
async def get_status_counts(self) -> dict[str, int]:
|
491 |
+
"""Get counts of documents in each status"""
|
492 |
+
counts = {status.value: 0 for status in DocStatus}
|
493 |
+
async with self._get_redis_connection() as redis:
|
494 |
+
try:
|
495 |
+
# Use SCAN to iterate through all keys in the namespace
|
496 |
+
cursor = 0
|
497 |
+
while True:
|
498 |
+
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
499 |
+
if keys:
|
500 |
+
# Get all values in batch
|
501 |
+
pipe = redis.pipeline()
|
502 |
+
for key in keys:
|
503 |
+
pipe.get(key)
|
504 |
+
values = await pipe.execute()
|
505 |
+
|
506 |
+
# Count statuses
|
507 |
+
for value in values:
|
508 |
+
if value:
|
509 |
+
try:
|
510 |
+
doc_data = json.loads(value)
|
511 |
+
status = doc_data.get("status")
|
512 |
+
if status in counts:
|
513 |
+
counts[status] += 1
|
514 |
+
except json.JSONDecodeError:
|
515 |
+
continue
|
516 |
+
|
517 |
+
if cursor == 0:
|
518 |
+
break
|
519 |
+
except Exception as e:
|
520 |
+
logger.error(f"Error getting status counts: {e}")
|
521 |
+
|
522 |
+
return counts
|
523 |
+
|
524 |
+
async def get_docs_by_status(
|
525 |
+
self, status: DocStatus
|
526 |
+
) -> dict[str, DocProcessingStatus]:
|
527 |
+
"""Get all documents with a specific status"""
|
528 |
+
result = {}
|
529 |
+
async with self._get_redis_connection() as redis:
|
530 |
+
try:
|
531 |
+
# Use SCAN to iterate through all keys in the namespace
|
532 |
+
cursor = 0
|
533 |
+
while True:
|
534 |
+
cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
|
535 |
+
if keys:
|
536 |
+
# Get all values in batch
|
537 |
+
pipe = redis.pipeline()
|
538 |
+
for key in keys:
|
539 |
+
pipe.get(key)
|
540 |
+
values = await pipe.execute()
|
541 |
+
|
542 |
+
# Filter by status and create DocProcessingStatus objects
|
543 |
+
for key, value in zip(keys, values):
|
544 |
+
if value:
|
545 |
+
try:
|
546 |
+
doc_data = json.loads(value)
|
547 |
+
if doc_data.get("status") == status.value:
|
548 |
+
# Extract document ID from key
|
549 |
+
doc_id = key.split(":", 1)[1]
|
550 |
+
|
551 |
+
# Make a copy of the data to avoid modifying the original
|
552 |
+
data = doc_data.copy()
|
553 |
+
# If content is missing, use content_summary as content
|
554 |
+
if "content" not in data and "content_summary" in data:
|
555 |
+
data["content"] = data["content_summary"]
|
556 |
+
# If file_path is not in data, use document id as file path
|
557 |
+
if "file_path" not in data:
|
558 |
+
data["file_path"] = "no-file-path"
|
559 |
+
|
560 |
+
result[doc_id] = DocProcessingStatus(**data)
|
561 |
+
except (json.JSONDecodeError, KeyError) as e:
|
562 |
+
logger.error(f"Error processing document {key}: {e}")
|
563 |
+
continue
|
564 |
+
|
565 |
+
if cursor == 0:
|
566 |
+
break
|
567 |
+
except Exception as e:
|
568 |
+
logger.error(f"Error getting docs by status: {e}")
|
569 |
+
|
570 |
+
return result
|
571 |
+
|
572 |
+
async def index_done_callback(self) -> None:
|
573 |
+
"""Redis handles persistence automatically"""
|
574 |
+
pass
|
575 |
+
|
576 |
+
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
577 |
+
"""Insert or update document status data"""
|
578 |
+
if not data:
|
579 |
+
return
|
580 |
+
|
581 |
+
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
582 |
+
async with self._get_redis_connection() as redis:
|
583 |
+
try:
|
584 |
+
pipe = redis.pipeline()
|
585 |
+
for k, v in data.items():
|
586 |
+
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
587 |
+
await pipe.execute()
|
588 |
+
except json.JSONEncodeError as e:
|
589 |
+
logger.error(f"JSON encode error during upsert: {e}")
|
590 |
+
raise
|
591 |
+
|
592 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
593 |
+
async with self._get_redis_connection() as redis:
|
594 |
+
try:
|
595 |
+
data = await redis.get(f"{self.namespace}:{id}")
|
596 |
+
return json.loads(data) if data else None
|
597 |
+
except json.JSONDecodeError as e:
|
598 |
+
logger.error(f"JSON decode error for id {id}: {e}")
|
599 |
+
return None
|
600 |
+
|
601 |
+
async def delete(self, doc_ids: list[str]) -> None:
|
602 |
+
"""Delete specific records from storage by their IDs"""
|
603 |
+
if not doc_ids:
|
604 |
+
return
|
605 |
+
|
606 |
+
async with self._get_redis_connection() as redis:
|
607 |
+
pipe = redis.pipeline()
|
608 |
+
for doc_id in doc_ids:
|
609 |
+
pipe.delete(f"{self.namespace}:{doc_id}")
|
610 |
+
|
611 |
+
results = await pipe.execute()
|
612 |
+
deleted_count = sum(results)
|
613 |
+
logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
|
614 |
+
|
615 |
+
async def drop(self) -> dict[str, str]:
|
616 |
+
"""Drop all document status data from storage and clean up resources"""
|
617 |
+
try:
|
618 |
+
async with self._get_redis_connection() as redis:
|
619 |
+
# Use SCAN to find all keys with the namespace prefix
|
620 |
+
pattern = f"{self.namespace}:*"
|
621 |
+
cursor = 0
|
622 |
+
deleted_count = 0
|
623 |
+
|
624 |
+
while True:
|
625 |
+
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
626 |
+
if keys:
|
627 |
+
# Delete keys in batches
|
628 |
+
pipe = redis.pipeline()
|
629 |
+
for key in keys:
|
630 |
+
pipe.delete(key)
|
631 |
+
results = await pipe.execute()
|
632 |
+
deleted_count += sum(results)
|
633 |
+
|
634 |
+
if cursor == 0:
|
635 |
+
break
|
636 |
+
|
637 |
+
logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
|
638 |
+
return {"status": "success", "message": "data dropped"}
|
639 |
+
except Exception as e:
|
640 |
+
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
641 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|
257 |
|
258 |
################ INSERT full_doc AND chunks ################
|
259 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
260 |
-
logger.
|
261 |
if not data:
|
262 |
return
|
263 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
@@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
454 |
|
455 |
###### INSERT entities And relationships ######
|
456 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
457 |
-
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
458 |
if not data:
|
459 |
return
|
460 |
-
|
461 |
-
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
462 |
|
463 |
# Get current time as UNIX timestamp
|
464 |
import time
|
|
|
257 |
|
258 |
################ INSERT full_doc AND chunks ################
|
259 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
260 |
+
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
261 |
if not data:
|
262 |
return
|
263 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
|
|
454 |
|
455 |
###### INSERT entities And relationships ######
|
456 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
457 |
if not data:
|
458 |
return
|
459 |
+
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
|
|
460 |
|
461 |
# Get current time as UNIX timestamp
|
462 |
import time
|
lightrag/operate.py
CHANGED
@@ -399,10 +399,10 @@ async def _get_cached_extraction_results(
|
|
399 |
"""
|
400 |
cached_results = {}
|
401 |
|
402 |
-
# Get all cached data
|
403 |
-
|
404 |
|
405 |
-
for cache_key, cache_entry in
|
406 |
if (
|
407 |
isinstance(cache_entry, dict)
|
408 |
and cache_entry.get("cache_type") == "extract"
|
@@ -1387,7 +1387,7 @@ async def kg_query(
|
|
1387 |
use_model_func = partial(use_model_func, _priority=5)
|
1388 |
|
1389 |
# Handle cache
|
1390 |
-
args_hash = compute_args_hash(query_param.mode, query
|
1391 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1392 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
1393 |
)
|
@@ -1546,7 +1546,7 @@ async def extract_keywords_only(
|
|
1546 |
"""
|
1547 |
|
1548 |
# 1. Handle cache if needed - add cache type for keywords
|
1549 |
-
args_hash = compute_args_hash(param.mode, text
|
1550 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1551 |
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
1552 |
)
|
@@ -2413,7 +2413,7 @@ async def naive_query(
|
|
2413 |
use_model_func = partial(use_model_func, _priority=5)
|
2414 |
|
2415 |
# Handle cache
|
2416 |
-
args_hash = compute_args_hash(query_param.mode, query
|
2417 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
2418 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
2419 |
)
|
@@ -2529,7 +2529,7 @@ async def kg_query_with_keywords(
|
|
2529 |
# Apply higher priority (5) to query relation LLM function
|
2530 |
use_model_func = partial(use_model_func, _priority=5)
|
2531 |
|
2532 |
-
args_hash = compute_args_hash(query_param.mode, query
|
2533 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
2534 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
2535 |
)
|
|
|
399 |
"""
|
400 |
cached_results = {}
|
401 |
|
402 |
+
# Get all cached data (flattened cache structure)
|
403 |
+
all_cache = await llm_response_cache.get_all()
|
404 |
|
405 |
+
for cache_key, cache_entry in all_cache.items():
|
406 |
if (
|
407 |
isinstance(cache_entry, dict)
|
408 |
and cache_entry.get("cache_type") == "extract"
|
|
|
1387 |
use_model_func = partial(use_model_func, _priority=5)
|
1388 |
|
1389 |
# Handle cache
|
1390 |
+
args_hash = compute_args_hash(query_param.mode, query)
|
1391 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1392 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
1393 |
)
|
|
|
1546 |
"""
|
1547 |
|
1548 |
# 1. Handle cache if needed - add cache type for keywords
|
1549 |
+
args_hash = compute_args_hash(param.mode, text)
|
1550 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
1551 |
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
1552 |
)
|
|
|
2413 |
use_model_func = partial(use_model_func, _priority=5)
|
2414 |
|
2415 |
# Handle cache
|
2416 |
+
args_hash = compute_args_hash(query_param.mode, query)
|
2417 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
2418 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
2419 |
)
|
|
|
2529 |
# Apply higher priority (5) to query relation LLM function
|
2530 |
use_model_func = partial(use_model_func, _priority=5)
|
2531 |
|
2532 |
+
args_hash = compute_args_hash(query_param.mode, query)
|
2533 |
cached_response, quantized, min_val, max_val = await handle_cache(
|
2534 |
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
2535 |
)
|
lightrag/utils.py
CHANGED
@@ -14,7 +14,6 @@ from functools import wraps
|
|
14 |
from hashlib import md5
|
15 |
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
16 |
import numpy as np
|
17 |
-
from lightrag.prompt import PROMPTS
|
18 |
from dotenv import load_dotenv
|
19 |
from lightrag.constants import (
|
20 |
DEFAULT_LOG_MAX_BYTES,
|
@@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
|
278 |
raise e from None
|
279 |
|
280 |
|
281 |
-
def compute_args_hash(*args: Any
|
282 |
"""Compute a hash for the given arguments.
|
283 |
Args:
|
284 |
*args: Arguments to hash
|
285 |
-
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
286 |
Returns:
|
287 |
str: Hash string
|
288 |
"""
|
@@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
|
290 |
|
291 |
# Convert all arguments to strings and join them
|
292 |
args_str = "".join([str(arg) for arg in args])
|
293 |
-
if cache_type:
|
294 |
-
args_str = f"{cache_type}:{args_str}"
|
295 |
|
296 |
# Compute MD5 hash
|
297 |
return hashlib.md5(args_str.encode()).hexdigest()
|
298 |
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
301 |
"""
|
302 |
Compute a unique ID for a given content string.
|
@@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
|
|
783 |
return combined_data
|
784 |
|
785 |
|
786 |
-
async def get_best_cached_response(
|
787 |
-
hashing_kv,
|
788 |
-
current_embedding,
|
789 |
-
similarity_threshold=0.95,
|
790 |
-
mode="default",
|
791 |
-
use_llm_check=False,
|
792 |
-
llm_func=None,
|
793 |
-
original_prompt=None,
|
794 |
-
cache_type=None,
|
795 |
-
) -> str | None:
|
796 |
-
logger.debug(
|
797 |
-
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
798 |
-
)
|
799 |
-
mode_cache = await hashing_kv.get_by_id(mode)
|
800 |
-
if not mode_cache:
|
801 |
-
return None
|
802 |
-
|
803 |
-
best_similarity = -1
|
804 |
-
best_response = None
|
805 |
-
best_prompt = None
|
806 |
-
best_cache_id = None
|
807 |
-
|
808 |
-
# Only iterate through cache entries for this mode
|
809 |
-
for cache_id, cache_data in mode_cache.items():
|
810 |
-
# Skip if cache_type doesn't match
|
811 |
-
if cache_type and cache_data.get("cache_type") != cache_type:
|
812 |
-
continue
|
813 |
-
|
814 |
-
# Check if cache data is valid
|
815 |
-
if cache_data["embedding"] is None:
|
816 |
-
continue
|
817 |
-
|
818 |
-
try:
|
819 |
-
# Safely convert cached embedding
|
820 |
-
cached_quantized = np.frombuffer(
|
821 |
-
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
822 |
-
).reshape(cache_data["embedding_shape"])
|
823 |
-
|
824 |
-
# Ensure min_val and max_val are valid float values
|
825 |
-
embedding_min = cache_data.get("embedding_min")
|
826 |
-
embedding_max = cache_data.get("embedding_max")
|
827 |
-
|
828 |
-
if (
|
829 |
-
embedding_min is None
|
830 |
-
or embedding_max is None
|
831 |
-
or embedding_min >= embedding_max
|
832 |
-
):
|
833 |
-
logger.warning(
|
834 |
-
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
|
835 |
-
)
|
836 |
-
continue
|
837 |
-
|
838 |
-
cached_embedding = dequantize_embedding(
|
839 |
-
cached_quantized,
|
840 |
-
embedding_min,
|
841 |
-
embedding_max,
|
842 |
-
)
|
843 |
-
except Exception as e:
|
844 |
-
logger.warning(f"Error processing cached embedding: {str(e)}")
|
845 |
-
continue
|
846 |
-
|
847 |
-
similarity = cosine_similarity(current_embedding, cached_embedding)
|
848 |
-
if similarity > best_similarity:
|
849 |
-
best_similarity = similarity
|
850 |
-
best_response = cache_data["return"]
|
851 |
-
best_prompt = cache_data["original_prompt"]
|
852 |
-
best_cache_id = cache_id
|
853 |
-
|
854 |
-
if best_similarity > similarity_threshold:
|
855 |
-
# If LLM check is enabled and all required parameters are provided
|
856 |
-
if (
|
857 |
-
use_llm_check
|
858 |
-
and llm_func
|
859 |
-
and original_prompt
|
860 |
-
and best_prompt
|
861 |
-
and best_response is not None
|
862 |
-
):
|
863 |
-
compare_prompt = PROMPTS["similarity_check"].format(
|
864 |
-
original_prompt=original_prompt, cached_prompt=best_prompt
|
865 |
-
)
|
866 |
-
|
867 |
-
try:
|
868 |
-
llm_result = await llm_func(compare_prompt)
|
869 |
-
llm_result = llm_result.strip()
|
870 |
-
llm_similarity = float(llm_result)
|
871 |
-
|
872 |
-
# Replace vector similarity with LLM similarity score
|
873 |
-
best_similarity = llm_similarity
|
874 |
-
if best_similarity < similarity_threshold:
|
875 |
-
log_data = {
|
876 |
-
"event": "cache_rejected_by_llm",
|
877 |
-
"type": cache_type,
|
878 |
-
"mode": mode,
|
879 |
-
"original_question": original_prompt[:100] + "..."
|
880 |
-
if len(original_prompt) > 100
|
881 |
-
else original_prompt,
|
882 |
-
"cached_question": best_prompt[:100] + "..."
|
883 |
-
if len(best_prompt) > 100
|
884 |
-
else best_prompt,
|
885 |
-
"similarity_score": round(best_similarity, 4),
|
886 |
-
"threshold": similarity_threshold,
|
887 |
-
}
|
888 |
-
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
889 |
-
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
890 |
-
return None
|
891 |
-
except Exception as e: # Catch all possible exceptions
|
892 |
-
logger.warning(f"LLM similarity check failed: {e}")
|
893 |
-
return None # Return None directly when LLM check fails
|
894 |
-
|
895 |
-
prompt_display = (
|
896 |
-
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
897 |
-
)
|
898 |
-
log_data = {
|
899 |
-
"event": "cache_hit",
|
900 |
-
"type": cache_type,
|
901 |
-
"mode": mode,
|
902 |
-
"similarity": round(best_similarity, 4),
|
903 |
-
"cache_id": best_cache_id,
|
904 |
-
"original_prompt": prompt_display,
|
905 |
-
}
|
906 |
-
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
907 |
-
return best_response
|
908 |
-
return None
|
909 |
-
|
910 |
-
|
911 |
def cosine_similarity(v1, v2):
|
912 |
"""Calculate cosine similarity between two vectors"""
|
913 |
dot_product = np.dot(v1, v2)
|
@@ -957,7 +857,7 @@ async def handle_cache(
|
|
957 |
mode="default",
|
958 |
cache_type=None,
|
959 |
):
|
960 |
-
"""Generic cache handling function"""
|
961 |
if hashing_kv is None:
|
962 |
return None, None, None, None
|
963 |
|
@@ -968,15 +868,14 @@ async def handle_cache(
|
|
968 |
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
969 |
return None, None, None, None
|
970 |
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
return mode_cache[args_hash]["return"], None, None, None
|
978 |
|
979 |
-
logger.debug(f"
|
980 |
return None, None, None, None
|
981 |
|
982 |
|
@@ -994,7 +893,7 @@ class CacheData:
|
|
994 |
|
995 |
|
996 |
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
997 |
-
"""Save data to cache
|
998 |
|
999 |
Args:
|
1000 |
hashing_kv: The key-value storage for caching
|
@@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|
1009 |
logger.debug("Streaming response detected, skipping cache")
|
1010 |
return
|
1011 |
|
1012 |
-
#
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
or {}
|
1017 |
-
)
|
1018 |
-
else:
|
1019 |
-
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
1020 |
|
1021 |
# Check if we already have identical content cached
|
1022 |
-
|
1023 |
-
|
|
|
1024 |
if existing_content == cache_data.content:
|
1025 |
-
logger.info(
|
1026 |
-
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
1027 |
-
)
|
1028 |
return
|
1029 |
|
1030 |
-
#
|
1031 |
-
|
1032 |
"return": cache_data.content,
|
1033 |
"cache_type": cache_data.cache_type,
|
1034 |
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
@@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|
1043 |
"original_prompt": cache_data.prompt,
|
1044 |
}
|
1045 |
|
1046 |
-
logger.info(f" == LLM cache == saving
|
1047 |
|
1048 |
-
#
|
1049 |
-
await hashing_kv.upsert({
|
1050 |
|
1051 |
|
1052 |
def safe_unicode_decode(content):
|
|
|
14 |
from hashlib import md5
|
15 |
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
16 |
import numpy as np
|
|
|
17 |
from dotenv import load_dotenv
|
18 |
from lightrag.constants import (
|
19 |
DEFAULT_LOG_MAX_BYTES,
|
|
|
277 |
raise e from None
|
278 |
|
279 |
|
280 |
+
def compute_args_hash(*args: Any) -> str:
|
281 |
"""Compute a hash for the given arguments.
|
282 |
Args:
|
283 |
*args: Arguments to hash
|
|
|
284 |
Returns:
|
285 |
str: Hash string
|
286 |
"""
|
|
|
288 |
|
289 |
# Convert all arguments to strings and join them
|
290 |
args_str = "".join([str(arg) for arg in args])
|
|
|
|
|
291 |
|
292 |
# Compute MD5 hash
|
293 |
return hashlib.md5(args_str.encode()).hexdigest()
|
294 |
|
295 |
|
296 |
+
def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
|
297 |
+
"""Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
|
298 |
+
|
299 |
+
Args:
|
300 |
+
mode: Cache mode (e.g., 'default', 'local', 'global')
|
301 |
+
cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
|
302 |
+
hash_value: Hash value from compute_args_hash
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
str: Flattened cache key
|
306 |
+
"""
|
307 |
+
return f"{mode}:{cache_type}:{hash_value}"
|
308 |
+
|
309 |
+
|
310 |
+
def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
|
311 |
+
"""Parse a flattened cache key back into its components
|
312 |
+
|
313 |
+
Args:
|
314 |
+
cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
|
318 |
+
"""
|
319 |
+
parts = cache_key.split(":", 2)
|
320 |
+
if len(parts) == 3:
|
321 |
+
return parts[0], parts[1], parts[2]
|
322 |
+
return None
|
323 |
+
|
324 |
+
|
325 |
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
326 |
"""
|
327 |
Compute a unique ID for a given content string.
|
|
|
808 |
return combined_data
|
809 |
|
810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
def cosine_similarity(v1, v2):
|
812 |
"""Calculate cosine similarity between two vectors"""
|
813 |
dot_product = np.dot(v1, v2)
|
|
|
857 |
mode="default",
|
858 |
cache_type=None,
|
859 |
):
|
860 |
+
"""Generic cache handling function with flattened cache keys"""
|
861 |
if hashing_kv is None:
|
862 |
return None, None, None, None
|
863 |
|
|
|
868 |
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
869 |
return None, None, None, None
|
870 |
|
871 |
+
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
872 |
+
flattened_key = generate_cache_key(mode, cache_type, args_hash)
|
873 |
+
cache_entry = await hashing_kv.get_by_id(flattened_key)
|
874 |
+
if cache_entry:
|
875 |
+
logger.debug(f"Flattened cache hit(key:{flattened_key})")
|
876 |
+
return cache_entry["return"], None, None, None
|
|
|
877 |
|
878 |
+
logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
|
879 |
return None, None, None, None
|
880 |
|
881 |
|
|
|
893 |
|
894 |
|
895 |
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
896 |
+
"""Save data to cache using flattened key structure.
|
897 |
|
898 |
Args:
|
899 |
hashing_kv: The key-value storage for caching
|
|
|
908 |
logger.debug("Streaming response detected, skipping cache")
|
909 |
return
|
910 |
|
911 |
+
# Use flattened cache key format: {mode}:{cache_type}:{hash}
|
912 |
+
flattened_key = generate_cache_key(
|
913 |
+
cache_data.mode, cache_data.cache_type, cache_data.args_hash
|
914 |
+
)
|
|
|
|
|
|
|
|
|
915 |
|
916 |
# Check if we already have identical content cached
|
917 |
+
existing_cache = await hashing_kv.get_by_id(flattened_key)
|
918 |
+
if existing_cache:
|
919 |
+
existing_content = existing_cache.get("return")
|
920 |
if existing_content == cache_data.content:
|
921 |
+
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
|
|
|
|
|
922 |
return
|
923 |
|
924 |
+
# Create cache entry with flattened structure
|
925 |
+
cache_entry = {
|
926 |
"return": cache_data.content,
|
927 |
"cache_type": cache_data.cache_type,
|
928 |
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
|
|
937 |
"original_prompt": cache_data.prompt,
|
938 |
}
|
939 |
|
940 |
+
logger.info(f" == LLM cache == saving: {flattened_key}")
|
941 |
|
942 |
+
# Save using flattened key
|
943 |
+
await hashing_kv.upsert({flattened_key: cache_entry})
|
944 |
|
945 |
|
946 |
def safe_unicode_decode(content):
|