Daniel.y commited on
Commit
3ab6fa0
·
unverified ·
2 Parent(s): 60cd0b2 3ca4de2

Merge pull request #1719 from danielaskdd/fix-redis-doc-delete

Browse files

Fix LLM cache handling for Redis to address document deletion scenarios

Files changed (1) hide show
  1. lightrag/kg/redis_impl.py +53 -7
lightrag/kg/redis_impl.py CHANGED
@@ -79,13 +79,59 @@ class RedisKVStorage(BaseKVStorage):
79
  await self.close()
80
 
81
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
82
- async with self._get_redis_connection() as redis:
83
- try:
84
- data = await redis.get(f"{self.namespace}:{id}")
85
- return json.loads(data) if data else None
86
- except json.JSONDecodeError as e:
87
- logger.error(f"JSON decode error for id {id}: {e}")
88
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
91
  async with self._get_redis_connection() as redis:
 
79
  await self.close()
80
 
81
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
82
+ if id == "default":
83
+ # Find all cache entries with cache_type == "extract"
84
+ async with self._get_redis_connection() as redis:
85
+ try:
86
+ result = {}
87
+ pattern = f"{self.namespace}:*"
88
+ cursor = 0
89
+
90
+ while True:
91
+ cursor, keys = await redis.scan(
92
+ cursor, match=pattern, count=100
93
+ )
94
+
95
+ if keys:
96
+ # Batch get values for these keys
97
+ pipe = redis.pipeline()
98
+ for key in keys:
99
+ pipe.get(key)
100
+ values = await pipe.execute()
101
+
102
+ # Check each value for cache_type == "extract"
103
+ for key, value in zip(keys, values):
104
+ if value:
105
+ try:
106
+ data = json.loads(value)
107
+ if (
108
+ isinstance(data, dict)
109
+ and data.get("cache_type") == "extract"
110
+ ):
111
+ # Extract cache key (remove namespace prefix)
112
+ cache_key = key.replace(
113
+ f"{self.namespace}:", ""
114
+ )
115
+ result[cache_key] = data
116
+ except json.JSONDecodeError:
117
+ continue
118
+
119
+ if cursor == 0:
120
+ break
121
+
122
+ return result if result else None
123
+ except Exception as e:
124
+ logger.error(f"Error scanning Redis for extract cache entries: {e}")
125
+ return None
126
+ else:
127
+ # Original behavior for non-"default" ids
128
+ async with self._get_redis_connection() as redis:
129
+ try:
130
+ data = await redis.get(f"{self.namespace}:{id}")
131
+ return json.loads(data) if data else None
132
+ except json.JSONDecodeError as e:
133
+ logger.error(f"JSON decode error for id {id}: {e}")
134
+ return None
135
 
136
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
137
  async with self._get_redis_connection() as redis: