yangdx commited on
Commit
1ce985e
·
1 Parent(s): 87ba2e6

Refactor shared storage to safely handle multi-process initialization and data sharing

Browse files
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -12,7 +12,11 @@ from lightrag.utils import (
12
  logger,
13
  write_json,
14
  )
15
- from .shared_storage import get_namespace_data, get_storage_lock
 
 
 
 
16
 
17
 
18
  @final
@@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage):
24
  working_dir = self.global_config["working_dir"]
25
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
26
  self._storage_lock = get_storage_lock()
 
 
 
27
  self._data = get_namespace_data(self.namespace)
28
- with self._storage_lock:
29
- if not self._data:
30
- self._data.update(load_json(self._file_name) or {})
31
- logger.info(f"Loaded document status storage with {len(self._data)} records")
 
 
 
32
 
33
  async def filter_keys(self, keys: set[str]) -> set[str]:
34
  """Return keys that should be processed (not in storage or not successfully processed)"""
 
12
  logger,
13
  write_json,
14
  )
15
+ from .shared_storage import (
16
+ get_namespace_data,
17
+ get_storage_lock,
18
+ try_initialize_namespace,
19
+ )
20
 
21
 
22
  @final
 
28
  working_dir = self.global_config["working_dir"]
29
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
30
  self._storage_lock = get_storage_lock()
31
+
32
+ # check need_init must before get_namespace_data
33
+ need_init = try_initialize_namespace(self.namespace)
34
  self._data = get_namespace_data(self.namespace)
35
+ if need_init:
36
+ loaded_data = load_json(self._file_name) or {}
37
+ with self._storage_lock:
38
+ self._data.update(loaded_data)
39
+ logger.info(
40
+ f"Loaded document status storage with {len(loaded_data)} records"
41
+ )
42
 
43
  async def filter_keys(self, keys: set[str]) -> set[str]:
44
  """Return keys that should be processed (not in storage or not successfully processed)"""
lightrag/kg/shared_storage.py CHANGED
@@ -1,30 +1,74 @@
 
1
  from multiprocessing.synchronize import Lock as ProcessLock
2
  from threading import Lock as ThreadLock
3
  from multiprocessing import Manager
4
  from typing import Any, Dict, Optional, Union
 
5
 
6
- # 定义类型变量
7
  LockType = Union[ProcessLock, ThreadLock]
8
 
9
- # 全局变量
10
- _shared_data: Optional[Dict[str, Any]] = None
11
- _namespace_objects: Optional[Dict[str, Any]] = None
12
- _global_lock: Optional[LockType] = None
13
  is_multiprocess = False
14
- manager = None
15
 
16
- def initialize_manager():
17
- """Initialize manager, only for multiple processes where workers > 1"""
18
- global manager
19
- if manager is None:
20
- manager = Manager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def _get_global_lock() -> LockType:
23
- global _global_lock, is_multiprocess
24
 
25
  if _global_lock is None:
26
  if is_multiprocess:
27
- _global_lock = manager.Lock()
28
  else:
29
  _global_lock = ThreadLock()
30
 
@@ -38,56 +82,40 @@ def get_scan_lock() -> LockType:
38
  """return scan_progress lock for data consistency"""
39
  return get_storage_lock()
40
 
41
- def get_shared_data() -> Dict[str, Any]:
42
- """
43
- return shared data for all storage types
44
- create mult-process save share data only if need for better performance
45
- """
46
- global _shared_data, is_multiprocess
47
-
48
- if _shared_data is None:
49
- lock = _get_global_lock()
50
- with lock:
51
- if _shared_data is None:
52
- if is_multiprocess:
53
- _shared_data = manager.dict()
54
- else:
55
- _shared_data = {}
56
-
57
- return _shared_data
58
-
59
  def get_namespace_object(namespace: str) -> Any:
60
  """Get an object for specific namespace"""
61
- global _namespace_objects, is_multiprocess
62
-
63
- if _namespace_objects is None:
64
- lock = _get_global_lock()
65
- with lock:
66
- if _namespace_objects is None:
67
- _namespace_objects = {}
68
-
69
- if namespace not in _namespace_objects:
70
  lock = _get_global_lock()
71
  with lock:
72
- if namespace not in _namespace_objects:
73
  if is_multiprocess:
74
- _namespace_objects[namespace] = manager.Value('O', None)
75
  else:
76
- _namespace_objects[namespace] = None
77
 
78
- return _namespace_objects[namespace]
 
 
79
 
80
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
81
  """get storage space for specific storage type(namespace)"""
82
- shared_data = get_shared_data()
83
- lock = _get_global_lock()
84
 
85
- if namespace not in shared_data:
 
 
 
 
86
  with lock:
87
- if namespace not in shared_data:
88
- shared_data[namespace] = {}
89
 
90
- return shared_data[namespace]
91
 
92
  def get_scan_progress() -> Dict[str, Any]:
93
  """get storage space for document scanning progress data"""
 
1
+ import os
2
  from multiprocessing.synchronize import Lock as ProcessLock
3
  from threading import Lock as ThreadLock
4
  from multiprocessing import Manager
5
  from typing import Any, Dict, Optional, Union
6
+ from lightrag.utils import logger
7
 
 
8
  LockType = Union[ProcessLock, ThreadLock]
9
 
 
 
 
 
10
  is_multiprocess = False
 
11
 
12
+ _manager = None
13
+ _global_lock: Optional[LockType] = None
14
+
15
+ # shared data for storage across processes
16
+ _shared_dicts: Optional[Dict[str, Any]] = {}
17
+ _share_objects: Optional[Dict[str, Any]] = {}
18
+ _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
19
+
20
+ def initialize_share_data():
21
+ """Initialize shared data, only called if multiple processes where workers > 1"""
22
+ global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
23
+ is_multiprocess = True
24
+
25
+ logger.info(f"Process {os.getpid()} initializing shared storage")
26
+
27
+ # Initialize manager
28
+ if _manager is None:
29
+ _manager = Manager()
30
+ logger.info(f"Process {os.getpid()} created manager")
31
+
32
+ # Create shared dictionaries with manager
33
+ _shared_dicts = _manager.dict()
34
+ _share_objects = _manager.dict()
35
+ _init_flags = _manager.dict() # 使用共享字典存储初始化标志
36
+ logger.info(f"Process {os.getpid()} created shared dictionaries")
37
+
38
+ def try_initialize_namespace(namespace: str) -> bool:
39
+ """
40
+ 尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
41
+ 使用共享字典的原子操作确保只有一个进程能成功初始化。
42
+ """
43
+ global _init_flags, _manager
44
+
45
+ if is_multiprocess:
46
+ if _init_flags is None:
47
+ raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
48
+ else:
49
+ if _init_flags is None:
50
+ _init_flags = {}
51
+
52
+ logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
53
+
54
+ # 使用全局锁保护共享字典的访问
55
+ with _get_global_lock():
56
+ # 检查是否已经初始化
57
+ if namespace not in _init_flags:
58
+ # 设置初始化标志
59
+ _init_flags[namespace] = True
60
+ logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
61
+ return True
62
+
63
+ logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
64
+ return False
65
 
66
  def _get_global_lock() -> LockType:
67
+ global _global_lock, is_multiprocess, _manager
68
 
69
  if _global_lock is None:
70
  if is_multiprocess:
71
+ _global_lock = _manager.Lock() # Use manager for lock
72
  else:
73
  _global_lock = ThreadLock()
74
 
 
82
  """return scan_progress lock for data consistency"""
83
  return get_storage_lock()
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def get_namespace_object(namespace: str) -> Any:
86
  """Get an object for specific namespace"""
87
+ global _share_objects, is_multiprocess, _manager
88
+
89
+ if is_multiprocess and not _manager:
90
+ raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
91
+
92
+ if namespace not in _share_objects:
 
 
 
93
  lock = _get_global_lock()
94
  with lock:
95
+ if namespace not in _share_objects:
96
  if is_multiprocess:
97
+ _share_objects[namespace] = _manager.Value('O', None)
98
  else:
99
+ _share_objects[namespace] = None
100
 
101
+ return _share_objects[namespace]
102
+
103
+ # 移除不再使用的函数
104
 
105
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
106
  """get storage space for specific storage type(namespace)"""
107
+ global _shared_dicts, is_multiprocess, _manager
 
108
 
109
+ if is_multiprocess and not _manager:
110
+ raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
111
+
112
+ if namespace not in _shared_dicts:
113
+ lock = _get_global_lock()
114
  with lock:
115
+ if namespace not in _shared_dicts:
116
+ _shared_dicts[namespace] = {}
117
 
118
+ return _shared_dicts[namespace]
119
 
120
  def get_scan_progress() -> Dict[str, Any]:
121
  """get storage space for document scanning progress data"""