yangdx
commited on
Commit
·
1ce985e
1
Parent(s):
87ba2e6
Refactor shared storage to safely handle multi-process initialization and data sharing
Browse files- lightrag/kg/json_doc_status_impl.py +15 -5
- lightrag/kg/shared_storage.py +78 -50
lightrag/kg/json_doc_status_impl.py
CHANGED
@@ -12,7 +12,11 @@ from lightrag.utils import (
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
15 |
-
from .shared_storage import
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
@final
|
@@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
24 |
working_dir = self.global_config["working_dir"]
|
25 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
26 |
self._storage_lock = get_storage_lock()
|
|
|
|
|
|
|
27 |
self._data = get_namespace_data(self.namespace)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
34 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
15 |
+
from .shared_storage import (
|
16 |
+
get_namespace_data,
|
17 |
+
get_storage_lock,
|
18 |
+
try_initialize_namespace,
|
19 |
+
)
|
20 |
|
21 |
|
22 |
@final
|
|
|
28 |
working_dir = self.global_config["working_dir"]
|
29 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
30 |
self._storage_lock = get_storage_lock()
|
31 |
+
|
32 |
+
# check need_init must before get_namespace_data
|
33 |
+
need_init = try_initialize_namespace(self.namespace)
|
34 |
self._data = get_namespace_data(self.namespace)
|
35 |
+
if need_init:
|
36 |
+
loaded_data = load_json(self._file_name) or {}
|
37 |
+
with self._storage_lock:
|
38 |
+
self._data.update(loaded_data)
|
39 |
+
logger.info(
|
40 |
+
f"Loaded document status storage with {len(loaded_data)} records"
|
41 |
+
)
|
42 |
|
43 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
44 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
lightrag/kg/shared_storage.py
CHANGED
@@ -1,30 +1,74 @@
|
|
|
|
1 |
from multiprocessing.synchronize import Lock as ProcessLock
|
2 |
from threading import Lock as ThreadLock
|
3 |
from multiprocessing import Manager
|
4 |
from typing import Any, Dict, Optional, Union
|
|
|
5 |
|
6 |
-
# 定义类型变量
|
7 |
LockType = Union[ProcessLock, ThreadLock]
|
8 |
|
9 |
-
# 全局变量
|
10 |
-
_shared_data: Optional[Dict[str, Any]] = None
|
11 |
-
_namespace_objects: Optional[Dict[str, Any]] = None
|
12 |
-
_global_lock: Optional[LockType] = None
|
13 |
is_multiprocess = False
|
14 |
-
manager = None
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def _get_global_lock() -> LockType:
|
23 |
-
global _global_lock, is_multiprocess
|
24 |
|
25 |
if _global_lock is None:
|
26 |
if is_multiprocess:
|
27 |
-
_global_lock =
|
28 |
else:
|
29 |
_global_lock = ThreadLock()
|
30 |
|
@@ -38,56 +82,40 @@ def get_scan_lock() -> LockType:
|
|
38 |
"""return scan_progress lock for data consistency"""
|
39 |
return get_storage_lock()
|
40 |
|
41 |
-
def get_shared_data() -> Dict[str, Any]:
|
42 |
-
"""
|
43 |
-
return shared data for all storage types
|
44 |
-
create mult-process save share data only if need for better performance
|
45 |
-
"""
|
46 |
-
global _shared_data, is_multiprocess
|
47 |
-
|
48 |
-
if _shared_data is None:
|
49 |
-
lock = _get_global_lock()
|
50 |
-
with lock:
|
51 |
-
if _shared_data is None:
|
52 |
-
if is_multiprocess:
|
53 |
-
_shared_data = manager.dict()
|
54 |
-
else:
|
55 |
-
_shared_data = {}
|
56 |
-
|
57 |
-
return _shared_data
|
58 |
-
|
59 |
def get_namespace_object(namespace: str) -> Any:
|
60 |
"""Get an object for specific namespace"""
|
61 |
-
global
|
62 |
-
|
63 |
-
if
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
_namespace_objects = {}
|
68 |
-
|
69 |
-
if namespace not in _namespace_objects:
|
70 |
lock = _get_global_lock()
|
71 |
with lock:
|
72 |
-
if namespace not in
|
73 |
if is_multiprocess:
|
74 |
-
|
75 |
else:
|
76 |
-
|
77 |
|
78 |
-
return
|
|
|
|
|
79 |
|
80 |
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
81 |
"""get storage space for specific storage type(namespace)"""
|
82 |
-
|
83 |
-
lock = _get_global_lock()
|
84 |
|
85 |
-
if
|
|
|
|
|
|
|
|
|
86 |
with lock:
|
87 |
-
if namespace not in
|
88 |
-
|
89 |
|
90 |
-
return
|
91 |
|
92 |
def get_scan_progress() -> Dict[str, Any]:
|
93 |
"""get storage space for document scanning progress data"""
|
|
|
1 |
+
import os
|
2 |
from multiprocessing.synchronize import Lock as ProcessLock
|
3 |
from threading import Lock as ThreadLock
|
4 |
from multiprocessing import Manager
|
5 |
from typing import Any, Dict, Optional, Union
|
6 |
+
from lightrag.utils import logger
|
7 |
|
|
|
8 |
LockType = Union[ProcessLock, ThreadLock]
|
9 |
|
|
|
|
|
|
|
|
|
10 |
is_multiprocess = False
|
|
|
11 |
|
12 |
+
_manager = None
|
13 |
+
_global_lock: Optional[LockType] = None
|
14 |
+
|
15 |
+
# shared data for storage across processes
|
16 |
+
_shared_dicts: Optional[Dict[str, Any]] = {}
|
17 |
+
_share_objects: Optional[Dict[str, Any]] = {}
|
18 |
+
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
19 |
+
|
20 |
+
def initialize_share_data():
|
21 |
+
"""Initialize shared data, only called if multiple processes where workers > 1"""
|
22 |
+
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
|
23 |
+
is_multiprocess = True
|
24 |
+
|
25 |
+
logger.info(f"Process {os.getpid()} initializing shared storage")
|
26 |
+
|
27 |
+
# Initialize manager
|
28 |
+
if _manager is None:
|
29 |
+
_manager = Manager()
|
30 |
+
logger.info(f"Process {os.getpid()} created manager")
|
31 |
+
|
32 |
+
# Create shared dictionaries with manager
|
33 |
+
_shared_dicts = _manager.dict()
|
34 |
+
_share_objects = _manager.dict()
|
35 |
+
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
|
36 |
+
logger.info(f"Process {os.getpid()} created shared dictionaries")
|
37 |
+
|
38 |
+
def try_initialize_namespace(namespace: str) -> bool:
|
39 |
+
"""
|
40 |
+
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
|
41 |
+
使用共享字典的原子操作确保只有一个进程能成功初始化。
|
42 |
+
"""
|
43 |
+
global _init_flags, _manager
|
44 |
+
|
45 |
+
if is_multiprocess:
|
46 |
+
if _init_flags is None:
|
47 |
+
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
|
48 |
+
else:
|
49 |
+
if _init_flags is None:
|
50 |
+
_init_flags = {}
|
51 |
+
|
52 |
+
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
|
53 |
+
|
54 |
+
# 使用全局锁保护共享字典的访问
|
55 |
+
with _get_global_lock():
|
56 |
+
# 检查是否已经初始化
|
57 |
+
if namespace not in _init_flags:
|
58 |
+
# 设置初始化标志
|
59 |
+
_init_flags[namespace] = True
|
60 |
+
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
|
61 |
+
return True
|
62 |
+
|
63 |
+
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
|
64 |
+
return False
|
65 |
|
66 |
def _get_global_lock() -> LockType:
|
67 |
+
global _global_lock, is_multiprocess, _manager
|
68 |
|
69 |
if _global_lock is None:
|
70 |
if is_multiprocess:
|
71 |
+
_global_lock = _manager.Lock() # Use manager for lock
|
72 |
else:
|
73 |
_global_lock = ThreadLock()
|
74 |
|
|
|
82 |
"""return scan_progress lock for data consistency"""
|
83 |
return get_storage_lock()
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def get_namespace_object(namespace: str) -> Any:
|
86 |
"""Get an object for specific namespace"""
|
87 |
+
global _share_objects, is_multiprocess, _manager
|
88 |
+
|
89 |
+
if is_multiprocess and not _manager:
|
90 |
+
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
91 |
+
|
92 |
+
if namespace not in _share_objects:
|
|
|
|
|
|
|
93 |
lock = _get_global_lock()
|
94 |
with lock:
|
95 |
+
if namespace not in _share_objects:
|
96 |
if is_multiprocess:
|
97 |
+
_share_objects[namespace] = _manager.Value('O', None)
|
98 |
else:
|
99 |
+
_share_objects[namespace] = None
|
100 |
|
101 |
+
return _share_objects[namespace]
|
102 |
+
|
103 |
+
# 移除不再使用的函数
|
104 |
|
105 |
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
106 |
"""get storage space for specific storage type(namespace)"""
|
107 |
+
global _shared_dicts, is_multiprocess, _manager
|
|
|
108 |
|
109 |
+
if is_multiprocess and not _manager:
|
110 |
+
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
111 |
+
|
112 |
+
if namespace not in _shared_dicts:
|
113 |
+
lock = _get_global_lock()
|
114 |
with lock:
|
115 |
+
if namespace not in _shared_dicts:
|
116 |
+
_shared_dicts[namespace] = {}
|
117 |
|
118 |
+
return _shared_dicts[namespace]
|
119 |
|
120 |
def get_scan_progress() -> Dict[str, Any]:
|
121 |
"""get storage space for document scanning progress data"""
|