yangdx commited on
Commit
556f361
·
1 Parent(s): 1ce985e

Fix linting

Browse files
.env.example CHANGED
@@ -141,4 +141,4 @@ QDRANT_URL=http://localhost:16333
141
  # QDRANT_API_KEY=your-api-key
142
 
143
  ### Redis
144
- REDIS_URI=redis://localhost:6379
 
141
  # QDRANT_API_KEY=your-api-key
142
 
143
  ### Redis
144
+ REDIS_URI=redis://localhost:6379
lightrag/api/lightrag_server.py CHANGED
@@ -54,11 +54,12 @@ config.read("config.ini")
54
 
55
  class LightragPathFilter(logging.Filter):
56
  """Filter for lightrag logger to filter out frequent path access logs"""
 
57
  def __init__(self):
58
  super().__init__()
59
  # Define paths to be filtered
60
  self.filtered_paths = ["/documents", "/health", "/webui/"]
61
-
62
  def filter(self, record):
63
  try:
64
  # Check if record has the required attributes for an access log
@@ -90,11 +91,13 @@ def create_app(args):
90
  # Initialize verbose debug setting
91
  # Can not use the logger at the top of this module when workers > 1
92
  from lightrag.utils import set_verbose_debug, logger
 
93
  # Setup logging
94
  logger.setLevel(getattr(logging, args.log_level))
95
  set_verbose_debug(args.verbose)
96
 
97
  from lightrag.kg.shared_storage import is_multiprocess
 
98
  logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
99
 
100
  # Verify that bindings are correctly setup
@@ -147,9 +150,7 @@ def create_app(args):
147
  # Auto scan documents if enabled
148
  if args.auto_scan_at_startup:
149
  # Create background task
150
- task = asyncio.create_task(
151
- run_scanning_process(rag, doc_manager)
152
- )
153
  app.state.background_tasks.add(task)
154
  task.add_done_callback(app.state.background_tasks.discard)
155
 
@@ -411,17 +412,19 @@ def get_application():
411
  """Factory function for creating the FastAPI application"""
412
  # Configure logging for this worker process
413
  configure_logging()
414
-
415
  # Get args from environment variable
416
- args_json = os.environ.get('LIGHTRAG_ARGS')
417
  if not args_json:
418
  args = parse_args() # Fallback to parsing args if env var not set
419
  else:
420
  import types
 
421
  args = types.SimpleNamespace(**json.loads(args_json))
422
-
423
  if args.workers > 1:
424
  from lightrag.kg.shared_storage import initialize_share_data
 
425
  initialize_share_data()
426
 
427
  return create_app(args)
@@ -434,58 +437,61 @@ def configure_logging():
434
  logger = logging.getLogger(logger_name)
435
  logger.handlers = []
436
  logger.filters = []
437
-
438
  # Configure basic logging
439
- logging.config.dictConfig({
440
- "version": 1,
441
- "disable_existing_loggers": False,
442
- "formatters": {
443
- "default": {
444
- "format": "%(levelname)s: %(message)s",
445
- },
446
- },
447
- "handlers": {
448
- "default": {
449
- "formatter": "default",
450
- "class": "logging.StreamHandler",
451
- "stream": "ext://sys.stderr",
452
  },
453
- },
454
- "loggers": {
455
- "uvicorn.access": {
456
- "handlers": ["default"],
457
- "level": "INFO",
458
- "propagate": False,
459
- "filters": ["path_filter"],
460
  },
461
- "lightrag": {
462
- "handlers": ["default"],
463
- "level": "INFO",
464
- "propagate": False,
465
- "filters": ["path_filter"],
 
 
 
 
 
 
 
 
466
  },
467
- },
468
- "filters": {
469
- "path_filter": {
470
- "()": "lightrag.api.lightrag_server.LightragPathFilter",
471
  },
472
- },
473
- })
 
474
 
475
  def main():
476
  from multiprocessing import freeze_support
 
477
  freeze_support()
478
-
479
  args = parse_args()
480
  # Save args to environment variable for child processes
481
- os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
482
 
483
  # Configure logging before starting uvicorn
484
  configure_logging()
485
 
486
  display_splash_screen(args)
487
 
488
-
489
  uvicorn_config = {
490
  "app": "lightrag.api.lightrag_server:get_application",
491
  "factory": True,
 
54
 
55
  class LightragPathFilter(logging.Filter):
56
  """Filter for lightrag logger to filter out frequent path access logs"""
57
+
58
  def __init__(self):
59
  super().__init__()
60
  # Define paths to be filtered
61
  self.filtered_paths = ["/documents", "/health", "/webui/"]
62
+
63
  def filter(self, record):
64
  try:
65
  # Check if record has the required attributes for an access log
 
91
  # Initialize verbose debug setting
92
  # Can not use the logger at the top of this module when workers > 1
93
  from lightrag.utils import set_verbose_debug, logger
94
+
95
  # Setup logging
96
  logger.setLevel(getattr(logging, args.log_level))
97
  set_verbose_debug(args.verbose)
98
 
99
  from lightrag.kg.shared_storage import is_multiprocess
100
+
101
  logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
102
 
103
  # Verify that bindings are correctly setup
 
150
  # Auto scan documents if enabled
151
  if args.auto_scan_at_startup:
152
  # Create background task
153
+ task = asyncio.create_task(run_scanning_process(rag, doc_manager))
 
 
154
  app.state.background_tasks.add(task)
155
  task.add_done_callback(app.state.background_tasks.discard)
156
 
 
412
  """Factory function for creating the FastAPI application"""
413
  # Configure logging for this worker process
414
  configure_logging()
415
+
416
  # Get args from environment variable
417
+ args_json = os.environ.get("LIGHTRAG_ARGS")
418
  if not args_json:
419
  args = parse_args() # Fallback to parsing args if env var not set
420
  else:
421
  import types
422
+
423
  args = types.SimpleNamespace(**json.loads(args_json))
424
+
425
  if args.workers > 1:
426
  from lightrag.kg.shared_storage import initialize_share_data
427
+
428
  initialize_share_data()
429
 
430
  return create_app(args)
 
437
  logger = logging.getLogger(logger_name)
438
  logger.handlers = []
439
  logger.filters = []
440
+
441
  # Configure basic logging
442
+ logging.config.dictConfig(
443
+ {
444
+ "version": 1,
445
+ "disable_existing_loggers": False,
446
+ "formatters": {
447
+ "default": {
448
+ "format": "%(levelname)s: %(message)s",
449
+ },
 
 
 
 
 
450
  },
451
+ "handlers": {
452
+ "default": {
453
+ "formatter": "default",
454
+ "class": "logging.StreamHandler",
455
+ "stream": "ext://sys.stderr",
456
+ },
 
457
  },
458
+ "loggers": {
459
+ "uvicorn.access": {
460
+ "handlers": ["default"],
461
+ "level": "INFO",
462
+ "propagate": False,
463
+ "filters": ["path_filter"],
464
+ },
465
+ "lightrag": {
466
+ "handlers": ["default"],
467
+ "level": "INFO",
468
+ "propagate": False,
469
+ "filters": ["path_filter"],
470
+ },
471
  },
472
+ "filters": {
473
+ "path_filter": {
474
+ "()": "lightrag.api.lightrag_server.LightragPathFilter",
475
+ },
476
  },
477
+ }
478
+ )
479
+
480
 
481
  def main():
482
  from multiprocessing import freeze_support
483
+
484
  freeze_support()
485
+
486
  args = parse_args()
487
  # Save args to environment variable for child processes
488
+ os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args))
489
 
490
  # Configure logging before starting uvicorn
491
  configure_logging()
492
 
493
  display_splash_screen(args)
494
 
 
495
  uvicorn_config = {
496
  "app": "lightrag.api.lightrag_server:get_application",
497
  "factory": True,
lightrag/api/routers/document_routes.py CHANGED
@@ -375,62 +375,70 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
375
 
376
 
377
  async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
378
- """Background task to scan and index documents"""
379
  scan_progress = get_scan_progress()
380
  scan_lock = get_scan_lock()
381
-
382
  # Initialize scan_progress if not already initialized
383
  if not scan_progress:
384
- scan_progress.update({
385
- "is_scanning": False,
386
- "current_file": "",
387
- "indexed_count": 0,
388
- "total_files": 0,
389
- "progress": 0,
390
- })
391
-
 
 
392
  with scan_lock:
393
  if scan_progress.get("is_scanning", False):
394
- ASCIIColors.info(
395
- "Skip document scanning(another scanning is active)"
396
- )
397
  return
398
- scan_progress.update({
399
- "is_scanning": True,
400
- "current_file": "",
401
- "indexed_count": 0,
402
- "total_files": 0,
403
- "progress": 0,
404
- })
 
 
405
 
406
  try:
407
  new_files = doc_manager.scan_directory_for_new_files()
408
  total_files = len(new_files)
409
- scan_progress.update({
410
- "current_file": "",
411
- "total_files": total_files,
412
- "indexed_count": 0,
413
- "progress": 0,
414
- })
 
 
415
 
416
  logging.info(f"Found {total_files} new files to index.")
417
  for idx, file_path in enumerate(new_files):
418
  try:
419
  progress = (idx / total_files * 100) if total_files > 0 else 0
420
- scan_progress.update({
421
- "current_file": os.path.basename(file_path),
422
- "indexed_count": idx,
423
- "progress": progress,
424
- })
425
-
 
 
426
  await pipeline_index_file(rag, file_path)
427
-
428
  progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
429
- scan_progress.update({
430
- "current_file": os.path.basename(file_path),
431
- "indexed_count": idx + 1,
432
- "progress": progress,
433
- })
 
 
434
 
435
  except Exception as e:
436
  logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -438,13 +446,15 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
438
  except Exception as e:
439
  logging.error(f"Error during scanning process: {str(e)}")
440
  finally:
441
- scan_progress.update({
442
- "is_scanning": False,
443
- "current_file": "",
444
- "indexed_count": 0,
445
- "total_files": 0,
446
- "progress": 0,
447
- })
 
 
448
 
449
 
450
  def create_document_routes(
 
375
 
376
 
377
  async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
378
+ """Background task to scan and index documents"""
379
  scan_progress = get_scan_progress()
380
  scan_lock = get_scan_lock()
381
+
382
  # Initialize scan_progress if not already initialized
383
  if not scan_progress:
384
+ scan_progress.update(
385
+ {
386
+ "is_scanning": False,
387
+ "current_file": "",
388
+ "indexed_count": 0,
389
+ "total_files": 0,
390
+ "progress": 0,
391
+ }
392
+ )
393
+
394
  with scan_lock:
395
  if scan_progress.get("is_scanning", False):
396
+ ASCIIColors.info("Skip document scanning(another scanning is active)")
 
 
397
  return
398
+ scan_progress.update(
399
+ {
400
+ "is_scanning": True,
401
+ "current_file": "",
402
+ "indexed_count": 0,
403
+ "total_files": 0,
404
+ "progress": 0,
405
+ }
406
+ )
407
 
408
  try:
409
  new_files = doc_manager.scan_directory_for_new_files()
410
  total_files = len(new_files)
411
+ scan_progress.update(
412
+ {
413
+ "current_file": "",
414
+ "total_files": total_files,
415
+ "indexed_count": 0,
416
+ "progress": 0,
417
+ }
418
+ )
419
 
420
  logging.info(f"Found {total_files} new files to index.")
421
  for idx, file_path in enumerate(new_files):
422
  try:
423
  progress = (idx / total_files * 100) if total_files > 0 else 0
424
+ scan_progress.update(
425
+ {
426
+ "current_file": os.path.basename(file_path),
427
+ "indexed_count": idx,
428
+ "progress": progress,
429
+ }
430
+ )
431
+
432
  await pipeline_index_file(rag, file_path)
433
+
434
  progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
435
+ scan_progress.update(
436
+ {
437
+ "current_file": os.path.basename(file_path),
438
+ "indexed_count": idx + 1,
439
+ "progress": progress,
440
+ }
441
+ )
442
 
443
  except Exception as e:
444
  logging.error(f"Error indexing file {file_path}: {str(e)}")
 
446
  except Exception as e:
447
  logging.error(f"Error during scanning process: {str(e)}")
448
  finally:
449
+ scan_progress.update(
450
+ {
451
+ "is_scanning": False,
452
+ "current_file": "",
453
+ "indexed_count": 0,
454
+ "total_files": 0,
455
+ "progress": 0,
456
+ }
457
+ )
458
 
459
 
460
  def create_document_routes(
lightrag/api/utils_api.py CHANGED
@@ -433,7 +433,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
433
  ASCIIColors.white(" └─ Document Status Storage: ", end="")
434
  ASCIIColors.yellow(f"{args.doc_status_storage}")
435
 
436
-
437
  # Server Status
438
  ASCIIColors.green("\n✨ Server starting up...\n")
439
 
 
433
  ASCIIColors.white(" └─ Document Status Storage: ", end="")
434
  ASCIIColors.yellow(f"{args.doc_status_storage}")
435
 
 
436
  # Server Status
437
  ASCIIColors.green("\n✨ Server starting up...\n")
438
 
lightrag/kg/faiss_impl.py CHANGED
@@ -8,14 +8,19 @@ import numpy as np
8
  from dataclasses import dataclass
9
  import pipmaster as pm
10
 
11
- from lightrag.utils import logger,compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
- from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
 
 
 
 
 
14
 
15
  if not pm.is_installed("faiss"):
16
  pm.install("faiss")
17
 
18
- import faiss # type: ignore
19
 
20
 
21
  @final
@@ -46,10 +51,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
46
  # Embedding dimension (e.g. 768) must match your embedding function
47
  self._dim = self.embedding_func.embedding_dim
48
  self._storage_lock = get_storage_lock()
49
-
50
- self._index = get_namespace_object('faiss_indices')
51
- self._id_to_meta = get_namespace_data('faiss_meta')
52
-
53
  with self._storage_lock:
54
  if is_multiprocess:
55
  if self._index.value is None:
@@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
68
  self._id_to_meta.update({})
69
  self._load_faiss_index()
70
 
71
-
72
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
73
  """
74
  Insert or update vectors in the Faiss index.
@@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
168
 
169
  # Perform the similarity search
170
  with self._storage_lock:
171
- distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
 
 
172
 
173
  distances = distances[0]
174
  indices = indices[0]
@@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
232
  with self._storage_lock:
233
  relations = []
234
  for fid, meta in self._id_to_meta.items():
235
- if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
 
 
 
236
  relations.append(fid)
237
 
238
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
@@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
292
  Save the current Faiss index + metadata to disk so it can persist across runs.
293
  """
294
  with self._storage_lock:
295
- faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
 
 
 
296
 
297
  # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
298
  # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
@@ -320,7 +332,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
320
  self._index.value = loaded_index
321
  else:
322
  self._index = loaded_index
323
-
324
  # Load metadata
325
  with open(self._meta_file, "r", encoding="utf-8") as f:
326
  stored_dict = json.load(f)
 
8
  from dataclasses import dataclass
9
  import pipmaster as pm
10
 
11
+ from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
+ from .shared_storage import (
14
+ get_namespace_data,
15
+ get_storage_lock,
16
+ get_namespace_object,
17
+ is_multiprocess,
18
+ )
19
 
20
  if not pm.is_installed("faiss"):
21
  pm.install("faiss")
22
 
23
+ import faiss # type: ignore
24
 
25
 
26
  @final
 
51
  # Embedding dimension (e.g. 768) must match your embedding function
52
  self._dim = self.embedding_func.embedding_dim
53
  self._storage_lock = get_storage_lock()
54
+
55
+ self._index = get_namespace_object("faiss_indices")
56
+ self._id_to_meta = get_namespace_data("faiss_meta")
57
+
58
  with self._storage_lock:
59
  if is_multiprocess:
60
  if self._index.value is None:
 
73
  self._id_to_meta.update({})
74
  self._load_faiss_index()
75
 
 
76
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
77
  """
78
  Insert or update vectors in the Faiss index.
 
172
 
173
  # Perform the similarity search
174
  with self._storage_lock:
175
+ distances, indices = (
176
+ self._index.value if is_multiprocess else self._index
177
+ ).search(embedding, top_k)
178
 
179
  distances = distances[0]
180
  indices = indices[0]
 
238
  with self._storage_lock:
239
  relations = []
240
  for fid, meta in self._id_to_meta.items():
241
+ if (
242
+ meta.get("src_id") == entity_name
243
+ or meta.get("tgt_id") == entity_name
244
+ ):
245
  relations.append(fid)
246
 
247
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
 
301
  Save the current Faiss index + metadata to disk so it can persist across runs.
302
  """
303
  with self._storage_lock:
304
+ faiss.write_index(
305
+ self._index.value if is_multiprocess else self._index,
306
+ self._faiss_index_file,
307
+ )
308
 
309
  # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
310
  # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
 
332
  self._index.value = loaded_index
333
  else:
334
  self._index = loaded_index
335
+
336
  # Load metadata
337
  with open(self._meta_file, "r", encoding="utf-8") as f:
338
  stored_dict = json.load(f)
lightrag/kg/json_kv_impl.py CHANGED
@@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage):
26
  self._data: dict[str, Any] = load_json(self._file_name) or {}
27
  logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
28
 
29
-
30
  async def index_done_callback(self) -> None:
31
  # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
32
  with self._storage_lock:
 
26
  self._data: dict[str, Any] = load_json(self._file_name) or {}
27
  logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
28
 
 
29
  async def index_done_callback(self) -> None:
30
  # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
31
  with self._storage_lock:
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -25,7 +25,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
25
  def __post_init__(self):
26
  # Initialize lock only for file operations
27
  self._storage_lock = get_storage_lock()
28
-
29
  # Use global config value if specified, otherwise use default
30
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -39,22 +39,28 @@ class NanoVectorDBStorage(BaseVectorStorage):
39
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
40
  )
41
  self._max_batch_size = self.global_config["embedding_batch_num"]
42
-
43
  self._client = get_namespace_object(self.namespace)
44
-
45
  with self._storage_lock:
46
  if is_multiprocess:
47
  if self._client.value is None:
48
  self._client.value = NanoVectorDB(
49
- self.embedding_func.embedding_dim, storage_file=self._client_file_name
 
 
 
 
50
  )
51
- logger.info(f"Initialized vector DB client for namespace {self.namespace}")
52
  else:
53
  if self._client is None:
54
  self._client = NanoVectorDB(
55
- self.embedding_func.embedding_dim, storage_file=self._client_file_name
 
 
 
 
56
  )
57
- logger.info(f"Initialized vector DB client for namespace {self.namespace}")
58
 
59
  def _get_client(self):
60
  """Get the appropriate client instance based on multiprocess mode"""
@@ -104,7 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
104
  # Execute embedding outside of lock to avoid long lock times
105
  embedding = await self.embedding_func([query])
106
  embedding = embedding[0]
107
-
108
  with self._storage_lock:
109
  client = self._get_client()
110
  results = client.query(
@@ -150,7 +156,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
150
  logger.debug(
151
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
152
  )
153
-
154
  with self._storage_lock:
155
  client = self._get_client()
156
  # Check if the entity exists
@@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
172
  for dp in storage["data"]
173
  if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
174
  ]
175
- logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
 
 
176
  ids_to_delete = [relation["__id__"] for relation in relations]
177
 
178
  if ids_to_delete:
 
25
  def __post_init__(self):
26
  # Initialize lock only for file operations
27
  self._storage_lock = get_storage_lock()
28
+
29
  # Use global config value if specified, otherwise use default
30
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
39
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
40
  )
41
  self._max_batch_size = self.global_config["embedding_batch_num"]
42
+
43
  self._client = get_namespace_object(self.namespace)
44
+
45
  with self._storage_lock:
46
  if is_multiprocess:
47
  if self._client.value is None:
48
  self._client.value = NanoVectorDB(
49
+ self.embedding_func.embedding_dim,
50
+ storage_file=self._client_file_name,
51
+ )
52
+ logger.info(
53
+ f"Initialized vector DB client for namespace {self.namespace}"
54
  )
 
55
  else:
56
  if self._client is None:
57
  self._client = NanoVectorDB(
58
+ self.embedding_func.embedding_dim,
59
+ storage_file=self._client_file_name,
60
+ )
61
+ logger.info(
62
+ f"Initialized vector DB client for namespace {self.namespace}"
63
  )
 
64
 
65
  def _get_client(self):
66
  """Get the appropriate client instance based on multiprocess mode"""
 
110
  # Execute embedding outside of lock to avoid long lock times
111
  embedding = await self.embedding_func([query])
112
  embedding = embedding[0]
113
+
114
  with self._storage_lock:
115
  client = self._get_client()
116
  results = client.query(
 
156
  logger.debug(
157
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
158
  )
159
+
160
  with self._storage_lock:
161
  client = self._get_client()
162
  # Check if the entity exists
 
178
  for dp in storage["data"]
179
  if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
180
  ]
181
+ logger.debug(
182
+ f"Found {len(relations)} relations for entity {entity_name}"
183
+ )
184
  ids_to_delete = [relation["__id__"] for relation in relations]
185
 
186
  if ids_to_delete:
lightrag/kg/networkx_impl.py CHANGED
@@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage):
78
  with self._storage_lock:
79
  if is_multiprocess:
80
  if self._graph.value is None:
81
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
 
 
82
  self._graph.value = preloaded_graph or nx.Graph()
83
  if preloaded_graph:
84
  logger.info(
85
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
86
  )
87
  else:
88
  logger.info("Created new empty graph")
89
  else:
90
  if self._graph is None:
91
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
 
 
92
  self._graph = preloaded_graph or nx.Graph()
93
  if preloaded_graph:
94
  logger.info(
95
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
96
  )
97
  else:
98
  logger.info("Created new empty graph")
99
 
100
  self._node_embed_algorithms = {
101
- "node2vec": self._node2vec_embed,
102
  }
103
-
104
  def _get_graph(self):
105
  """Get the appropriate graph instance based on multiprocess mode"""
106
  if is_multiprocess:
@@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage):
248
 
249
  with self._storage_lock:
250
  graph = self._get_graph()
251
-
252
  # Handle special case for "*" label
253
  if node_label == "*":
254
  # For "*", return the entire graph including all nodes and edges
255
- subgraph = graph.copy() # Create a copy to avoid modifying the original graph
 
 
256
  else:
257
  # Find nodes with matching node id (partial match)
258
  nodes_to_explore = []
@@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage):
272
  if len(subgraph.nodes()) > max_graph_nodes:
273
  origin_nodes = len(subgraph.nodes())
274
  node_degrees = dict(subgraph.degree())
275
- top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
276
- :max_graph_nodes
277
- ]
278
  top_node_ids = [node[0] for node in top_nodes]
279
  # Create new subgraph with only top nodes
280
  subgraph = subgraph.subgraph(top_node_ids)
 
78
  with self._storage_lock:
79
  if is_multiprocess:
80
  if self._graph.value is None:
81
+ preloaded_graph = NetworkXStorage.load_nx_graph(
82
+ self._graphml_xml_file
83
+ )
84
  self._graph.value = preloaded_graph or nx.Graph()
85
  if preloaded_graph:
86
  logger.info(
87
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
88
  )
89
  else:
90
  logger.info("Created new empty graph")
91
  else:
92
  if self._graph is None:
93
+ preloaded_graph = NetworkXStorage.load_nx_graph(
94
+ self._graphml_xml_file
95
+ )
96
  self._graph = preloaded_graph or nx.Graph()
97
  if preloaded_graph:
98
  logger.info(
99
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
100
  )
101
  else:
102
  logger.info("Created new empty graph")
103
 
104
  self._node_embed_algorithms = {
105
+ "node2vec": self._node2vec_embed,
106
  }
107
+
108
  def _get_graph(self):
109
  """Get the appropriate graph instance based on multiprocess mode"""
110
  if is_multiprocess:
 
252
 
253
  with self._storage_lock:
254
  graph = self._get_graph()
255
+
256
  # Handle special case for "*" label
257
  if node_label == "*":
258
  # For "*", return the entire graph including all nodes and edges
259
+ subgraph = (
260
+ graph.copy()
261
+ ) # Create a copy to avoid modifying the original graph
262
  else:
263
  # Find nodes with matching node id (partial match)
264
  nodes_to_explore = []
 
278
  if len(subgraph.nodes()) > max_graph_nodes:
279
  origin_nodes = len(subgraph.nodes())
280
  node_degrees = dict(subgraph.degree())
281
+ top_nodes = sorted(
282
+ node_degrees.items(), key=lambda x: x[1], reverse=True
283
+ )[:max_graph_nodes]
284
  top_node_ids = [node[0] for node in top_nodes]
285
  # Create new subgraph with only top nodes
286
  subgraph = subgraph.subgraph(top_node_ids)
lightrag/kg/shared_storage.py CHANGED
@@ -17,106 +17,125 @@ _shared_dicts: Optional[Dict[str, Any]] = {}
17
  _share_objects: Optional[Dict[str, Any]] = {}
18
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
19
 
 
20
  def initialize_share_data():
21
  """Initialize shared data, only called if multiple processes where workers > 1"""
22
  global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
23
  is_multiprocess = True
24
-
25
  logger.info(f"Process {os.getpid()} initializing shared storage")
26
-
27
  # Initialize manager
28
  if _manager is None:
29
  _manager = Manager()
30
  logger.info(f"Process {os.getpid()} created manager")
31
-
32
  # Create shared dictionaries with manager
33
  _shared_dicts = _manager.dict()
34
  _share_objects = _manager.dict()
35
  _init_flags = _manager.dict() # 使用共享字典存储初始化标志
36
  logger.info(f"Process {os.getpid()} created shared dictionaries")
37
 
 
38
  def try_initialize_namespace(namespace: str) -> bool:
39
  """
40
  尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
41
  使用共享字典的原子操作确保只有一个进程能成功初始化。
42
  """
43
  global _init_flags, _manager
44
-
45
  if is_multiprocess:
46
  if _init_flags is None:
47
- raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
 
 
48
  else:
49
  if _init_flags is None:
50
  _init_flags = {}
51
-
52
  logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
53
-
54
  # 使用全局锁保护共享字典的访问
55
  with _get_global_lock():
56
  # 检查是否已经初始化
57
  if namespace not in _init_flags:
58
  # 设置初始化标志
59
  _init_flags[namespace] = True
60
- logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
 
 
61
  return True
62
-
63
- logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
 
 
64
  return False
65
 
 
66
  def _get_global_lock() -> LockType:
67
  global _global_lock, is_multiprocess, _manager
68
-
69
  if _global_lock is None:
70
  if is_multiprocess:
71
  _global_lock = _manager.Lock() # Use manager for lock
72
  else:
73
  _global_lock = ThreadLock()
74
-
75
  return _global_lock
76
 
 
77
  def get_storage_lock() -> LockType:
78
  """return storage lock for data consistency"""
79
  return _get_global_lock()
80
 
 
81
  def get_scan_lock() -> LockType:
82
  """return scan_progress lock for data consistency"""
83
  return get_storage_lock()
84
 
 
85
  def get_namespace_object(namespace: str) -> Any:
86
  """Get an object for specific namespace"""
87
  global _share_objects, is_multiprocess, _manager
88
-
89
  if is_multiprocess and not _manager:
90
- raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
 
 
91
 
92
  if namespace not in _share_objects:
93
  lock = _get_global_lock()
94
  with lock:
95
  if namespace not in _share_objects:
96
  if is_multiprocess:
97
- _share_objects[namespace] = _manager.Value('O', None)
98
  else:
99
  _share_objects[namespace] = None
100
-
101
  return _share_objects[namespace]
102
 
 
103
  # 移除不再使用的函数
104
 
 
105
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
106
  """get storage space for specific storage type(namespace)"""
107
  global _shared_dicts, is_multiprocess, _manager
108
-
109
  if is_multiprocess and not _manager:
110
- raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
 
 
111
 
112
  if namespace not in _shared_dicts:
113
  lock = _get_global_lock()
114
  with lock:
115
  if namespace not in _shared_dicts:
116
  _shared_dicts[namespace] = {}
117
-
118
  return _shared_dicts[namespace]
119
 
 
120
  def get_scan_progress() -> Dict[str, Any]:
121
  """get storage space for document scanning progress data"""
122
- return get_namespace_data('scan_progress')
 
17
  _share_objects: Optional[Dict[str, Any]] = {}
18
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
19
 
20
+
21
  def initialize_share_data():
22
  """Initialize shared data, only called if multiple processes where workers > 1"""
23
  global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
24
  is_multiprocess = True
25
+
26
  logger.info(f"Process {os.getpid()} initializing shared storage")
27
+
28
  # Initialize manager
29
  if _manager is None:
30
  _manager = Manager()
31
  logger.info(f"Process {os.getpid()} created manager")
32
+
33
  # Create shared dictionaries with manager
34
  _shared_dicts = _manager.dict()
35
  _share_objects = _manager.dict()
36
  _init_flags = _manager.dict() # 使用共享字典存储初始化标志
37
  logger.info(f"Process {os.getpid()} created shared dictionaries")
38
 
39
+
40
  def try_initialize_namespace(namespace: str) -> bool:
41
  """
42
  尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
43
  使用共享字典的原子操作确保只有一个进程能成功初始化。
44
  """
45
  global _init_flags, _manager
46
+
47
  if is_multiprocess:
48
  if _init_flags is None:
49
+ raise RuntimeError(
50
+ "Shared storage not initialized. Call initialize_share_data() first."
51
+ )
52
  else:
53
  if _init_flags is None:
54
  _init_flags = {}
55
+
56
  logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
57
+
58
  # 使用全局锁保护共享字典的访问
59
  with _get_global_lock():
60
  # 检查是否已经初始化
61
  if namespace not in _init_flags:
62
  # 设置初始化标志
63
  _init_flags[namespace] = True
64
+ logger.info(
65
+ f"Process {os.getpid()} ready to initialize namespace {namespace}"
66
+ )
67
  return True
68
+
69
+ logger.info(
70
+ f"Process {os.getpid()} found namespace {namespace} already initialized"
71
+ )
72
  return False
73
 
74
+
75
  def _get_global_lock() -> LockType:
76
  global _global_lock, is_multiprocess, _manager
77
+
78
  if _global_lock is None:
79
  if is_multiprocess:
80
  _global_lock = _manager.Lock() # Use manager for lock
81
  else:
82
  _global_lock = ThreadLock()
83
+
84
  return _global_lock
85
 
86
+
87
  def get_storage_lock() -> LockType:
88
  """return storage lock for data consistency"""
89
  return _get_global_lock()
90
 
91
+
92
  def get_scan_lock() -> LockType:
93
  """return scan_progress lock for data consistency"""
94
  return get_storage_lock()
95
 
96
+
97
  def get_namespace_object(namespace: str) -> Any:
98
  """Get an object for specific namespace"""
99
  global _share_objects, is_multiprocess, _manager
100
+
101
  if is_multiprocess and not _manager:
102
+ raise RuntimeError(
103
+ "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
104
+ )
105
 
106
  if namespace not in _share_objects:
107
  lock = _get_global_lock()
108
  with lock:
109
  if namespace not in _share_objects:
110
  if is_multiprocess:
111
+ _share_objects[namespace] = _manager.Value("O", None)
112
  else:
113
  _share_objects[namespace] = None
114
+
115
  return _share_objects[namespace]
116
 
117
+
118
  # 移除不再使用的函数
119
 
120
+
121
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
122
  """get storage space for specific storage type(namespace)"""
123
  global _shared_dicts, is_multiprocess, _manager
124
+
125
  if is_multiprocess and not _manager:
126
+ raise RuntimeError(
127
+ "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
128
+ )
129
 
130
  if namespace not in _shared_dicts:
131
  lock = _get_global_lock()
132
  with lock:
133
  if namespace not in _shared_dicts:
134
  _shared_dicts[namespace] = {}
135
+
136
  return _shared_dicts[namespace]
137
 
138
+
139
  def get_scan_progress() -> Dict[str, Any]:
140
  """get storage space for document scanning progress data"""
141
+ return get_namespace_data("scan_progress")
lightrag/lightrag.py CHANGED
@@ -266,7 +266,7 @@ class LightRAG:
266
 
267
  _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
268
 
269
- def __post_init__(self):
270
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
271
  set_logger(self.log_file_path, self.log_level)
272
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
 
266
 
267
  _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
268
 
269
+ def __post_init__(self):
270
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
271
  set_logger(self.log_file_path, self.log_level)
272
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
lightrag/utils.py CHANGED
@@ -55,6 +55,7 @@ def set_verbose_debug(enabled: bool):
55
  global VERBOSE_DEBUG
56
  VERBOSE_DEBUG = enabled
57
 
 
58
  statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
59
 
60
  # Initialize logger
@@ -100,6 +101,7 @@ class UnlimitedSemaphore:
100
 
101
  ENCODER = None
102
 
 
103
  @dataclass
104
  class EmbeddingFunc:
105
  embedding_dim: int
 
55
  global VERBOSE_DEBUG
56
  VERBOSE_DEBUG = enabled
57
 
58
+
59
  statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
60
 
61
  # Initialize logger
 
101
 
102
  ENCODER = None
103
 
104
+
105
  @dataclass
106
  class EmbeddingFunc:
107
  embedding_dim: int