yangdx commited on
Commit
6d1b218
·
1 Parent(s): db40764

fix: Improve async handling and FAISS storage reliability

Browse files

- Add async context manager support
- Fix embedding data type conversion
- Improve error handling in FAISS ops
- Add multiprocess storage sync

Files changed (2) hide show
  1. lightrag/api/README.md +1 -1
  2. lightrag/kg/faiss_impl.py +37 -37
lightrag/api/README.md CHANGED
@@ -186,7 +186,7 @@ LightRAG supports binding to various LLM/Embedding backends:
186
  * openai & openai compatible
187
  * azure_openai
188
 
189
- Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type.
190
 
191
  ### Storage Types Supported
192
 
 
186
  * openai & openai compatible
187
  * azure_openai
188
 
189
+ Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
190
 
191
  ### Storage Types Supported
192
 
lightrag/kg/faiss_impl.py CHANGED
@@ -71,7 +71,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
71
  async def _get_index(self):
72
  """Check if the shtorage should be reloaded"""
73
  # Acquire lock to prevent concurrent read and write
74
- with self._storage_lock:
75
  # Check if storage was updated by another process
76
  if (is_multiprocess and self.storage_updated.value) or (
77
  not is_multiprocess and self.storage_updated
@@ -139,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
139
  )
140
  return []
141
 
142
- # Normalize embeddings for cosine similarity (in-place)
 
143
  faiss.normalize_L2(embeddings)
144
 
145
  # Upsert logic:
@@ -153,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
153
  existing_ids_to_remove.append(faiss_internal_id)
154
 
155
  if existing_ids_to_remove:
156
- self._remove_faiss_ids(existing_ids_to_remove)
157
 
158
  # Step 2: Add new vectors
159
  index = await self._get_index()
@@ -185,7 +186,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
185
 
186
  # Perform the similarity search
187
  index = await self._get_index()
188
- distances, indices = index().search(embedding, top_k)
189
 
190
  distances = distances[0]
191
  indices = indices[0]
@@ -229,7 +230,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
229
  to_remove.append(fid)
230
 
231
  if to_remove:
232
- self._remove_faiss_ids(to_remove)
233
  logger.debug(
234
  f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
235
  )
@@ -251,7 +252,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
251
 
252
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
253
  if relations:
254
- self._remove_faiss_ids(relations)
255
  logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
256
 
257
  # --------------------------------------------------------------------------------
@@ -267,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
267
  return fid
268
  return None
269
 
270
- def _remove_faiss_ids(self, fid_list):
271
  """
272
  Remove a list of internal Faiss IDs from the index.
273
  Because IndexFlatIP doesn't support 'removals',
@@ -283,7 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
283
  vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
284
  new_id_to_meta[new_fid] = vec_meta
285
 
286
- with self._storage_lock:
287
  # Re-init index
288
  self._index = faiss.IndexFlatIP(self._dim)
289
  if vectors_to_keep:
@@ -339,35 +340,34 @@ class FaissVectorDBStorage(BaseVectorStorage):
339
  self._index = faiss.IndexFlatIP(self._dim)
340
  self._id_to_meta = {}
341
 
342
-
343
- async def index_done_callback(self) -> None:
344
- # Check if storage was updated by another process
345
- if is_multiprocess and self.storage_updated.value:
346
- # Storage was updated by another process, reload data instead of saving
347
- logger.warning(
348
- f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
349
- )
350
- with self._storage_lock:
351
- self._index = faiss.IndexFlatIP(self._dim)
352
- self._id_to_meta = {}
353
- self._load_faiss_index()
354
- self.storage_updated.value = False
355
- return False # Return error
356
-
357
- # Acquire lock and perform persistence
358
- async with self._storage_lock:
359
- try:
360
- # Save data to disk
361
- self._save_faiss_index()
362
- # Notify other processes that data has been updated
363
- await set_all_update_flags(self.namespace)
364
- # Reset own update flag to avoid self-reloading
365
- if is_multiprocess:
366
  self.storage_updated.value = False
367
- else:
368
- self.storage_updated = False
369
- except Exception as e:
370
- logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
371
  return False # Return error
372
 
373
- return True # Return success
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  async def _get_index(self):
72
  """Check if the shtorage should be reloaded"""
73
  # Acquire lock to prevent concurrent read and write
74
+ async with self._storage_lock:
75
  # Check if storage was updated by another process
76
  if (is_multiprocess and self.storage_updated.value) or (
77
  not is_multiprocess and self.storage_updated
 
139
  )
140
  return []
141
 
142
+ # Convert to float32 and normalize embeddings for cosine similarity (in-place)
143
+ embeddings = embeddings.astype(np.float32)
144
  faiss.normalize_L2(embeddings)
145
 
146
  # Upsert logic:
 
154
  existing_ids_to_remove.append(faiss_internal_id)
155
 
156
  if existing_ids_to_remove:
157
+ await self._remove_faiss_ids(existing_ids_to_remove)
158
 
159
  # Step 2: Add new vectors
160
  index = await self._get_index()
 
186
 
187
  # Perform the similarity search
188
  index = await self._get_index()
189
+ distances, indices = index.search(embedding, top_k)
190
 
191
  distances = distances[0]
192
  indices = indices[0]
 
230
  to_remove.append(fid)
231
 
232
  if to_remove:
233
+ await self._remove_faiss_ids(to_remove)
234
  logger.debug(
235
  f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
236
  )
 
252
 
253
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
254
  if relations:
255
+ await self._remove_faiss_ids(relations)
256
  logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
257
 
258
  # --------------------------------------------------------------------------------
 
268
  return fid
269
  return None
270
 
271
+ async def _remove_faiss_ids(self, fid_list):
272
  """
273
  Remove a list of internal Faiss IDs from the index.
274
  Because IndexFlatIP doesn't support 'removals',
 
284
  vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
285
  new_id_to_meta[new_fid] = vec_meta
286
 
287
+ async with self._storage_lock:
288
  # Re-init index
289
  self._index = faiss.IndexFlatIP(self._dim)
290
  if vectors_to_keep:
 
340
  self._index = faiss.IndexFlatIP(self._dim)
341
  self._id_to_meta = {}
342
 
343
+ async def index_done_callback(self) -> None:
344
+ # Check if storage was updated by another process
345
+ if is_multiprocess and self.storage_updated.value:
346
+ # Storage was updated by another process, reload data instead of saving
347
+ logger.warning(
348
+ f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
349
+ )
350
+ async with self._storage_lock:
351
+ self._index = faiss.IndexFlatIP(self._dim)
352
+ self._id_to_meta = {}
353
+ self._load_faiss_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  self.storage_updated.value = False
 
 
 
 
355
  return False # Return error
356
 
357
+ # Acquire lock and perform persistence
358
+ async with self._storage_lock:
359
+ try:
360
+ # Save data to disk
361
+ self._save_faiss_index()
362
+ # Notify other processes that data has been updated
363
+ await set_all_update_flags(self.namespace)
364
+ # Reset own update flag to avoid self-reloading
365
+ if is_multiprocess:
366
+ self.storage_updated.value = False
367
+ else:
368
+ self.storage_updated = False
369
+ except Exception as e:
370
+ logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
371
+ return False # Return error
372
+
373
+ return True # Return success