gzdaniel commited on
Commit
6320c9d
·
1 Parent(s): 480d89b

feat: Flatten LLM cache structure for improved recall efficiency

Browse files

Refactored the LLM cache to a flat Key-Value (KV) structure, replacing the previous nested format. The old structure used the 'mode' as a key and stored specific cache content as JSON nested under it. This change significantly enhances cache recall efficiency.

examples/unofficial-sample/copy_llm_cache_to_another_storage.py CHANGED
@@ -52,18 +52,23 @@ async def copy_from_postgres_to_json():
52
  embedding_func=None,
53
  )
54
 
 
 
 
 
55
  kv = {}
56
- for c_id in await from_llm_response_cache.all_keys():
57
- print(f"Copying {c_id}")
58
- workspace = c_id["workspace"]
59
- mode = c_id["mode"]
60
- _id = c_id["id"]
61
- postgres_db.workspace = workspace
62
- obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
63
- if mode not in kv:
64
- kv[mode] = {}
65
- kv[mode][_id] = obj[_id]
66
- print(f"Object {obj}")
 
67
  await to_llm_response_cache.upsert(kv)
68
  await to_llm_response_cache.index_done_callback()
69
  print("Mission accomplished!")
@@ -85,13 +90,24 @@ async def copy_from_json_to_postgres():
85
  db=postgres_db,
86
  )
87
 
88
- for mode in await from_llm_response_cache.all_keys():
89
- print(f"Copying {mode}")
90
- caches = await from_llm_response_cache.get_by_id(mode)
91
- for k, v in caches.items():
92
- item = {mode: {k: v}}
93
- print(f"\tCopying {item}")
94
- await to_llm_response_cache.upsert(item)
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  if __name__ == "__main__":
 
52
  embedding_func=None,
53
  )
54
 
55
+ # Get all cache data using the new flattened structure
56
+ all_data = await from_llm_response_cache.get_all()
57
+
58
+ # Convert flattened data to hierarchical structure for JsonKVStorage
59
  kv = {}
60
+ for flattened_key, cache_entry in all_data.items():
61
+ # Parse flattened key: {mode}:{cache_type}:{hash}
62
+ parts = flattened_key.split(":", 2)
63
+ if len(parts) == 3:
64
+ mode, cache_type, hash_value = parts
65
+ if mode not in kv:
66
+ kv[mode] = {}
67
+ kv[mode][hash_value] = cache_entry
68
+ print(f"Copying {flattened_key} -> {mode}[{hash_value}]")
69
+ else:
70
+ print(f"Skipping invalid key format: {flattened_key}")
71
+
72
  await to_llm_response_cache.upsert(kv)
73
  await to_llm_response_cache.index_done_callback()
74
  print("Mission accomplished!")
 
90
  db=postgres_db,
91
  )
92
 
93
+ # Get all cache data from JsonKVStorage (hierarchical structure)
94
+ all_data = await from_llm_response_cache.get_all()
95
+
96
+ # Convert hierarchical data to flattened structure for PGKVStorage
97
+ flattened_data = {}
98
+ for mode, mode_data in all_data.items():
99
+ print(f"Processing mode: {mode}")
100
+ for hash_value, cache_entry in mode_data.items():
101
+ # Determine cache_type from cache entry or use default
102
+ cache_type = cache_entry.get("cache_type", "extract")
103
+ # Create flattened key: {mode}:{cache_type}:{hash}
104
+ flattened_key = f"{mode}:{cache_type}:{hash_value}"
105
+ flattened_data[flattened_key] = cache_entry
106
+ print(f"\tConverting {mode}[{hash_value}] -> {flattened_key}")
107
+
108
+ # Upsert the flattened data
109
+ await to_llm_response_cache.upsert(flattened_data)
110
+ print("Mission accomplished!")
111
 
112
 
113
  if __name__ == "__main__":
lightrag/kg/__init__.py CHANGED
@@ -37,6 +37,7 @@ STORAGE_IMPLEMENTATIONS = {
37
  "DOC_STATUS_STORAGE": {
38
  "implementations": [
39
  "JsonDocStatusStorage",
 
40
  "PGDocStatusStorage",
41
  "MongoDocStatusStorage",
42
  ],
@@ -79,6 +80,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
79
  "MongoVectorDBStorage": [],
80
  # Document Status Storage Implementations
81
  "JsonDocStatusStorage": [],
 
82
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
83
  "MongoDocStatusStorage": [],
84
  }
@@ -96,6 +98,7 @@ STORAGES = {
96
  "MongoGraphStorage": ".kg.mongo_impl",
97
  "MongoVectorDBStorage": ".kg.mongo_impl",
98
  "RedisKVStorage": ".kg.redis_impl",
 
99
  "ChromaVectorDBStorage": ".kg.chroma_impl",
100
  # "TiDBKVStorage": ".kg.tidb_impl",
101
  # "TiDBVectorDBStorage": ".kg.tidb_impl",
 
37
  "DOC_STATUS_STORAGE": {
38
  "implementations": [
39
  "JsonDocStatusStorage",
40
+ "RedisDocStatusStorage",
41
  "PGDocStatusStorage",
42
  "MongoDocStatusStorage",
43
  ],
 
80
  "MongoVectorDBStorage": [],
81
  # Document Status Storage Implementations
82
  "JsonDocStatusStorage": [],
83
+ "RedisDocStatusStorage": ["REDIS_URI"],
84
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
85
  "MongoDocStatusStorage": [],
86
  }
 
98
  "MongoGraphStorage": ".kg.mongo_impl",
99
  "MongoVectorDBStorage": ".kg.mongo_impl",
100
  "RedisKVStorage": ".kg.redis_impl",
101
+ "RedisDocStatusStorage": ".kg.redis_impl",
102
  "ChromaVectorDBStorage": ".kg.chroma_impl",
103
  # "TiDBKVStorage": ".kg.tidb_impl",
104
  # "TiDBVectorDBStorage": ".kg.tidb_impl",
lightrag/kg/chroma_impl.py CHANGED
@@ -109,7 +109,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
109
  raise
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
- logger.info(f"Inserting {len(data)} to {self.namespace}")
113
  if not data:
114
  return
115
 
@@ -234,7 +234,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
234
  ids: List of vector IDs to be deleted
235
  """
236
  try:
237
- logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
238
  self._collection.delete(ids=ids)
239
  logger.debug(
240
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
 
109
  raise
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
113
  if not data:
114
  return
115
 
 
234
  ids: List of vector IDs to be deleted
235
  """
236
  try:
 
237
  self._collection.delete(ids=ids)
238
  logger.debug(
239
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
lightrag/kg/json_kv_impl.py CHANGED
@@ -42,19 +42,14 @@ class JsonKVStorage(BaseKVStorage):
42
  if need_init:
43
  loaded_data = load_json(self._file_name) or {}
44
  async with self._storage_lock:
45
- self._data.update(loaded_data)
46
-
47
- # Calculate data count based on namespace
48
- if self.namespace.endswith("cache"):
49
- # For cache namespaces, sum the cache entries across all cache types
50
- data_count = sum(
51
- len(first_level_dict)
52
- for first_level_dict in loaded_data.values()
53
- if isinstance(first_level_dict, dict)
54
  )
55
- else:
56
- # For non-cache namespaces, use the original count method
57
- data_count = len(loaded_data)
58
 
59
  logger.info(
60
  f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
@@ -67,17 +62,8 @@ class JsonKVStorage(BaseKVStorage):
67
  dict(self._data) if hasattr(self._data, "_getvalue") else self._data
68
  )
69
 
70
- # Calculate data count based on namespace
71
- if self.namespace.endswith("cache"):
72
- # # For cache namespaces, sum the cache entries across all cache types
73
- data_count = sum(
74
- len(first_level_dict)
75
- for first_level_dict in data_dict.values()
76
- if isinstance(first_level_dict, dict)
77
- )
78
- else:
79
- # For non-cache namespaces, use the original count method
80
- data_count = len(data_dict)
81
 
82
  logger.debug(
83
  f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
@@ -150,14 +136,14 @@ class JsonKVStorage(BaseKVStorage):
150
  await set_all_update_flags(self.namespace)
151
 
152
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
153
- """Delete specific records from storage by by cache mode
154
 
155
  Importance notes for in-memory storage:
156
  1. Changes will be persisted to disk during the next index_done_callback
157
  2. update flags to notify other processes that data persistence is needed
158
 
159
  Args:
160
- ids (list[str]): List of cache mode to be drop from storage
161
 
162
  Returns:
163
  True: if the cache drop successfully
@@ -167,9 +153,29 @@ class JsonKVStorage(BaseKVStorage):
167
  return False
168
 
169
  try:
170
- await self.delete(modes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return True
172
- except Exception:
 
173
  return False
174
 
175
  # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
@@ -245,9 +251,58 @@ class JsonKVStorage(BaseKVStorage):
245
  logger.error(f"Error dropping {self.namespace}: {e}")
246
  return {"status": "error", "message": str(e)}
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  async def finalize(self):
249
  """Finalize storage resources
250
  Persistence cache data to disk before exiting
251
  """
252
- if self.namespace.endswith("cache"):
253
  await self.index_done_callback()
 
42
  if need_init:
43
  loaded_data = load_json(self._file_name) or {}
44
  async with self._storage_lock:
45
+ # Migrate legacy cache structure if needed
46
+ if self.namespace.endswith("_cache"):
47
+ loaded_data = await self._migrate_legacy_cache_structure(
48
+ loaded_data
 
 
 
 
 
49
  )
50
+
51
+ self._data.update(loaded_data)
52
+ data_count = len(loaded_data)
53
 
54
  logger.info(
55
  f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
 
62
  dict(self._data) if hasattr(self._data, "_getvalue") else self._data
63
  )
64
 
65
+ # Calculate data count - all data is now flattened
66
+ data_count = len(data_dict)
 
 
 
 
 
 
 
 
 
67
 
68
  logger.debug(
69
  f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
 
136
  await set_all_update_flags(self.namespace)
137
 
138
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
139
+ """Delete specific records from storage by cache mode
140
 
141
  Importance notes for in-memory storage:
142
  1. Changes will be persisted to disk during the next index_done_callback
143
  2. update flags to notify other processes that data persistence is needed
144
 
145
  Args:
146
+ modes (list[str]): List of cache modes to be dropped from storage
147
 
148
  Returns:
149
  True: if the cache drop successfully
 
153
  return False
154
 
155
  try:
156
+ async with self._storage_lock:
157
+ keys_to_delete = []
158
+ modes_set = set(modes) # Convert to set for efficient lookup
159
+
160
+ for key in list(self._data.keys()):
161
+ # Parse flattened cache key: mode:cache_type:hash
162
+ parts = key.split(":", 2)
163
+ if len(parts) == 3 and parts[0] in modes_set:
164
+ keys_to_delete.append(key)
165
+
166
+ # Batch delete
167
+ for key in keys_to_delete:
168
+ self._data.pop(key, None)
169
+
170
+ if keys_to_delete:
171
+ await set_all_update_flags(self.namespace)
172
+ logger.info(
173
+ f"Dropped {len(keys_to_delete)} cache entries for modes: {modes}"
174
+ )
175
+
176
  return True
177
+ except Exception as e:
178
+ logger.error(f"Error dropping cache by modes: {e}")
179
  return False
180
 
181
  # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
 
251
  logger.error(f"Error dropping {self.namespace}: {e}")
252
  return {"status": "error", "message": str(e)}
253
 
254
+ async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
255
+ """Migrate legacy nested cache structure to flattened structure
256
+
257
+ Args:
258
+ data: Original data dictionary that may contain legacy structure
259
+
260
+ Returns:
261
+ Migrated data dictionary with flattened cache keys
262
+ """
263
+ from lightrag.utils import generate_cache_key
264
+
265
+ # Early return if data is empty
266
+ if not data:
267
+ return data
268
+
269
+ # Check first entry to see if it's already in new format
270
+ first_key = next(iter(data.keys()))
271
+ if ":" in first_key and len(first_key.split(":")) == 3:
272
+ # Already in flattened format, return as-is
273
+ return data
274
+
275
+ migrated_data = {}
276
+ migration_count = 0
277
+
278
+ for key, value in data.items():
279
+ # Check if this is a legacy nested cache structure
280
+ if isinstance(value, dict) and all(
281
+ isinstance(v, dict) and "return" in v for v in value.values()
282
+ ):
283
+ # This looks like a legacy cache mode with nested structure
284
+ mode = key
285
+ for cache_hash, cache_entry in value.items():
286
+ cache_type = cache_entry.get("cache_type", "extract")
287
+ flattened_key = generate_cache_key(mode, cache_type, cache_hash)
288
+ migrated_data[flattened_key] = cache_entry
289
+ migration_count += 1
290
+ else:
291
+ # Keep non-cache data or already flattened cache data as-is
292
+ migrated_data[key] = value
293
+
294
+ if migration_count > 0:
295
+ logger.info(
296
+ f"Migrated {migration_count} legacy cache entries to flattened structure"
297
+ )
298
+ # Persist migrated data immediately
299
+ write_json(migrated_data, self._file_name)
300
+
301
+ return migrated_data
302
+
303
  async def finalize(self):
304
  """Finalize storage resources
305
  Persistence cache data to disk before exiting
306
  """
307
+ if self.namespace.endswith("_cache"):
308
  await self.index_done_callback()
lightrag/kg/milvus_impl.py CHANGED
@@ -75,7 +75,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
75
  )
76
 
77
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
78
- logger.info(f"Inserting {len(data)} to {self.namespace}")
79
  if not data:
80
  return
81
 
 
75
  )
76
 
77
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
78
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
79
  if not data:
80
  return
81
 
lightrag/kg/mongo_impl.py CHANGED
@@ -15,7 +15,6 @@ from ..base import (
15
  DocStatus,
16
  DocStatusStorage,
17
  )
18
- from ..namespace import NameSpace, is_namespace
19
  from ..utils import logger, compute_mdhash_id
20
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
21
  from ..constants import GRAPH_FIELD_SEP
@@ -98,17 +97,8 @@ class MongoKVStorage(BaseKVStorage):
98
  self._data = None
99
 
100
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
101
- if id == "default":
102
- # Find all documents with _id starting with "default_"
103
- cursor = self._data.find({"_id": {"$regex": "^default_"}})
104
- result = {}
105
- async for doc in cursor:
106
- # Use the complete _id as key
107
- result[doc["_id"]] = doc
108
- return result if result else None
109
- else:
110
- # Original behavior for non-"default" ids
111
- return await self._data.find_one({"_id": id})
112
 
113
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
114
  cursor = self._data.find({"_id": {"$in": ids}})
@@ -133,43 +123,21 @@ class MongoKVStorage(BaseKVStorage):
133
  return result
134
 
135
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
136
- logger.info(f"Inserting {len(data)} to {self.namespace}")
137
  if not data:
138
  return
139
 
140
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
141
- update_tasks: list[Any] = []
142
- for mode, items in data.items():
143
- for k, v in items.items():
144
- key = f"{mode}_{k}"
145
- data[mode][k]["_id"] = f"{mode}_{k}"
146
- update_tasks.append(
147
- self._data.update_one(
148
- {"_id": key}, {"$setOnInsert": v}, upsert=True
149
- )
150
- )
151
- await asyncio.gather(*update_tasks)
152
- else:
153
- update_tasks = []
154
- for k, v in data.items():
155
- data[k]["_id"] = k
156
- update_tasks.append(
157
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
158
- )
159
- await asyncio.gather(*update_tasks)
160
-
161
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
162
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
163
- res = {}
164
- v = await self._data.find_one({"_id": mode + "_" + id})
165
- if v:
166
- res[id] = v
167
- logger.debug(f"llm_response_cache find one by:{id}")
168
- return res
169
- else:
170
- return None
171
- else:
172
- return None
173
 
174
  async def index_done_callback(self) -> None:
175
  # Mongo handles persistence automatically
@@ -209,8 +177,8 @@ class MongoKVStorage(BaseKVStorage):
209
  return False
210
 
211
  try:
212
- # Build regex pattern to match documents with the specified modes
213
- pattern = f"^({'|'.join(modes)})_"
214
  result = await self._data.delete_many({"_id": {"$regex": pattern}})
215
  logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
216
  return True
@@ -274,7 +242,7 @@ class MongoDocStatusStorage(DocStatusStorage):
274
  return data - existing_ids
275
 
276
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
277
- logger.info(f"Inserting {len(data)} to {self.namespace}")
278
  if not data:
279
  return
280
  update_tasks: list[Any] = []
@@ -1282,7 +1250,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
1282
  logger.debug("vector index already exist")
1283
 
1284
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
1285
- logger.info(f"Inserting {len(data)} to {self.namespace}")
1286
  if not data:
1287
  return
1288
 
@@ -1371,7 +1339,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
1371
  Args:
1372
  ids: List of vector IDs to be deleted
1373
  """
1374
- logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
1375
  if not ids:
1376
  return
1377
 
 
15
  DocStatus,
16
  DocStatusStorage,
17
  )
 
18
  from ..utils import logger, compute_mdhash_id
19
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
20
  from ..constants import GRAPH_FIELD_SEP
 
97
  self._data = None
98
 
99
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
100
+ # Unified handling for flattened keys
101
+ return await self._data.find_one({"_id": id})
 
 
 
 
 
 
 
 
 
102
 
103
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
104
  cursor = self._data.find({"_id": {"$in": ids}})
 
123
  return result
124
 
125
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
126
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
127
  if not data:
128
  return
129
 
130
+ # Unified handling for all namespaces with flattened keys
131
+ # Use bulk_write for better performance
132
+ from pymongo import UpdateOne
133
+
134
+ operations = []
135
+ for k, v in data.items():
136
+ v["_id"] = k # Use flattened key as _id
137
+ operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
138
+
139
+ if operations:
140
+ await self._data.bulk_write(operations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  async def index_done_callback(self) -> None:
143
  # Mongo handles persistence automatically
 
177
  return False
178
 
179
  try:
180
+ # Build regex pattern to match flattened key format: mode:cache_type:hash
181
+ pattern = f"^({'|'.join(modes)}):"
182
  result = await self._data.delete_many({"_id": {"$regex": pattern}})
183
  logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
184
  return True
 
242
  return data - existing_ids
243
 
244
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
245
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
246
  if not data:
247
  return
248
  update_tasks: list[Any] = []
 
1250
  logger.debug("vector index already exist")
1251
 
1252
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
1253
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
1254
  if not data:
1255
  return
1256
 
 
1339
  Args:
1340
  ids: List of vector IDs to be deleted
1341
  """
1342
+ logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
1343
  if not ids:
1344
  return
1345
 
lightrag/kg/postgres_impl.py CHANGED
@@ -247,6 +247,116 @@ class PostgreSQLDB:
247
  logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
248
  # Do not re-raise, to allow the application to start
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  async def check_tables(self):
251
  # First create all tables
252
  for k, v in TABLES.items():
@@ -304,6 +414,13 @@ class PostgreSQLDB:
304
  except Exception as e:
305
  logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
306
 
 
 
 
 
 
 
 
307
  async def query(
308
  self,
309
  sql: str,
@@ -486,77 +603,48 @@ class PGKVStorage(BaseKVStorage):
486
 
487
  try:
488
  results = await self.db.query(sql, params, multirows=True)
489
-
 
490
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
491
- result_dict = {}
492
  for row in results:
493
- mode = row["mode"]
494
- if mode not in result_dict:
495
- result_dict[mode] = {}
496
- result_dict[mode][row["id"]] = row
497
- return result_dict
498
- else:
499
- return {row["id"]: row for row in results}
 
 
 
 
 
 
 
 
 
 
 
500
  except Exception as e:
501
  logger.error(f"Error retrieving all data from {self.namespace}: {e}")
502
  return {}
503
 
504
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
505
- """Get doc_full data by id."""
506
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
507
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
508
- # For LLM cache, the id parameter actually represents the mode
509
- params = {"workspace": self.db.workspace, "mode": id}
510
- array_res = await self.db.query(sql, params, multirows=True)
511
- res = {}
512
- for row in array_res:
513
- # Dynamically add cache_type field based on mode
514
- row_with_cache_type = dict(row)
515
- if id == "default":
516
- row_with_cache_type["cache_type"] = "extract"
517
- else:
518
- row_with_cache_type["cache_type"] = "unknown"
519
- res[row["id"]] = row_with_cache_type
520
- return res if res else None
521
- else:
522
- params = {"workspace": self.db.workspace, "id": id}
523
- response = await self.db.query(sql, params)
524
- return response if response else None
525
-
526
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
527
- """Specifically for llm_response_cache."""
528
- sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
529
- params = {"workspace": self.db.workspace, "mode": mode, "id": id}
530
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
531
- array_res = await self.db.query(sql, params, multirows=True)
532
- res = {}
533
- for row in array_res:
534
- res[row["id"]] = row
535
- return res
536
- else:
537
- return None
538
 
539
  # Query by id
540
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
541
- """Get doc_chunks data by id"""
542
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
543
  ids=",".join([f"'{id}'" for id in ids])
544
  )
545
  params = {"workspace": self.db.workspace}
546
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
547
- array_res = await self.db.query(sql, params, multirows=True)
548
- modes = set()
549
- dict_res: dict[str, dict] = {}
550
- for row in array_res:
551
- modes.add(row["mode"])
552
- for mode in modes:
553
- if mode not in dict_res:
554
- dict_res[mode] = {}
555
- for row in array_res:
556
- dict_res[row["mode"]][row["id"]] = row
557
- return [{k: v} for k, v in dict_res.items()]
558
- else:
559
- return await self.db.query(sql, params, multirows=True)
560
 
561
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
562
  """Specifically for llm_response_cache."""
@@ -617,19 +705,18 @@ class PGKVStorage(BaseKVStorage):
617
  }
618
  await self.db.execute(upsert_sql, _data)
619
  elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
620
- for mode, items in data.items():
621
- for k, v in items.items():
622
- upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
623
- _data = {
624
- "workspace": self.db.workspace,
625
- "id": k,
626
- "original_prompt": v["original_prompt"],
627
- "return_value": v["return"],
628
- "mode": mode,
629
- "chunk_id": v.get("chunk_id"),
630
- }
631
 
632
- await self.db.execute(upsert_sql, _data)
633
 
634
  async def index_done_callback(self) -> None:
635
  # PG handles persistence automatically
@@ -1035,8 +1122,8 @@ class PGDocStatusStorage(DocStatusStorage):
1035
  else:
1036
  exist_keys = []
1037
  new_keys = set([s for s in keys if s not in exist_keys])
1038
- print(f"keys: {keys}")
1039
- print(f"new_keys: {new_keys}")
1040
  return new_keys
1041
  except Exception as e:
1042
  logger.error(
@@ -2621,7 +2708,7 @@ SQL_TEMPLATES = {
2621
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2622
  """,
2623
  "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2624
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
2625
  """,
2626
  "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2627
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
 
247
  logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}")
248
  # Do not re-raise, to allow the application to start
249
 
250
+ async def _check_llm_cache_needs_migration(self):
251
+ """Check if LLM cache data needs migration by examining the first record"""
252
+ try:
253
+ # Only query the first record to determine format
254
+ check_sql = """
255
+ SELECT id FROM LIGHTRAG_LLM_CACHE
256
+ ORDER BY create_time ASC
257
+ LIMIT 1
258
+ """
259
+ result = await self.query(check_sql)
260
+
261
+ if result and result.get("id"):
262
+ # If id doesn't contain colon, it's old format
263
+ return ":" not in result["id"]
264
+
265
+ return False # No data or already new format
266
+ except Exception as e:
267
+ logger.warning(f"Failed to check LLM cache migration status: {e}")
268
+ return False
269
+
270
+ async def _migrate_llm_cache_to_flattened_keys(self):
271
+ """Migrate LLM cache to flattened key format, recalculating hash values"""
272
+ try:
273
+ # Get all old format data
274
+ old_data_sql = """
275
+ SELECT id, mode, original_prompt, return_value, chunk_id,
276
+ create_time, update_time
277
+ FROM LIGHTRAG_LLM_CACHE
278
+ WHERE id NOT LIKE '%:%'
279
+ """
280
+
281
+ old_records = await self.query(old_data_sql, multirows=True)
282
+
283
+ if not old_records:
284
+ logger.info("No old format LLM cache data found, skipping migration")
285
+ return
286
+
287
+ logger.info(
288
+ f"Found {len(old_records)} old format cache records, starting migration..."
289
+ )
290
+
291
+ # Import hash calculation function
292
+ from ..utils import compute_args_hash
293
+
294
+ migrated_count = 0
295
+
296
+ # Migrate data in batches
297
+ for record in old_records:
298
+ try:
299
+ # Recalculate hash using correct method
300
+ new_hash = compute_args_hash(
301
+ record["mode"], record["original_prompt"]
302
+ )
303
+
304
+ # Generate new flattened key
305
+ cache_type = "extract" # Default type
306
+ new_key = f"{record['mode']}:{cache_type}:{new_hash}"
307
+
308
+ # Insert new format data
309
+ insert_sql = """
310
+ INSERT INTO LIGHTRAG_LLM_CACHE
311
+ (workspace, id, mode, original_prompt, return_value, chunk_id, create_time, update_time)
312
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
313
+ ON CONFLICT (workspace, mode, id) DO NOTHING
314
+ """
315
+
316
+ await self.execute(
317
+ insert_sql,
318
+ {
319
+ "workspace": self.workspace,
320
+ "id": new_key,
321
+ "mode": record["mode"],
322
+ "original_prompt": record["original_prompt"],
323
+ "return_value": record["return_value"],
324
+ "chunk_id": record["chunk_id"],
325
+ "create_time": record["create_time"],
326
+ "update_time": record["update_time"],
327
+ },
328
+ )
329
+
330
+ # Delete old data
331
+ delete_sql = """
332
+ DELETE FROM LIGHTRAG_LLM_CACHE
333
+ WHERE workspace=$1 AND mode=$2 AND id=$3
334
+ """
335
+ await self.execute(
336
+ delete_sql,
337
+ {
338
+ "workspace": self.workspace,
339
+ "mode": record["mode"],
340
+ "id": record["id"], # Old id
341
+ },
342
+ )
343
+
344
+ migrated_count += 1
345
+
346
+ except Exception as e:
347
+ logger.warning(
348
+ f"Failed to migrate cache record {record['id']}: {e}"
349
+ )
350
+ continue
351
+
352
+ logger.info(
353
+ f"Successfully migrated {migrated_count} cache records to flattened format"
354
+ )
355
+
356
+ except Exception as e:
357
+ logger.error(f"LLM cache migration failed: {e}")
358
+ # Don't raise exception, allow system to continue startup
359
+
360
  async def check_tables(self):
361
  # First create all tables
362
  for k, v in TABLES.items():
 
414
  except Exception as e:
415
  logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}")
416
 
417
+ # Check and migrate LLM cache to flattened keys if needed
418
+ try:
419
+ if await self._check_llm_cache_needs_migration():
420
+ await self._migrate_llm_cache_to_flattened_keys()
421
+ except Exception as e:
422
+ logger.error(f"PostgreSQL, LLM cache migration failed: {e}")
423
+
424
  async def query(
425
  self,
426
  sql: str,
 
603
 
604
  try:
605
  results = await self.db.query(sql, params, multirows=True)
606
+
607
+ # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
608
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
609
+ processed_results = {}
610
  for row in results:
611
+ # Parse flattened key to extract cache_type
612
+ key_parts = row["id"].split(":")
613
+ cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
614
+
615
+ # Map field names and add cache_type for compatibility
616
+ processed_row = {
617
+ **row,
618
+ "return": row.get("return_value", ""), # Map return_value to return
619
+ "cache_type": cache_type, # Add cache_type from key
620
+ "original_prompt": row.get("original_prompt", ""),
621
+ "chunk_id": row.get("chunk_id"),
622
+ "mode": row.get("mode", "default")
623
+ }
624
+ processed_results[row["id"]] = processed_row
625
+ return processed_results
626
+
627
+ # For other namespaces, return as-is
628
+ return {row["id"]: row for row in results}
629
  except Exception as e:
630
  logger.error(f"Error retrieving all data from {self.namespace}: {e}")
631
  return {}
632
 
633
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
634
+ """Get data by id."""
635
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
636
+ params = {"workspace": self.db.workspace, "id": id}
637
+ response = await self.db.query(sql, params)
638
+ return response if response else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  # Query by id
641
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
642
+ """Get data by ids"""
643
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
644
  ids=",".join([f"'{id}'" for id in ids])
645
  )
646
  params = {"workspace": self.db.workspace}
647
+ return await self.db.query(sql, params, multirows=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
650
  """Specifically for llm_response_cache."""
 
705
  }
706
  await self.db.execute(upsert_sql, _data)
707
  elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
708
+ for k, v in data.items():
709
+ upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
710
+ _data = {
711
+ "workspace": self.db.workspace,
712
+ "id": k, # Use flattened key as id
713
+ "original_prompt": v["original_prompt"],
714
+ "return_value": v["return"],
715
+ "mode": v.get("mode", "default"), # Get mode from data
716
+ "chunk_id": v.get("chunk_id"),
717
+ }
 
718
 
719
+ await self.db.execute(upsert_sql, _data)
720
 
721
  async def index_done_callback(self) -> None:
722
  # PG handles persistence automatically
 
1122
  else:
1123
  exist_keys = []
1124
  new_keys = set([s for s in keys if s not in exist_keys])
1125
+ # print(f"keys: {keys}")
1126
+ # print(f"new_keys: {new_keys}")
1127
  return new_keys
1128
  except Exception as e:
1129
  logger.error(
 
2708
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2709
  """,
2710
  "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2711
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
2712
  """,
2713
  "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id
2714
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3
lightrag/kg/qdrant_impl.py CHANGED
@@ -85,7 +85,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
85
  )
86
 
87
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
88
- logger.info(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
 
85
  )
86
 
87
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
88
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
lightrag/kg/redis_impl.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
- from typing import Any, final
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
6
  from contextlib import asynccontextmanager
 
7
 
8
  if not pm.is_installed("redis"):
9
  pm.install("redis")
@@ -13,7 +14,7 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
13
  from redis.exceptions import RedisError, ConnectionError # type: ignore
14
  from lightrag.utils import logger
15
 
16
- from lightrag.base import BaseKVStorage
17
  import json
18
 
19
 
@@ -26,6 +27,41 @@ SOCKET_TIMEOUT = 5.0
26
  SOCKET_CONNECT_TIMEOUT = 3.0
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @final
30
  @dataclass
31
  class RedisKVStorage(BaseKVStorage):
@@ -33,19 +69,28 @@ class RedisKVStorage(BaseKVStorage):
33
  redis_url = os.environ.get(
34
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
35
  )
36
- # Create a connection pool with limits
37
- self._pool = ConnectionPool.from_url(
38
- redis_url,
39
- max_connections=MAX_CONNECTIONS,
40
- decode_responses=True,
41
- socket_timeout=SOCKET_TIMEOUT,
42
- socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
43
- )
44
  self._redis = Redis(connection_pool=self._pool)
45
  logger.info(
46
- f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @asynccontextmanager
50
  async def _get_redis_connection(self):
51
  """Safe context manager for Redis operations."""
@@ -99,21 +144,57 @@ class RedisKVStorage(BaseKVStorage):
99
  logger.error(f"JSON decode error in batch get: {e}")
100
  return [None] * len(ids)
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  async def filter_keys(self, keys: set[str]) -> set[str]:
103
  async with self._get_redis_connection() as redis:
104
  pipe = redis.pipeline()
105
- for key in keys:
 
106
  pipe.exists(f"{self.namespace}:{key}")
107
  results = await pipe.execute()
108
 
109
- existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
110
  return set(keys) - existing_ids
111
 
112
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
113
  if not data:
114
  return
115
-
116
- logger.info(f"Inserting {len(data)} items to {self.namespace}")
117
  async with self._get_redis_connection() as redis:
118
  try:
119
  pipe = redis.pipeline()
@@ -148,13 +229,13 @@ class RedisKVStorage(BaseKVStorage):
148
  )
149
 
150
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
151
- """Delete specific records from storage by by cache mode
152
 
153
  Importance notes for Redis storage:
154
  1. This will immediately delete the specified cache modes from Redis
155
 
156
  Args:
157
- modes (list[str]): List of cache mode to be drop from storage
158
 
159
  Returns:
160
  True: if the cache drop successfully
@@ -164,9 +245,43 @@ class RedisKVStorage(BaseKVStorage):
164
  return False
165
 
166
  try:
167
- await self.delete(modes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return True
169
- except Exception:
 
170
  return False
171
 
172
  async def drop(self) -> dict[str, str]:
@@ -177,24 +292,350 @@ class RedisKVStorage(BaseKVStorage):
177
  """
178
  async with self._get_redis_connection() as redis:
179
  try:
180
- keys = await redis.keys(f"{self.namespace}:*")
181
-
182
- if keys:
183
- pipe = redis.pipeline()
184
- for key in keys:
185
- pipe.delete(key)
186
- results = await pipe.execute()
187
- deleted_count = sum(results)
188
-
189
- logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
190
- return {
191
- "status": "success",
192
- "message": f"{deleted_count} keys dropped",
193
- }
194
- else:
195
- logger.info(f"No keys found to drop in {self.namespace}")
196
- return {"status": "success", "message": "no keys to drop"}
 
 
 
 
 
 
197
 
198
  except Exception as e:
199
  logger.error(f"Error dropping keys from {self.namespace}: {e}")
200
  return {"status": "error", "message": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, final, Union
3
  from dataclasses import dataclass
4
  import pipmaster as pm
5
  import configparser
6
  from contextlib import asynccontextmanager
7
+ import threading
8
 
9
  if not pm.is_installed("redis"):
10
  pm.install("redis")
 
14
  from redis.exceptions import RedisError, ConnectionError # type: ignore
15
  from lightrag.utils import logger
16
 
17
+ from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
18
  import json
19
 
20
 
 
27
  SOCKET_CONNECT_TIMEOUT = 3.0
28
 
29
 
30
+ class RedisConnectionManager:
31
+ """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
32
+
33
+ _pools = {}
34
+ _lock = threading.Lock()
35
+
36
+ @classmethod
37
+ def get_pool(cls, redis_url: str) -> ConnectionPool:
38
+ """Get or create a connection pool for the given Redis URL"""
39
+ if redis_url not in cls._pools:
40
+ with cls._lock:
41
+ if redis_url not in cls._pools:
42
+ cls._pools[redis_url] = ConnectionPool.from_url(
43
+ redis_url,
44
+ max_connections=MAX_CONNECTIONS,
45
+ decode_responses=True,
46
+ socket_timeout=SOCKET_TIMEOUT,
47
+ socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
48
+ )
49
+ logger.info(f"Created shared Redis connection pool for {redis_url}")
50
+ return cls._pools[redis_url]
51
+
52
+ @classmethod
53
+ def close_all_pools(cls):
54
+ """Close all connection pools (for cleanup)"""
55
+ with cls._lock:
56
+ for url, pool in cls._pools.items():
57
+ try:
58
+ pool.disconnect()
59
+ logger.info(f"Closed Redis connection pool for {url}")
60
+ except Exception as e:
61
+ logger.error(f"Error closing Redis pool for {url}: {e}")
62
+ cls._pools.clear()
63
+
64
+
65
  @final
66
  @dataclass
67
  class RedisKVStorage(BaseKVStorage):
 
69
  redis_url = os.environ.get(
70
  "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
71
  )
72
+ # Use shared connection pool
73
+ self._pool = RedisConnectionManager.get_pool(redis_url)
 
 
 
 
 
 
74
  self._redis = Redis(connection_pool=self._pool)
75
  logger.info(
76
+ f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
77
  )
78
 
79
+ async def initialize(self):
80
+ """Initialize Redis connection and migrate legacy cache structure if needed"""
81
+ # Test connection
82
+ try:
83
+ async with self._get_redis_connection() as redis:
84
+ await redis.ping()
85
+ logger.info(f"Connected to Redis for namespace {self.namespace}")
86
+ except Exception as e:
87
+ logger.error(f"Failed to connect to Redis: {e}")
88
+ raise
89
+
90
+ # Migrate legacy cache structure if this is a cache namespace
91
+ if self.namespace.endswith("_cache"):
92
+ await self._migrate_legacy_cache_structure()
93
+
94
  @asynccontextmanager
95
  async def _get_redis_connection(self):
96
  """Safe context manager for Redis operations."""
 
144
  logger.error(f"JSON decode error in batch get: {e}")
145
  return [None] * len(ids)
146
 
147
+ async def get_all(self) -> dict[str, Any]:
148
+ """Get all data from storage
149
+
150
+ Returns:
151
+ Dictionary containing all stored data
152
+ """
153
+ async with self._get_redis_connection() as redis:
154
+ try:
155
+ # Get all keys for this namespace
156
+ keys = await redis.keys(f"{self.namespace}:*")
157
+
158
+ if not keys:
159
+ return {}
160
+
161
+ # Get all values in batch
162
+ pipe = redis.pipeline()
163
+ for key in keys:
164
+ pipe.get(key)
165
+ values = await pipe.execute()
166
+
167
+ # Build result dictionary
168
+ result = {}
169
+ for key, value in zip(keys, values):
170
+ if value:
171
+ # Extract the ID part (after namespace:)
172
+ key_id = key.split(":", 1)[1]
173
+ try:
174
+ result[key_id] = json.loads(value)
175
+ except json.JSONDecodeError as e:
176
+ logger.error(f"JSON decode error for key {key}: {e}")
177
+ continue
178
+
179
+ return result
180
+ except Exception as e:
181
+ logger.error(f"Error getting all data from Redis: {e}")
182
+ return {}
183
+
184
  async def filter_keys(self, keys: set[str]) -> set[str]:
185
  async with self._get_redis_connection() as redis:
186
  pipe = redis.pipeline()
187
+ keys_list = list(keys) # Convert set to list for indexing
188
+ for key in keys_list:
189
  pipe.exists(f"{self.namespace}:{key}")
190
  results = await pipe.execute()
191
 
192
+ existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
193
  return set(keys) - existing_ids
194
 
195
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
196
  if not data:
197
  return
 
 
198
  async with self._get_redis_connection() as redis:
199
  try:
200
  pipe = redis.pipeline()
 
229
  )
230
 
231
  async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
232
+ """Delete specific records from storage by cache mode
233
 
234
  Importance notes for Redis storage:
235
  1. This will immediately delete the specified cache modes from Redis
236
 
237
  Args:
238
+ modes (list[str]): List of cache modes to be dropped from storage
239
 
240
  Returns:
241
  True: if the cache drop successfully
 
245
  return False
246
 
247
  try:
248
+ async with self._get_redis_connection() as redis:
249
+ keys_to_delete = []
250
+
251
+ # Find matching keys for each mode using SCAN
252
+ for mode in modes:
253
+ # Use correct pattern to match flattened cache key format {namespace}:{mode}:{cache_type}:{hash}
254
+ pattern = f"{self.namespace}:{mode}:*"
255
+ cursor = 0
256
+ mode_keys = []
257
+
258
+ while True:
259
+ cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
260
+ if keys:
261
+ mode_keys.extend(keys)
262
+
263
+ if cursor == 0:
264
+ break
265
+
266
+ keys_to_delete.extend(mode_keys)
267
+ logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
268
+
269
+ if keys_to_delete:
270
+ # Batch delete
271
+ pipe = redis.pipeline()
272
+ for key in keys_to_delete:
273
+ pipe.delete(key)
274
+ results = await pipe.execute()
275
+ deleted_count = sum(results)
276
+ logger.info(
277
+ f"Dropped {deleted_count} cache entries for modes: {modes}"
278
+ )
279
+ else:
280
+ logger.warning(f"No cache entries found for modes: {modes}")
281
+
282
  return True
283
+ except Exception as e:
284
+ logger.error(f"Error dropping cache by modes in Redis: {e}")
285
  return False
286
 
287
  async def drop(self) -> dict[str, str]:
 
292
  """
293
  async with self._get_redis_connection() as redis:
294
  try:
295
+ # Use SCAN to find all keys with the namespace prefix
296
+ pattern = f"{self.namespace}:*"
297
+ cursor = 0
298
+ deleted_count = 0
299
+
300
+ while True:
301
+ cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
302
+ if keys:
303
+ # Delete keys in batches
304
+ pipe = redis.pipeline()
305
+ for key in keys:
306
+ pipe.delete(key)
307
+ results = await pipe.execute()
308
+ deleted_count += sum(results)
309
+
310
+ if cursor == 0:
311
+ break
312
+
313
+ logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
314
+ return {
315
+ "status": "success",
316
+ "message": f"{deleted_count} keys dropped",
317
+ }
318
 
319
  except Exception as e:
320
  logger.error(f"Error dropping keys from {self.namespace}: {e}")
321
  return {"status": "error", "message": str(e)}
322
+
323
+ async def _migrate_legacy_cache_structure(self):
324
+ """Migrate legacy nested cache structure to flattened structure for Redis
325
+
326
+ Redis already stores data in a flattened way, but we need to check for
327
+ legacy keys that might contain nested JSON structures and migrate them.
328
+
329
+ Early exit if any flattened key is found (indicating migration already done).
330
+ """
331
+ from lightrag.utils import generate_cache_key
332
+
333
+ async with self._get_redis_connection() as redis:
334
+ # Get all keys for this namespace
335
+ keys = await redis.keys(f"{self.namespace}:*")
336
+
337
+ if not keys:
338
+ return
339
+
340
+ # Check if we have any flattened keys already - if so, skip migration
341
+ has_flattened_keys = False
342
+ keys_to_migrate = []
343
+
344
+ for key in keys:
345
+ # Extract the ID part (after namespace:)
346
+ key_id = key.split(":", 1)[1]
347
+
348
+ # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash)
349
+ if ":" in key_id and len(key_id.split(":")) == 3:
350
+ has_flattened_keys = True
351
+ break # Early exit - migration already done
352
+
353
+ # Get the data to check if it's a legacy nested structure
354
+ data = await redis.get(key)
355
+ if data:
356
+ try:
357
+ parsed_data = json.loads(data)
358
+ # Check if this looks like a legacy cache mode with nested structure
359
+ if isinstance(parsed_data, dict) and all(
360
+ isinstance(v, dict) and "return" in v
361
+ for v in parsed_data.values()
362
+ ):
363
+ keys_to_migrate.append((key, key_id, parsed_data))
364
+ except json.JSONDecodeError:
365
+ continue
366
+
367
+ # If we found any flattened keys, assume migration is already done
368
+ if has_flattened_keys:
369
+ logger.debug(
370
+ f"Found flattened cache keys in {self.namespace}, skipping migration"
371
+ )
372
+ return
373
+
374
+ if not keys_to_migrate:
375
+ return
376
+
377
+ # Perform migration
378
+ pipe = redis.pipeline()
379
+ migration_count = 0
380
+
381
+ for old_key, mode, nested_data in keys_to_migrate:
382
+ # Delete the old key
383
+ pipe.delete(old_key)
384
+
385
+ # Create new flattened keys
386
+ for cache_hash, cache_entry in nested_data.items():
387
+ cache_type = cache_entry.get("cache_type", "extract")
388
+ flattened_key = generate_cache_key(mode, cache_type, cache_hash)
389
+ full_key = f"{self.namespace}:{flattened_key}"
390
+ pipe.set(full_key, json.dumps(cache_entry))
391
+ migration_count += 1
392
+
393
+ await pipe.execute()
394
+
395
+ if migration_count > 0:
396
+ logger.info(
397
+ f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
398
+ )
399
+
400
+
401
+ @final
402
+ @dataclass
403
+ class RedisDocStatusStorage(DocStatusStorage):
404
+ """Redis implementation of document status storage"""
405
+
406
+ def __post_init__(self):
407
+ redis_url = os.environ.get(
408
+ "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
409
+ )
410
+ # Use shared connection pool
411
+ self._pool = RedisConnectionManager.get_pool(redis_url)
412
+ self._redis = Redis(connection_pool=self._pool)
413
+ logger.info(
414
+ f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
415
+ )
416
+
417
+ async def initialize(self):
418
+ """Initialize Redis connection"""
419
+ try:
420
+ async with self._get_redis_connection() as redis:
421
+ await redis.ping()
422
+ logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
423
+ except Exception as e:
424
+ logger.error(f"Failed to connect to Redis for doc status: {e}")
425
+ raise
426
+
427
+ @asynccontextmanager
428
+ async def _get_redis_connection(self):
429
+ """Safe context manager for Redis operations."""
430
+ try:
431
+ yield self._redis
432
+ except ConnectionError as e:
433
+ logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
434
+ raise
435
+ except RedisError as e:
436
+ logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
437
+ raise
438
+ except Exception as e:
439
+ logger.error(
440
+ f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
441
+ )
442
+ raise
443
+
444
+ async def close(self):
445
+ """Close the Redis connection."""
446
+ if hasattr(self, "_redis") and self._redis:
447
+ await self._redis.close()
448
+ logger.debug(f"Closed Redis connection for doc status {self.namespace}")
449
+
450
+ async def __aenter__(self):
451
+ """Support for async context manager."""
452
+ return self
453
+
454
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
455
+ """Ensure Redis resources are cleaned up when exiting context."""
456
+ await self.close()
457
+
458
+ async def filter_keys(self, keys: set[str]) -> set[str]:
459
+ """Return keys that should be processed (not in storage or not successfully processed)"""
460
+ async with self._get_redis_connection() as redis:
461
+ pipe = redis.pipeline()
462
+ keys_list = list(keys)
463
+ for key in keys_list:
464
+ pipe.exists(f"{self.namespace}:{key}")
465
+ results = await pipe.execute()
466
+
467
+ existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
468
+ return set(keys) - existing_ids
469
+
470
+ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
471
+ result: list[dict[str, Any]] = []
472
+ async with self._get_redis_connection() as redis:
473
+ try:
474
+ pipe = redis.pipeline()
475
+ for id in ids:
476
+ pipe.get(f"{self.namespace}:{id}")
477
+ results = await pipe.execute()
478
+
479
+ for result_data in results:
480
+ if result_data:
481
+ try:
482
+ result.append(json.loads(result_data))
483
+ except json.JSONDecodeError as e:
484
+ logger.error(f"JSON decode error in get_by_ids: {e}")
485
+ continue
486
+ except Exception as e:
487
+ logger.error(f"Error in get_by_ids: {e}")
488
+ return result
489
+
490
+ async def get_status_counts(self) -> dict[str, int]:
491
+ """Get counts of documents in each status"""
492
+ counts = {status.value: 0 for status in DocStatus}
493
+ async with self._get_redis_connection() as redis:
494
+ try:
495
+ # Use SCAN to iterate through all keys in the namespace
496
+ cursor = 0
497
+ while True:
498
+ cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
499
+ if keys:
500
+ # Get all values in batch
501
+ pipe = redis.pipeline()
502
+ for key in keys:
503
+ pipe.get(key)
504
+ values = await pipe.execute()
505
+
506
+ # Count statuses
507
+ for value in values:
508
+ if value:
509
+ try:
510
+ doc_data = json.loads(value)
511
+ status = doc_data.get("status")
512
+ if status in counts:
513
+ counts[status] += 1
514
+ except json.JSONDecodeError:
515
+ continue
516
+
517
+ if cursor == 0:
518
+ break
519
+ except Exception as e:
520
+ logger.error(f"Error getting status counts: {e}")
521
+
522
+ return counts
523
+
524
+ async def get_docs_by_status(
525
+ self, status: DocStatus
526
+ ) -> dict[str, DocProcessingStatus]:
527
+ """Get all documents with a specific status"""
528
+ result = {}
529
+ async with self._get_redis_connection() as redis:
530
+ try:
531
+ # Use SCAN to iterate through all keys in the namespace
532
+ cursor = 0
533
+ while True:
534
+ cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
535
+ if keys:
536
+ # Get all values in batch
537
+ pipe = redis.pipeline()
538
+ for key in keys:
539
+ pipe.get(key)
540
+ values = await pipe.execute()
541
+
542
+ # Filter by status and create DocProcessingStatus objects
543
+ for key, value in zip(keys, values):
544
+ if value:
545
+ try:
546
+ doc_data = json.loads(value)
547
+ if doc_data.get("status") == status.value:
548
+ # Extract document ID from key
549
+ doc_id = key.split(":", 1)[1]
550
+
551
+ # Make a copy of the data to avoid modifying the original
552
+ data = doc_data.copy()
553
+ # If content is missing, use content_summary as content
554
+ if "content" not in data and "content_summary" in data:
555
+ data["content"] = data["content_summary"]
556
+ # If file_path is not in data, use document id as file path
557
+ if "file_path" not in data:
558
+ data["file_path"] = "no-file-path"
559
+
560
+ result[doc_id] = DocProcessingStatus(**data)
561
+ except (json.JSONDecodeError, KeyError) as e:
562
+ logger.error(f"Error processing document {key}: {e}")
563
+ continue
564
+
565
+ if cursor == 0:
566
+ break
567
+ except Exception as e:
568
+ logger.error(f"Error getting docs by status: {e}")
569
+
570
+ return result
571
+
572
+ async def index_done_callback(self) -> None:
573
+ """Redis handles persistence automatically"""
574
+ pass
575
+
576
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
577
+ """Insert or update document status data"""
578
+ if not data:
579
+ return
580
+
581
+ logger.debug(f"Inserting {len(data)} records to {self.namespace}")
582
+ async with self._get_redis_connection() as redis:
583
+ try:
584
+ pipe = redis.pipeline()
585
+ for k, v in data.items():
586
+ pipe.set(f"{self.namespace}:{k}", json.dumps(v))
587
+ await pipe.execute()
588
+ except json.JSONEncodeError as e:
589
+ logger.error(f"JSON encode error during upsert: {e}")
590
+ raise
591
+
592
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
593
+ async with self._get_redis_connection() as redis:
594
+ try:
595
+ data = await redis.get(f"{self.namespace}:{id}")
596
+ return json.loads(data) if data else None
597
+ except json.JSONDecodeError as e:
598
+ logger.error(f"JSON decode error for id {id}: {e}")
599
+ return None
600
+
601
+ async def delete(self, doc_ids: list[str]) -> None:
602
+ """Delete specific records from storage by their IDs"""
603
+ if not doc_ids:
604
+ return
605
+
606
+ async with self._get_redis_connection() as redis:
607
+ pipe = redis.pipeline()
608
+ for doc_id in doc_ids:
609
+ pipe.delete(f"{self.namespace}:{doc_id}")
610
+
611
+ results = await pipe.execute()
612
+ deleted_count = sum(results)
613
+ logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
614
+
615
+ async def drop(self) -> dict[str, str]:
616
+ """Drop all document status data from storage and clean up resources"""
617
+ try:
618
+ async with self._get_redis_connection() as redis:
619
+ # Use SCAN to find all keys with the namespace prefix
620
+ pattern = f"{self.namespace}:*"
621
+ cursor = 0
622
+ deleted_count = 0
623
+
624
+ while True:
625
+ cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
626
+ if keys:
627
+ # Delete keys in batches
628
+ pipe = redis.pipeline()
629
+ for key in keys:
630
+ pipe.delete(key)
631
+ results = await pipe.execute()
632
+ deleted_count += sum(results)
633
+
634
+ if cursor == 0:
635
+ break
636
+
637
+ logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
638
+ return {"status": "success", "message": "data dropped"}
639
+ except Exception as e:
640
+ logger.error(f"Error dropping doc status {self.namespace}: {e}")
641
+ return {"status": "error", "message": str(e)}
lightrag/kg/tidb_impl.py CHANGED
@@ -257,7 +257,7 @@ class TiDBKVStorage(BaseKVStorage):
257
 
258
  ################ INSERT full_doc AND chunks ################
259
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
260
- logger.info(f"Inserting {len(data)} to {self.namespace}")
261
  if not data:
262
  return
263
  left_data = {k: v for k, v in data.items() if k not in self._data}
@@ -454,11 +454,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
454
 
455
  ###### INSERT entities And relationships ######
456
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
457
- logger.info(f"Inserting {len(data)} to {self.namespace}")
458
  if not data:
459
  return
460
-
461
- logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
462
 
463
  # Get current time as UNIX timestamp
464
  import time
 
257
 
258
  ################ INSERT full_doc AND chunks ################
259
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
260
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
261
  if not data:
262
  return
263
  left_data = {k: v for k, v in data.items() if k not in self._data}
 
454
 
455
  ###### INSERT entities And relationships ######
456
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
457
  if not data:
458
  return
459
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
 
460
 
461
  # Get current time as UNIX timestamp
462
  import time
lightrag/operate.py CHANGED
@@ -399,10 +399,10 @@ async def _get_cached_extraction_results(
399
  """
400
  cached_results = {}
401
 
402
- # Get all cached data for "default" mode (entity extraction cache)
403
- default_cache = await llm_response_cache.get_by_id("default") or {}
404
 
405
- for cache_key, cache_entry in default_cache.items():
406
  if (
407
  isinstance(cache_entry, dict)
408
  and cache_entry.get("cache_type") == "extract"
@@ -1387,7 +1387,7 @@ async def kg_query(
1387
  use_model_func = partial(use_model_func, _priority=5)
1388
 
1389
  # Handle cache
1390
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
1391
  cached_response, quantized, min_val, max_val = await handle_cache(
1392
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
1393
  )
@@ -1546,7 +1546,7 @@ async def extract_keywords_only(
1546
  """
1547
 
1548
  # 1. Handle cache if needed - add cache type for keywords
1549
- args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
1550
  cached_response, quantized, min_val, max_val = await handle_cache(
1551
  hashing_kv, args_hash, text, param.mode, cache_type="keywords"
1552
  )
@@ -2413,7 +2413,7 @@ async def naive_query(
2413
  use_model_func = partial(use_model_func, _priority=5)
2414
 
2415
  # Handle cache
2416
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
2417
  cached_response, quantized, min_val, max_val = await handle_cache(
2418
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2419
  )
@@ -2529,7 +2529,7 @@ async def kg_query_with_keywords(
2529
  # Apply higher priority (5) to query relation LLM function
2530
  use_model_func = partial(use_model_func, _priority=5)
2531
 
2532
- args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
2533
  cached_response, quantized, min_val, max_val = await handle_cache(
2534
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2535
  )
 
399
  """
400
  cached_results = {}
401
 
402
+ # Get all cached data (flattened cache structure)
403
+ all_cache = await llm_response_cache.get_all()
404
 
405
+ for cache_key, cache_entry in all_cache.items():
406
  if (
407
  isinstance(cache_entry, dict)
408
  and cache_entry.get("cache_type") == "extract"
 
1387
  use_model_func = partial(use_model_func, _priority=5)
1388
 
1389
  # Handle cache
1390
+ args_hash = compute_args_hash(query_param.mode, query)
1391
  cached_response, quantized, min_val, max_val = await handle_cache(
1392
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
1393
  )
 
1546
  """
1547
 
1548
  # 1. Handle cache if needed - add cache type for keywords
1549
+ args_hash = compute_args_hash(param.mode, text)
1550
  cached_response, quantized, min_val, max_val = await handle_cache(
1551
  hashing_kv, args_hash, text, param.mode, cache_type="keywords"
1552
  )
 
2413
  use_model_func = partial(use_model_func, _priority=5)
2414
 
2415
  # Handle cache
2416
+ args_hash = compute_args_hash(query_param.mode, query)
2417
  cached_response, quantized, min_val, max_val = await handle_cache(
2418
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2419
  )
 
2529
  # Apply higher priority (5) to query relation LLM function
2530
  use_model_func = partial(use_model_func, _priority=5)
2531
 
2532
+ args_hash = compute_args_hash(query_param.mode, query)
2533
  cached_response, quantized, min_val, max_val = await handle_cache(
2534
  hashing_kv, args_hash, query, query_param.mode, cache_type="query"
2535
  )
lightrag/utils.py CHANGED
@@ -14,7 +14,6 @@ from functools import wraps
14
  from hashlib import md5
15
  from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import numpy as np
17
- from lightrag.prompt import PROMPTS
18
  from dotenv import load_dotenv
19
  from lightrag.constants import (
20
  DEFAULT_LOG_MAX_BYTES,
@@ -278,11 +277,10 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
278
  raise e from None
279
 
280
 
281
- def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
282
  """Compute a hash for the given arguments.
283
  Args:
284
  *args: Arguments to hash
285
- cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
286
  Returns:
287
  str: Hash string
288
  """
@@ -290,13 +288,40 @@ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
290
 
291
  # Convert all arguments to strings and join them
292
  args_str = "".join([str(arg) for arg in args])
293
- if cache_type:
294
- args_str = f"{cache_type}:{args_str}"
295
 
296
  # Compute MD5 hash
297
  return hashlib.md5(args_str.encode()).hexdigest()
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  def compute_mdhash_id(content: str, prefix: str = "") -> str:
301
  """
302
  Compute a unique ID for a given content string.
@@ -783,131 +808,6 @@ def process_combine_contexts(*context_lists):
783
  return combined_data
784
 
785
 
786
- async def get_best_cached_response(
787
- hashing_kv,
788
- current_embedding,
789
- similarity_threshold=0.95,
790
- mode="default",
791
- use_llm_check=False,
792
- llm_func=None,
793
- original_prompt=None,
794
- cache_type=None,
795
- ) -> str | None:
796
- logger.debug(
797
- f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
798
- )
799
- mode_cache = await hashing_kv.get_by_id(mode)
800
- if not mode_cache:
801
- return None
802
-
803
- best_similarity = -1
804
- best_response = None
805
- best_prompt = None
806
- best_cache_id = None
807
-
808
- # Only iterate through cache entries for this mode
809
- for cache_id, cache_data in mode_cache.items():
810
- # Skip if cache_type doesn't match
811
- if cache_type and cache_data.get("cache_type") != cache_type:
812
- continue
813
-
814
- # Check if cache data is valid
815
- if cache_data["embedding"] is None:
816
- continue
817
-
818
- try:
819
- # Safely convert cached embedding
820
- cached_quantized = np.frombuffer(
821
- bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
822
- ).reshape(cache_data["embedding_shape"])
823
-
824
- # Ensure min_val and max_val are valid float values
825
- embedding_min = cache_data.get("embedding_min")
826
- embedding_max = cache_data.get("embedding_max")
827
-
828
- if (
829
- embedding_min is None
830
- or embedding_max is None
831
- or embedding_min >= embedding_max
832
- ):
833
- logger.warning(
834
- f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
835
- )
836
- continue
837
-
838
- cached_embedding = dequantize_embedding(
839
- cached_quantized,
840
- embedding_min,
841
- embedding_max,
842
- )
843
- except Exception as e:
844
- logger.warning(f"Error processing cached embedding: {str(e)}")
845
- continue
846
-
847
- similarity = cosine_similarity(current_embedding, cached_embedding)
848
- if similarity > best_similarity:
849
- best_similarity = similarity
850
- best_response = cache_data["return"]
851
- best_prompt = cache_data["original_prompt"]
852
- best_cache_id = cache_id
853
-
854
- if best_similarity > similarity_threshold:
855
- # If LLM check is enabled and all required parameters are provided
856
- if (
857
- use_llm_check
858
- and llm_func
859
- and original_prompt
860
- and best_prompt
861
- and best_response is not None
862
- ):
863
- compare_prompt = PROMPTS["similarity_check"].format(
864
- original_prompt=original_prompt, cached_prompt=best_prompt
865
- )
866
-
867
- try:
868
- llm_result = await llm_func(compare_prompt)
869
- llm_result = llm_result.strip()
870
- llm_similarity = float(llm_result)
871
-
872
- # Replace vector similarity with LLM similarity score
873
- best_similarity = llm_similarity
874
- if best_similarity < similarity_threshold:
875
- log_data = {
876
- "event": "cache_rejected_by_llm",
877
- "type": cache_type,
878
- "mode": mode,
879
- "original_question": original_prompt[:100] + "..."
880
- if len(original_prompt) > 100
881
- else original_prompt,
882
- "cached_question": best_prompt[:100] + "..."
883
- if len(best_prompt) > 100
884
- else best_prompt,
885
- "similarity_score": round(best_similarity, 4),
886
- "threshold": similarity_threshold,
887
- }
888
- logger.debug(json.dumps(log_data, ensure_ascii=False))
889
- logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
890
- return None
891
- except Exception as e: # Catch all possible exceptions
892
- logger.warning(f"LLM similarity check failed: {e}")
893
- return None # Return None directly when LLM check fails
894
-
895
- prompt_display = (
896
- best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
897
- )
898
- log_data = {
899
- "event": "cache_hit",
900
- "type": cache_type,
901
- "mode": mode,
902
- "similarity": round(best_similarity, 4),
903
- "cache_id": best_cache_id,
904
- "original_prompt": prompt_display,
905
- }
906
- logger.debug(json.dumps(log_data, ensure_ascii=False))
907
- return best_response
908
- return None
909
-
910
-
911
  def cosine_similarity(v1, v2):
912
  """Calculate cosine similarity between two vectors"""
913
  dot_product = np.dot(v1, v2)
@@ -957,7 +857,7 @@ async def handle_cache(
957
  mode="default",
958
  cache_type=None,
959
  ):
960
- """Generic cache handling function"""
961
  if hashing_kv is None:
962
  return None, None, None, None
963
 
@@ -968,15 +868,14 @@ async def handle_cache(
968
  if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
969
  return None, None, None, None
970
 
971
- if exists_func(hashing_kv, "get_by_mode_and_id"):
972
- mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
973
- else:
974
- mode_cache = await hashing_kv.get_by_id(mode) or {}
975
- if args_hash in mode_cache:
976
- logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
977
- return mode_cache[args_hash]["return"], None, None, None
978
 
979
- logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
980
  return None, None, None, None
981
 
982
 
@@ -994,7 +893,7 @@ class CacheData:
994
 
995
 
996
  async def save_to_cache(hashing_kv, cache_data: CacheData):
997
- """Save data to cache, with improved handling for streaming responses and duplicate content.
998
 
999
  Args:
1000
  hashing_kv: The key-value storage for caching
@@ -1009,26 +908,21 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
1009
  logger.debug("Streaming response detected, skipping cache")
1010
  return
1011
 
1012
- # Get existing cache data
1013
- if exists_func(hashing_kv, "get_by_mode_and_id"):
1014
- mode_cache = (
1015
- await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
1016
- or {}
1017
- )
1018
- else:
1019
- mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
1020
 
1021
  # Check if we already have identical content cached
1022
- if cache_data.args_hash in mode_cache:
1023
- existing_content = mode_cache[cache_data.args_hash].get("return")
 
1024
  if existing_content == cache_data.content:
1025
- logger.info(
1026
- f"Cache content unchanged for {cache_data.args_hash}, skipping update"
1027
- )
1028
  return
1029
 
1030
- # Update cache with new content
1031
- mode_cache[cache_data.args_hash] = {
1032
  "return": cache_data.content,
1033
  "cache_type": cache_data.cache_type,
1034
  "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
@@ -1043,10 +937,10 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
1043
  "original_prompt": cache_data.prompt,
1044
  }
1045
 
1046
- logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
1047
 
1048
- # Only upsert if there's actual new content
1049
- await hashing_kv.upsert({cache_data.mode: mode_cache})
1050
 
1051
 
1052
  def safe_unicode_decode(content):
 
14
  from hashlib import md5
15
  from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import numpy as np
 
17
  from dotenv import load_dotenv
18
  from lightrag.constants import (
19
  DEFAULT_LOG_MAX_BYTES,
 
277
  raise e from None
278
 
279
 
280
+ def compute_args_hash(*args: Any) -> str:
281
  """Compute a hash for the given arguments.
282
  Args:
283
  *args: Arguments to hash
 
284
  Returns:
285
  str: Hash string
286
  """
 
288
 
289
  # Convert all arguments to strings and join them
290
  args_str = "".join([str(arg) for arg in args])
 
 
291
 
292
  # Compute MD5 hash
293
  return hashlib.md5(args_str.encode()).hexdigest()
294
 
295
 
296
+ def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str:
297
+ """Generate a flattened cache key in the format {mode}:{cache_type}:{hash}
298
+
299
+ Args:
300
+ mode: Cache mode (e.g., 'default', 'local', 'global')
301
+ cache_type: Type of cache (e.g., 'extract', 'query', 'keywords')
302
+ hash_value: Hash value from compute_args_hash
303
+
304
+ Returns:
305
+ str: Flattened cache key
306
+ """
307
+ return f"{mode}:{cache_type}:{hash_value}"
308
+
309
+
310
+ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
311
+ """Parse a flattened cache key back into its components
312
+
313
+ Args:
314
+ cache_key: Flattened cache key in format {mode}:{cache_type}:{hash}
315
+
316
+ Returns:
317
+ tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format
318
+ """
319
+ parts = cache_key.split(":", 2)
320
+ if len(parts) == 3:
321
+ return parts[0], parts[1], parts[2]
322
+ return None
323
+
324
+
325
  def compute_mdhash_id(content: str, prefix: str = "") -> str:
326
  """
327
  Compute a unique ID for a given content string.
 
808
  return combined_data
809
 
810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  def cosine_similarity(v1, v2):
812
  """Calculate cosine similarity between two vectors"""
813
  dot_product = np.dot(v1, v2)
 
857
  mode="default",
858
  cache_type=None,
859
  ):
860
+ """Generic cache handling function with flattened cache keys"""
861
  if hashing_kv is None:
862
  return None, None, None, None
863
 
 
868
  if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
869
  return None, None, None, None
870
 
871
+ # Use flattened cache key format: {mode}:{cache_type}:{hash}
872
+ flattened_key = generate_cache_key(mode, cache_type, args_hash)
873
+ cache_entry = await hashing_kv.get_by_id(flattened_key)
874
+ if cache_entry:
875
+ logger.debug(f"Flattened cache hit(key:{flattened_key})")
876
+ return cache_entry["return"], None, None, None
 
877
 
878
+ logger.debug(f"Cache missed(mode:{mode} type:{cache_type})")
879
  return None, None, None, None
880
 
881
 
 
893
 
894
 
895
  async def save_to_cache(hashing_kv, cache_data: CacheData):
896
+ """Save data to cache using flattened key structure.
897
 
898
  Args:
899
  hashing_kv: The key-value storage for caching
 
908
  logger.debug("Streaming response detected, skipping cache")
909
  return
910
 
911
+ # Use flattened cache key format: {mode}:{cache_type}:{hash}
912
+ flattened_key = generate_cache_key(
913
+ cache_data.mode, cache_data.cache_type, cache_data.args_hash
914
+ )
 
 
 
 
915
 
916
  # Check if we already have identical content cached
917
+ existing_cache = await hashing_kv.get_by_id(flattened_key)
918
+ if existing_cache:
919
+ existing_content = existing_cache.get("return")
920
  if existing_content == cache_data.content:
921
+ logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
 
 
922
  return
923
 
924
+ # Create cache entry with flattened structure
925
+ cache_entry = {
926
  "return": cache_data.content,
927
  "cache_type": cache_data.cache_type,
928
  "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
 
937
  "original_prompt": cache_data.prompt,
938
  }
939
 
940
+ logger.info(f" == LLM cache == saving: {flattened_key}")
941
 
942
+ # Save using flattened key
943
+ await hashing_kv.upsert({flattened_key: cache_entry})
944
 
945
 
946
  def safe_unicode_decode(content):