yangdx
commited on
Commit
·
8571c18
1
Parent(s):
05d541c
feat: add multi-process support for FAISS vector storage
Browse files• Add storage update flag and locks
• Support cross-process index reload
• Add async initialize method
- lightrag/kg/faiss_impl.py +64 -14
lightrag/kg/faiss_impl.py
CHANGED
@@ -15,7 +15,12 @@ if not pm.is_installed("faiss"):
|
|
15 |
pm.install("faiss")
|
16 |
|
17 |
import faiss # type: ignore
|
18 |
-
from
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
@final
|
@@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
45 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
46 |
# Embedding dimension (e.g. 768) must match your embedding function
|
47 |
self._dim = self.embedding_func.embedding_dim
|
48 |
-
|
49 |
-
|
50 |
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
51 |
# If you have a large number of vectors, you might want IVF or other indexes.
|
52 |
# For demonstration, we use a simple IndexFlatIP.
|
53 |
self._index = faiss.IndexFlatIP(self._dim)
|
54 |
-
|
55 |
# Keep a local store for metadata, IDs, etc.
|
56 |
# Maps <int faiss_id> → metadata (including your original ID).
|
57 |
self._id_to_meta = {}
|
58 |
|
59 |
-
|
60 |
-
with self._storage_lock:
|
61 |
-
self._load_faiss_index()
|
62 |
|
63 |
-
def
|
64 |
-
"""
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
async def
|
|
|
|
|
68 |
with self._storage_lock:
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
71 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
72 |
"""
|
73 |
Insert or update vectors in the Faiss index.
|
@@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
135 |
self._remove_faiss_ids(existing_ids_to_remove)
|
136 |
|
137 |
# Step 2: Add new vectors
|
138 |
-
index = self._get_index()
|
139 |
start_idx = index.ntotal
|
140 |
index.add(embeddings)
|
141 |
|
@@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
163 |
)
|
164 |
|
165 |
# Perform the similarity search
|
166 |
-
|
|
|
167 |
|
168 |
distances = distances[0]
|
169 |
indices = indices[0]
|
@@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
316 |
logger.warning("Starting with an empty Faiss index.")
|
317 |
self._index = faiss.IndexFlatIP(self._dim)
|
318 |
self._id_to_meta = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
pm.install("faiss")
|
16 |
|
17 |
import faiss # type: ignore
|
18 |
+
from .shared_storage import (
|
19 |
+
get_storage_lock,
|
20 |
+
get_update_flag,
|
21 |
+
set_all_update_flags,
|
22 |
+
is_multiprocess,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
@final
|
|
|
50 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
51 |
# Embedding dimension (e.g. 768) must match your embedding function
|
52 |
self._dim = self.embedding_func.embedding_dim
|
53 |
+
|
|
|
54 |
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
55 |
# If you have a large number of vectors, you might want IVF or other indexes.
|
56 |
# For demonstration, we use a simple IndexFlatIP.
|
57 |
self._index = faiss.IndexFlatIP(self._dim)
|
|
|
58 |
# Keep a local store for metadata, IDs, etc.
|
59 |
# Maps <int faiss_id> → metadata (including your original ID).
|
60 |
self._id_to_meta = {}
|
61 |
|
62 |
+
self._load_faiss_index()
|
|
|
|
|
63 |
|
64 |
+
async def initialize(self):
|
65 |
+
"""Initialize storage data"""
|
66 |
+
# Get the update flag for cross-process update notification
|
67 |
+
self.storage_updated = await get_update_flag(self.namespace)
|
68 |
+
# Get the storage lock for use in other methods
|
69 |
+
self._storage_lock = get_storage_lock()
|
70 |
|
71 |
+
async def _get_index(self):
|
72 |
+
"""Check if the shtorage should be reloaded"""
|
73 |
+
# Acquire lock to prevent concurrent read and write
|
74 |
with self._storage_lock:
|
75 |
+
# Check if storage was updated by another process
|
76 |
+
if (is_multiprocess and self.storage_updated.value) or \
|
77 |
+
(not is_multiprocess and self.storage_updated):
|
78 |
+
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process")
|
79 |
+
# Reload data
|
80 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
81 |
+
self._id_to_meta = {}
|
82 |
+
self._load_faiss_index()
|
83 |
+
if is_multiprocess:
|
84 |
+
self.storage_updated.value = False
|
85 |
+
else:
|
86 |
+
self.storage_updated = False
|
87 |
+
return self._index
|
88 |
|
89 |
+
|
90 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
91 |
"""
|
92 |
Insert or update vectors in the Faiss index.
|
|
|
154 |
self._remove_faiss_ids(existing_ids_to_remove)
|
155 |
|
156 |
# Step 2: Add new vectors
|
157 |
+
index = await self._get_index()
|
158 |
start_idx = index.ntotal
|
159 |
index.add(embeddings)
|
160 |
|
|
|
182 |
)
|
183 |
|
184 |
# Perform the similarity search
|
185 |
+
index = await self._get_index()
|
186 |
+
distances, indices = index().search(embedding, top_k)
|
187 |
|
188 |
distances = distances[0]
|
189 |
indices = indices[0]
|
|
|
336 |
logger.warning("Starting with an empty Faiss index.")
|
337 |
self._index = faiss.IndexFlatIP(self._dim)
|
338 |
self._id_to_meta = {}
|
339 |
+
|
340 |
+
async def index_done_callback(self) -> None:
|
341 |
+
# Check if storage was updated by another process
|
342 |
+
if is_multiprocess and self.storage_updated.value:
|
343 |
+
# Storage was updated by another process, reload data instead of saving
|
344 |
+
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...")
|
345 |
+
with self._storage_lock:
|
346 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
347 |
+
self._id_to_meta = {}
|
348 |
+
self._load_faiss_index()
|
349 |
+
self.storage_updated.value = False
|
350 |
+
return False # Return error
|
351 |
+
|
352 |
+
# Acquire lock and perform persistence
|
353 |
+
async with self._storage_lock:
|
354 |
+
try:
|
355 |
+
# Save data to disk
|
356 |
+
self._save_faiss_index()
|
357 |
+
# Set all update flags to False
|
358 |
+
await set_all_update_flags(self.namespace)
|
359 |
+
# Reset own update flag to avoid self-reloading
|
360 |
+
if is_multiprocess:
|
361 |
+
self.storage_updated.value = False
|
362 |
+
else:
|
363 |
+
self.storage_updated = False
|
364 |
+
except Exception as e:
|
365 |
+
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
366 |
+
return False # Return error
|
367 |
+
|
368 |
+
return True # Return success
|