gzdaniel commited on
Commit
9535cef
·
1 Parent(s): 73852ac

Refac: Enhance KG rebuild stability by incorporating `create_time` into the LLM cache

Browse files
lightrag/kg/json_kv_impl.py CHANGED
@@ -78,22 +78,49 @@ class JsonKVStorage(BaseKVStorage):
78
  Dictionary containing all stored data
79
  """
80
  async with self._storage_lock:
81
- return dict(self._data)
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
84
  async with self._storage_lock:
85
- return self._data.get(id)
 
 
 
 
 
 
 
 
 
86
 
87
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
88
  async with self._storage_lock:
89
- return [
90
- (
91
- {k: v for k, v in self._data[id].items()}
92
- if self._data.get(id, None)
93
- else None
94
- )
95
- for id in ids
96
- ]
 
 
 
 
 
 
 
97
 
98
  async def filter_keys(self, keys: set[str]) -> set[str]:
99
  async with self._storage_lock:
@@ -107,13 +134,29 @@ class JsonKVStorage(BaseKVStorage):
107
  """
108
  if not data:
109
  return
 
 
 
 
 
110
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
111
  async with self._storage_lock:
112
- # For text_chunks namespace, ensure llm_cache_list field exists
113
- if "text_chunks" in self.namespace:
114
- for chunk_id, chunk_data in data.items():
115
- if "llm_cache_list" not in chunk_data:
116
- chunk_data["llm_cache_list"] = []
 
 
 
 
 
 
 
 
 
 
 
117
  self._data.update(data)
118
  await set_all_update_flags(self.namespace)
119
 
 
78
  Dictionary containing all stored data
79
  """
80
  async with self._storage_lock:
81
+ result = {}
82
+ for key, value in self._data.items():
83
+ if value:
84
+ # Create a copy to avoid modifying the original data
85
+ data = dict(value)
86
+ # Ensure time fields are present, provide default values for old data
87
+ data.setdefault("create_time", 0)
88
+ data.setdefault("update_time", 0)
89
+ result[key] = data
90
+ else:
91
+ result[key] = value
92
+ return result
93
 
94
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
95
  async with self._storage_lock:
96
+ result = self._data.get(id)
97
+ if result:
98
+ # Create a copy to avoid modifying the original data
99
+ result = dict(result)
100
+ # Ensure time fields are present, provide default values for old data
101
+ result.setdefault("create_time", 0)
102
+ result.setdefault("update_time", 0)
103
+ # Ensure _id field contains the clean ID
104
+ result["_id"] = id
105
+ return result
106
 
107
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
108
  async with self._storage_lock:
109
+ results = []
110
+ for id in ids:
111
+ data = self._data.get(id, None)
112
+ if data:
113
+ # Create a copy to avoid modifying the original data
114
+ result = {k: v for k, v in data.items()}
115
+ # Ensure time fields are present, provide default values for old data
116
+ result.setdefault("create_time", 0)
117
+ result.setdefault("update_time", 0)
118
+ # Ensure _id field contains the clean ID
119
+ result["_id"] = id
120
+ results.append(result)
121
+ else:
122
+ results.append(None)
123
+ return results
124
 
125
  async def filter_keys(self, keys: set[str]) -> set[str]:
126
  async with self._storage_lock:
 
134
  """
135
  if not data:
136
  return
137
+
138
+ import time
139
+
140
+ current_time = int(time.time()) # Get current Unix timestamp
141
+
142
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
143
  async with self._storage_lock:
144
+ # Add timestamps to data based on whether key exists
145
+ for k, v in data.items():
146
+ # For text_chunks namespace, ensure llm_cache_list field exists
147
+ if "text_chunks" in self.namespace:
148
+ if "llm_cache_list" not in v:
149
+ v["llm_cache_list"] = []
150
+
151
+ # Add timestamps based on whether key exists
152
+ if k in self._data: # Key exists, only update update_time
153
+ v["update_time"] = current_time
154
+ else: # New key, set both create_time and update_time
155
+ v["create_time"] = current_time
156
+ v["update_time"] = current_time
157
+
158
+ v["_id"] = k
159
+
160
  self._data.update(data)
161
  await set_all_update_flags(self.namespace)
162
 
lightrag/kg/mongo_impl.py CHANGED
@@ -98,11 +98,21 @@ class MongoKVStorage(BaseKVStorage):
98
 
99
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
100
  # Unified handling for flattened keys
101
- return await self._data.find_one({"_id": id})
 
 
 
 
 
102
 
103
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
104
  cursor = self._data.find({"_id": {"$in": ids}})
105
- return await cursor.to_list()
 
 
 
 
 
106
 
107
  async def filter_keys(self, keys: set[str]) -> set[str]:
108
  cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@@ -119,6 +129,9 @@ class MongoKVStorage(BaseKVStorage):
119
  result = {}
120
  async for doc in cursor:
121
  doc_id = doc.pop("_id")
 
 
 
122
  result[doc_id] = doc
123
  return result
124
 
@@ -132,6 +145,8 @@ class MongoKVStorage(BaseKVStorage):
132
  from pymongo import UpdateOne
133
 
134
  operations = []
 
 
135
  for k, v in data.items():
136
  # For text_chunks namespace, ensure llm_cache_list field exists
137
  if self.namespace.endswith("text_chunks"):
@@ -139,7 +154,20 @@ class MongoKVStorage(BaseKVStorage):
139
  v["llm_cache_list"] = []
140
 
141
  v["_id"] = k # Use flattened key as _id
142
- operations.append(UpdateOne({"_id": k}, {"$set": v}, upsert=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  if operations:
145
  await self._data.bulk_write(operations)
 
98
 
99
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
100
  # Unified handling for flattened keys
101
+ doc = await self._data.find_one({"_id": id})
102
+ if doc:
103
+ # Ensure time fields are present, provide default values for old data
104
+ doc.setdefault("create_time", 0)
105
+ doc.setdefault("update_time", 0)
106
+ return doc
107
 
108
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
109
  cursor = self._data.find({"_id": {"$in": ids}})
110
+ docs = await cursor.to_list()
111
+ # Ensure time fields are present for all documents
112
+ for doc in docs:
113
+ doc.setdefault("create_time", 0)
114
+ doc.setdefault("update_time", 0)
115
+ return docs
116
 
117
  async def filter_keys(self, keys: set[str]) -> set[str]:
118
  cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
 
129
  result = {}
130
  async for doc in cursor:
131
  doc_id = doc.pop("_id")
132
+ # Ensure time fields are present for all documents
133
+ doc.setdefault("create_time", 0)
134
+ doc.setdefault("update_time", 0)
135
  result[doc_id] = doc
136
  return result
137
 
 
145
  from pymongo import UpdateOne
146
 
147
  operations = []
148
+ current_time = int(time.time()) # Get current Unix timestamp
149
+
150
  for k, v in data.items():
151
  # For text_chunks namespace, ensure llm_cache_list field exists
152
  if self.namespace.endswith("text_chunks"):
 
154
  v["llm_cache_list"] = []
155
 
156
  v["_id"] = k # Use flattened key as _id
157
+ v["update_time"] = current_time # Always update update_time
158
+
159
+ operations.append(
160
+ UpdateOne(
161
+ {"_id": k},
162
+ {
163
+ "$set": v, # Update all fields including update_time
164
+ "$setOnInsert": {
165
+ "create_time": current_time
166
+ }, # Set create_time only on insert
167
+ },
168
+ upsert=True,
169
+ )
170
+ )
171
 
172
  if operations:
173
  await self._data.bulk_write(operations)
lightrag/kg/postgres_impl.py CHANGED
@@ -752,6 +752,8 @@ class PGKVStorage(BaseKVStorage):
752
  "original_prompt": row.get("original_prompt", ""),
753
  "chunk_id": row.get("chunk_id"),
754
  "mode": row.get("mode", "default"),
 
 
755
  }
756
  processed_results[row["id"]] = processed_row
757
  return processed_results
@@ -767,6 +769,8 @@ class PGKVStorage(BaseKVStorage):
767
  except json.JSONDecodeError:
768
  llm_cache_list = []
769
  row["llm_cache_list"] = llm_cache_list
 
 
770
  processed_results[row["id"]] = row
771
  return processed_results
772
 
@@ -791,6 +795,8 @@ class PGKVStorage(BaseKVStorage):
791
  except json.JSONDecodeError:
792
  llm_cache_list = []
793
  response["llm_cache_list"] = llm_cache_list
 
 
794
 
795
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
796
  if response and is_namespace(
@@ -804,6 +810,8 @@ class PGKVStorage(BaseKVStorage):
804
  "original_prompt": response.get("original_prompt", ""),
805
  "chunk_id": response.get("chunk_id"),
806
  "mode": response.get("mode", "default"),
 
 
807
  }
808
 
809
  return response if response else None
@@ -827,6 +835,8 @@ class PGKVStorage(BaseKVStorage):
827
  except json.JSONDecodeError:
828
  llm_cache_list = []
829
  result["llm_cache_list"] = llm_cache_list
 
 
830
 
831
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
832
  if results and is_namespace(
@@ -842,6 +852,8 @@ class PGKVStorage(BaseKVStorage):
842
  "original_prompt": row.get("original_prompt", ""),
843
  "chunk_id": row.get("chunk_id"),
844
  "mode": row.get("mode", "default"),
 
 
845
  }
846
  processed_results.append(processed_row)
847
  return processed_results
@@ -2941,10 +2953,12 @@ SQL_TEMPLATES = {
2941
  """,
2942
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2943
  chunk_order_index, full_doc_id, file_path,
2944
- COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
 
2945
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2946
  """,
2947
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
 
2948
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
2949
  """,
2950
  "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
@@ -2955,10 +2969,12 @@ SQL_TEMPLATES = {
2955
  """,
2956
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2957
  chunk_order_index, full_doc_id, file_path,
2958
- COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list
 
2959
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
2960
  """,
2961
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type
 
2962
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
2963
  """,
2964
  "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
 
752
  "original_prompt": row.get("original_prompt", ""),
753
  "chunk_id": row.get("chunk_id"),
754
  "mode": row.get("mode", "default"),
755
+ "create_time": row.get("create_time", 0),
756
+ "update_time": row.get("update_time", 0),
757
  }
758
  processed_results[row["id"]] = processed_row
759
  return processed_results
 
769
  except json.JSONDecodeError:
770
  llm_cache_list = []
771
  row["llm_cache_list"] = llm_cache_list
772
+ row["create_time"] = row.get("create_time", 0)
773
+ row["update_time"] = row.get("update_time", 0)
774
  processed_results[row["id"]] = row
775
  return processed_results
776
 
 
795
  except json.JSONDecodeError:
796
  llm_cache_list = []
797
  response["llm_cache_list"] = llm_cache_list
798
+ response["create_time"] = response.get("create_time", 0)
799
+ response["update_time"] = response.get("update_time", 0)
800
 
801
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
802
  if response and is_namespace(
 
810
  "original_prompt": response.get("original_prompt", ""),
811
  "chunk_id": response.get("chunk_id"),
812
  "mode": response.get("mode", "default"),
813
+ "create_time": response.get("create_time", 0),
814
+ "update_time": response.get("update_time", 0),
815
  }
816
 
817
  return response if response else None
 
835
  except json.JSONDecodeError:
836
  llm_cache_list = []
837
  result["llm_cache_list"] = llm_cache_list
838
+ result["create_time"] = result.get("create_time", 0)
839
+ result["update_time"] = result.get("update_time", 0)
840
 
841
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
842
  if results and is_namespace(
 
852
  "original_prompt": row.get("original_prompt", ""),
853
  "chunk_id": row.get("chunk_id"),
854
  "mode": row.get("mode", "default"),
855
+ "create_time": row.get("create_time", 0),
856
+ "update_time": row.get("update_time", 0),
857
  }
858
  processed_results.append(processed_row)
859
  return processed_results
 
2953
  """,
2954
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2955
  chunk_order_index, full_doc_id, file_path,
2956
+ COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
2957
+ create_time, update_time
2958
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
2959
  """,
2960
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
2961
+ create_time, update_time
2962
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2
2963
  """,
2964
  "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id
 
2969
  """,
2970
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
2971
  chunk_order_index, full_doc_id, file_path,
2972
+ COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
2973
+ create_time, update_time
2974
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
2975
  """,
2976
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type,
2977
+ create_time, update_time
2978
  FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
2979
  """,
2980
  "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
lightrag/kg/redis_impl.py CHANGED
@@ -132,7 +132,13 @@ class RedisKVStorage(BaseKVStorage):
132
  async with self._get_redis_connection() as redis:
133
  try:
134
  data = await redis.get(f"{self.namespace}:{id}")
135
- return json.loads(data) if data else None
 
 
 
 
 
 
136
  except json.JSONDecodeError as e:
137
  logger.error(f"JSON decode error for id {id}: {e}")
138
  return None
@@ -144,7 +150,19 @@ class RedisKVStorage(BaseKVStorage):
144
  for id in ids:
145
  pipe.get(f"{self.namespace}:{id}")
146
  results = await pipe.execute()
147
- return [json.loads(result) if result else None for result in results]
 
 
 
 
 
 
 
 
 
 
 
 
148
  except json.JSONDecodeError as e:
149
  logger.error(f"JSON decode error in batch get: {e}")
150
  return [None] * len(ids)
@@ -176,7 +194,11 @@ class RedisKVStorage(BaseKVStorage):
176
  # Extract the ID part (after namespace:)
177
  key_id = key.split(":", 1)[1]
178
  try:
179
- result[key_id] = json.loads(value)
 
 
 
 
180
  except json.JSONDecodeError as e:
181
  logger.error(f"JSON decode error for key {key}: {e}")
182
  continue
@@ -200,21 +222,41 @@ class RedisKVStorage(BaseKVStorage):
200
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
201
  if not data:
202
  return
 
 
 
 
 
203
  async with self._get_redis_connection() as redis:
204
  try:
205
- # For text_chunks namespace, ensure llm_cache_list field exists
206
- if "text_chunks" in self.namespace:
207
- for chunk_id, chunk_data in data.items():
208
- if "llm_cache_list" not in chunk_data:
209
- chunk_data["llm_cache_list"] = []
210
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  pipe = redis.pipeline()
212
  for k, v in data.items():
213
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
214
  await pipe.execute()
215
 
216
- for k in data:
217
- data[k]["_id"] = k
218
  except json.JSONEncodeError as e:
219
  logger.error(f"JSON encode error during upsert: {e}")
220
  raise
 
132
  async with self._get_redis_connection() as redis:
133
  try:
134
  data = await redis.get(f"{self.namespace}:{id}")
135
+ if data:
136
+ result = json.loads(data)
137
+ # Ensure time fields are present, provide default values for old data
138
+ result.setdefault("create_time", 0)
139
+ result.setdefault("update_time", 0)
140
+ return result
141
+ return None
142
  except json.JSONDecodeError as e:
143
  logger.error(f"JSON decode error for id {id}: {e}")
144
  return None
 
150
  for id in ids:
151
  pipe.get(f"{self.namespace}:{id}")
152
  results = await pipe.execute()
153
+
154
+ processed_results = []
155
+ for result in results:
156
+ if result:
157
+ data = json.loads(result)
158
+ # Ensure time fields are present for all documents
159
+ data.setdefault("create_time", 0)
160
+ data.setdefault("update_time", 0)
161
+ processed_results.append(data)
162
+ else:
163
+ processed_results.append(None)
164
+
165
+ return processed_results
166
  except json.JSONDecodeError as e:
167
  logger.error(f"JSON decode error in batch get: {e}")
168
  return [None] * len(ids)
 
194
  # Extract the ID part (after namespace:)
195
  key_id = key.split(":", 1)[1]
196
  try:
197
+ data = json.loads(value)
198
+ # Ensure time fields are present for all documents
199
+ data.setdefault("create_time", 0)
200
+ data.setdefault("update_time", 0)
201
+ result[key_id] = data
202
  except json.JSONDecodeError as e:
203
  logger.error(f"JSON decode error for key {key}: {e}")
204
  continue
 
222
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
223
  if not data:
224
  return
225
+
226
+ import time
227
+
228
+ current_time = int(time.time()) # Get current Unix timestamp
229
+
230
  async with self._get_redis_connection() as redis:
231
  try:
232
+ # Check which keys already exist to determine create vs update
233
+ pipe = redis.pipeline()
234
+ for k in data.keys():
235
+ pipe.exists(f"{self.namespace}:{k}")
236
+ exists_results = await pipe.execute()
237
+
238
+ # Add timestamps to data
239
+ for i, (k, v) in enumerate(data.items()):
240
+ # For text_chunks namespace, ensure llm_cache_list field exists
241
+ if "text_chunks" in self.namespace:
242
+ if "llm_cache_list" not in v:
243
+ v["llm_cache_list"] = []
244
+
245
+ # Add timestamps based on whether key exists
246
+ if exists_results[i]: # Key exists, only update update_time
247
+ v["update_time"] = current_time
248
+ else: # New key, set both create_time and update_time
249
+ v["create_time"] = current_time
250
+ v["update_time"] = current_time
251
+
252
+ v["_id"] = k
253
+
254
+ # Store the data
255
  pipe = redis.pipeline()
256
  for k, v in data.items():
257
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
258
  await pipe.execute()
259
 
 
 
260
  except json.JSONEncodeError as e:
261
  logger.error(f"JSON encode error during upsert: {e}")
262
  raise
lightrag/operate.py CHANGED
@@ -273,8 +273,6 @@ async def _rebuild_knowledge_from_chunks(
273
  all_referenced_chunk_ids.update(chunk_ids)
274
  for chunk_ids in relationships_to_rebuild.values():
275
  all_referenced_chunk_ids.update(chunk_ids)
276
- # sort all_referenced_chunk_ids to get a stable order in merge stage
277
- all_referenced_chunk_ids = sorted(all_referenced_chunk_ids)
278
 
279
  status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
280
  logger.info(status_message)
@@ -464,12 +462,22 @@ async def _get_cached_extraction_results(
464
  ):
465
  chunk_id = cache_entry["chunk_id"]
466
  extraction_result = cache_entry["return"]
 
 
 
467
  valid_entries += 1
468
 
469
  # Support multiple LLM caches per chunk
470
  if chunk_id not in cached_results:
471
  cached_results[chunk_id] = []
472
- cached_results[chunk_id].append(extraction_result)
 
 
 
 
 
 
 
473
 
474
  logger.info(
475
  f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"
 
273
  all_referenced_chunk_ids.update(chunk_ids)
274
  for chunk_ids in relationships_to_rebuild.values():
275
  all_referenced_chunk_ids.update(chunk_ids)
 
 
276
 
277
  status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions"
278
  logger.info(status_message)
 
462
  ):
463
  chunk_id = cache_entry["chunk_id"]
464
  extraction_result = cache_entry["return"]
465
+ create_time = cache_entry.get(
466
+ "create_time", 0
467
+ ) # Get creation time, default to 0
468
  valid_entries += 1
469
 
470
  # Support multiple LLM caches per chunk
471
  if chunk_id not in cached_results:
472
  cached_results[chunk_id] = []
473
+ # Store tuple with extraction result and creation time for sorting
474
+ cached_results[chunk_id].append((extraction_result, create_time))
475
+
476
+ # Sort extraction results by create_time for each chunk
477
+ for chunk_id in cached_results:
478
+ # Sort by create_time (x[1]), then extract only extraction_result (x[0])
479
+ cached_results[chunk_id].sort(key=lambda x: x[1])
480
+ cached_results[chunk_id] = [item[0] for item in cached_results[chunk_id]]
481
 
482
  logger.info(
483
  f"Found {valid_entries} valid cache entries, {len(cached_results)} chunks with results"