yangdx
commited on
Commit
·
05d541c
1
Parent(s):
2655164
Improve multi-process data synchronization and persistence in storage implementations
Browse files• Remove _get_client() or _get_graph() from index_done_callback
• Add return value for index_done_callback
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -66,7 +66,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
66 |
# Check if data needs to be reloaded
|
67 |
if (is_multiprocess and self.storage_updated.value) or \
|
68 |
(not is_multiprocess and self.storage_updated):
|
69 |
-
logger.info(f"
|
70 |
# Reload data
|
71 |
self._client = NanoVectorDB(
|
72 |
self.embedding_func.embedding_dim,
|
@@ -199,7 +199,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
199 |
except Exception as e:
|
200 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
201 |
|
202 |
-
async def index_done_callback(self) ->
|
|
|
203 |
# Check if storage was updated by another process
|
204 |
if is_multiprocess and self.storage_updated.value:
|
205 |
# Storage was updated by another process, reload data instead of saving
|
@@ -213,14 +214,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
213 |
return False # Return error
|
214 |
|
215 |
# Acquire lock and perform persistence
|
216 |
-
client = await self._get_client()
|
217 |
async with self._storage_lock:
|
218 |
try:
|
219 |
# Save data to disk
|
220 |
-
|
221 |
# Notify other processes that data has been updated
|
222 |
await set_all_update_flags(self.namespace)
|
223 |
-
# Reset own update flag to avoid self-
|
224 |
if is_multiprocess:
|
225 |
self.storage_updated.value = False
|
226 |
else:
|
@@ -229,3 +229,5 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
229 |
except Exception as e:
|
230 |
logger.error(f"Error saving data for {self.namespace}: {e}")
|
231 |
return False # Return error
|
|
|
|
|
|
66 |
# Check if data needs to be reloaded
|
67 |
if (is_multiprocess and self.storage_updated.value) or \
|
68 |
(not is_multiprocess and self.storage_updated):
|
69 |
+
logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process")
|
70 |
# Reload data
|
71 |
self._client = NanoVectorDB(
|
72 |
self.embedding_func.embedding_dim,
|
|
|
199 |
except Exception as e:
|
200 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
201 |
|
202 |
+
async def index_done_callback(self) -> bool:
|
203 |
+
"""Save data to disk"""
|
204 |
# Check if storage was updated by another process
|
205 |
if is_multiprocess and self.storage_updated.value:
|
206 |
# Storage was updated by another process, reload data instead of saving
|
|
|
214 |
return False # Return error
|
215 |
|
216 |
# Acquire lock and perform persistence
|
|
|
217 |
async with self._storage_lock:
|
218 |
try:
|
219 |
# Save data to disk
|
220 |
+
self._get_client.save()
|
221 |
# Notify other processes that data has been updated
|
222 |
await set_all_update_flags(self.namespace)
|
223 |
+
# Reset own update flag to avoid self-reloading
|
224 |
if is_multiprocess:
|
225 |
self.storage_updated.value = False
|
226 |
else:
|
|
|
229 |
except Exception as e:
|
230 |
logger.error(f"Error saving data for {self.namespace}: {e}")
|
231 |
return False # Return error
|
232 |
+
|
233 |
+
return True # Return success
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -110,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
110 |
# Check if data needs to be reloaded
|
111 |
if (is_multiprocess and self.storage_updated.value) or \
|
112 |
(not is_multiprocess and self.storage_updated):
|
113 |
-
logger.info(f"
|
114 |
# Reload data
|
115 |
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
116 |
# Reset update flag
|
@@ -329,7 +329,8 @@ class NetworkXStorage(BaseGraphStorage):
|
|
329 |
)
|
330 |
return result
|
331 |
|
332 |
-
async def index_done_callback(self) ->
|
|
|
333 |
# Check if storage was updated by another process
|
334 |
if is_multiprocess and self.storage_updated.value:
|
335 |
# Storage was updated by another process, reload data instead of saving
|
@@ -340,14 +341,13 @@ class NetworkXStorage(BaseGraphStorage):
|
|
340 |
return False # Return error
|
341 |
|
342 |
# Acquire lock and perform persistence
|
343 |
-
graph = await self._get_graph()
|
344 |
async with self._storage_lock:
|
345 |
try:
|
346 |
# Save data to disk
|
347 |
-
NetworkXStorage.write_nx_graph(
|
348 |
# Notify other processes that data has been updated
|
349 |
await set_all_update_flags(self.namespace)
|
350 |
-
# Reset own update flag to avoid self-
|
351 |
if is_multiprocess:
|
352 |
self.storage_updated.value = False
|
353 |
else:
|
@@ -356,3 +356,5 @@ class NetworkXStorage(BaseGraphStorage):
|
|
356 |
except Exception as e:
|
357 |
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
358 |
return False # Return error
|
|
|
|
|
|
110 |
# Check if data needs to be reloaded
|
111 |
if (is_multiprocess and self.storage_updated.value) or \
|
112 |
(not is_multiprocess and self.storage_updated):
|
113 |
+
logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process")
|
114 |
# Reload data
|
115 |
self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
116 |
# Reset update flag
|
|
|
329 |
)
|
330 |
return result
|
331 |
|
332 |
+
async def index_done_callback(self) -> bool:
|
333 |
+
"""Save data to disk"""
|
334 |
# Check if storage was updated by another process
|
335 |
if is_multiprocess and self.storage_updated.value:
|
336 |
# Storage was updated by another process, reload data instead of saving
|
|
|
341 |
return False # Return error
|
342 |
|
343 |
# Acquire lock and perform persistence
|
|
|
344 |
async with self._storage_lock:
|
345 |
try:
|
346 |
# Save data to disk
|
347 |
+
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
348 |
# Notify other processes that data has been updated
|
349 |
await set_all_update_flags(self.namespace)
|
350 |
+
# Reset own update flag to avoid self-reloading
|
351 |
if is_multiprocess:
|
352 |
self.storage_updated.value = False
|
353 |
else:
|
|
|
356 |
except Exception as e:
|
357 |
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
358 |
return False # Return error
|
359 |
+
|
360 |
+
return True
|