yangdx commited on
Commit
8571c18
·
1 Parent(s): 05d541c

feat: add multi-process support for FAISS vector storage

Browse files

• Add storage update flag and locks
• Support cross-process index reload
• Add async initialize method

Files changed (1) hide show
  1. lightrag/kg/faiss_impl.py +64 -14
lightrag/kg/faiss_impl.py CHANGED
@@ -15,7 +15,12 @@ if not pm.is_installed("faiss"):
15
  pm.install("faiss")
16
 
17
  import faiss # type: ignore
18
- from threading import Lock as ThreadLock
 
 
 
 
 
19
 
20
 
21
  @final
@@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage):
45
  self._max_batch_size = self.global_config["embedding_batch_num"]
46
  # Embedding dimension (e.g. 768) must match your embedding function
47
  self._dim = self.embedding_func.embedding_dim
48
- self._storage_lock = ThreadLock()
49
-
50
  # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
51
  # If you have a large number of vectors, you might want IVF or other indexes.
52
  # For demonstration, we use a simple IndexFlatIP.
53
  self._index = faiss.IndexFlatIP(self._dim)
54
-
55
  # Keep a local store for metadata, IDs, etc.
56
  # Maps <int faiss_id> → metadata (including your original ID).
57
  self._id_to_meta = {}
58
 
59
- # Attempt to load an existing index + metadata from disk
60
- with self._storage_lock:
61
- self._load_faiss_index()
62
 
63
- def _get_index(self):
64
- """Check if the shtorage should be reloaded"""
65
- return self._index
 
 
 
66
 
67
- async def index_done_callback(self) -> None:
 
 
68
  with self._storage_lock:
69
- self._save_faiss_index()
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
71
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
72
  """
73
  Insert or update vectors in the Faiss index.
@@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
135
  self._remove_faiss_ids(existing_ids_to_remove)
136
 
137
  # Step 2: Add new vectors
138
- index = self._get_index()
139
  start_idx = index.ntotal
140
  index.add(embeddings)
141
 
@@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
163
  )
164
 
165
  # Perform the similarity search
166
- distances, indices = self._get_index().search(embedding, top_k)
 
167
 
168
  distances = distances[0]
169
  indices = indices[0]
@@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
316
  logger.warning("Starting with an empty Faiss index.")
317
  self._index = faiss.IndexFlatIP(self._dim)
318
  self._id_to_meta = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  pm.install("faiss")
16
 
17
  import faiss # type: ignore
18
+ from .shared_storage import (
19
+ get_storage_lock,
20
+ get_update_flag,
21
+ set_all_update_flags,
22
+ is_multiprocess,
23
+ )
24
 
25
 
26
  @final
 
50
  self._max_batch_size = self.global_config["embedding_batch_num"]
51
  # Embedding dimension (e.g. 768) must match your embedding function
52
  self._dim = self.embedding_func.embedding_dim
53
+
 
54
  # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
55
  # If you have a large number of vectors, you might want IVF or other indexes.
56
  # For demonstration, we use a simple IndexFlatIP.
57
  self._index = faiss.IndexFlatIP(self._dim)
 
58
  # Keep a local store for metadata, IDs, etc.
59
  # Maps <int faiss_id> → metadata (including your original ID).
60
  self._id_to_meta = {}
61
 
62
+ self._load_faiss_index()
 
 
63
 
64
+ async def initialize(self):
65
+ """Initialize storage data"""
66
+ # Get the update flag for cross-process update notification
67
+ self.storage_updated = await get_update_flag(self.namespace)
68
+ # Get the storage lock for use in other methods
69
+ self._storage_lock = get_storage_lock()
70
 
71
+ async def _get_index(self):
72
+ """Check if the shtorage should be reloaded"""
73
+ # Acquire lock to prevent concurrent read and write
74
  with self._storage_lock:
75
+ # Check if storage was updated by another process
76
+ if (is_multiprocess and self.storage_updated.value) or \
77
+ (not is_multiprocess and self.storage_updated):
78
+ logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process")
79
+ # Reload data
80
+ self._index = faiss.IndexFlatIP(self._dim)
81
+ self._id_to_meta = {}
82
+ self._load_faiss_index()
83
+ if is_multiprocess:
84
+ self.storage_updated.value = False
85
+ else:
86
+ self.storage_updated = False
87
+ return self._index
88
 
89
+
90
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
91
  """
92
  Insert or update vectors in the Faiss index.
 
154
  self._remove_faiss_ids(existing_ids_to_remove)
155
 
156
  # Step 2: Add new vectors
157
+ index = await self._get_index()
158
  start_idx = index.ntotal
159
  index.add(embeddings)
160
 
 
182
  )
183
 
184
  # Perform the similarity search
185
+ index = await self._get_index()
186
+ distances, indices = index().search(embedding, top_k)
187
 
188
  distances = distances[0]
189
  indices = indices[0]
 
336
  logger.warning("Starting with an empty Faiss index.")
337
  self._index = faiss.IndexFlatIP(self._dim)
338
  self._id_to_meta = {}
339
+
340
+ async def index_done_callback(self) -> None:
341
+ # Check if storage was updated by another process
342
+ if is_multiprocess and self.storage_updated.value:
343
+ # Storage was updated by another process, reload data instead of saving
344
+ logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...")
345
+ with self._storage_lock:
346
+ self._index = faiss.IndexFlatIP(self._dim)
347
+ self._id_to_meta = {}
348
+ self._load_faiss_index()
349
+ self.storage_updated.value = False
350
+ return False # Return error
351
+
352
+ # Acquire lock and perform persistence
353
+ async with self._storage_lock:
354
+ try:
355
+ # Save data to disk
356
+ self._save_faiss_index()
357
+ # Set all update flags to False
358
+ await set_all_update_flags(self.namespace)
359
+ # Reset own update flag to avoid self-reloading
360
+ if is_multiprocess:
361
+ self.storage_updated.value = False
362
+ else:
363
+ self.storage_updated = False
364
+ except Exception as e:
365
+ logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
366
+ return False # Return error
367
+
368
+ return True # Return success