YanSte commited on
Commit
dc0c15b
·
unverified ·
2 Parent(s): 3548e23 8db9467

Merge pull request #969 from danielaskdd/add-multi-worker-support

Browse files
.gitignore CHANGED
@@ -21,6 +21,7 @@ site/
21
 
22
  # Logs / Reports
23
  *.log
 
24
  *.logfire
25
  *.coverage/
26
  log/
 
21
 
22
  # Logs / Reports
23
  *.log
24
+ *.log.*
25
  *.logfire
26
  *.coverage/
27
  log/
.env.example → env.example RENAMED
@@ -1,6 +1,9 @@
 
 
1
  ### Server Configuration
2
  # HOST=0.0.0.0
3
  # PORT=9621
 
4
  # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
5
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
6
 
@@ -22,6 +25,9 @@
22
  ### Logging level
23
  # LOG_LEVEL=INFO
24
  # VERBOSE=False
 
 
 
25
 
26
  ### Max async calls for LLM
27
  # MAX_ASYNC=4
@@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
138
  ### Qdrant
139
  QDRANT_URL=http://localhost:16333
140
  # QDRANT_API_KEY=your-api-key
 
 
 
 
1
+ ### This is sample file of .env
2
+
3
  ### Server Configuration
4
  # HOST=0.0.0.0
5
  # PORT=9621
6
+ # WORKERS=1
7
  # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
8
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
9
 
 
25
  ### Logging level
26
  # LOG_LEVEL=INFO
27
  # VERBOSE=False
28
+ # LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory
29
+ # LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB
30
+ # LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5
31
 
32
  ### Max async calls for LLM
33
  # MAX_ASYNC=4
 
144
  ### Qdrant
145
  QDRANT_URL=http://localhost:16333
146
  # QDRANT_API_KEY=your-api-key
147
+
148
+ ### Redis
149
+ REDIS_URI=redis://localhost:6379
lightrag/api/README.md CHANGED
@@ -24,6 +24,8 @@ pip install -e ".[api]"
24
 
25
  ### Starting API Server with Default Settings
26
 
 
 
27
  LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
28
 
29
  * ollama
@@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama
92
  LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
93
 
94
  # start with ollama llm and ollama embedding (no apikey is needed)
95
- Light_server --llm-binding ollama --embedding-binding ollama
 
 
 
 
 
 
 
 
 
 
 
 
96
  ```
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ### For Azure OpenAI Backend
 
99
  Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
100
  ```bash
101
  # Change the resource group name, location and OpenAI resource name as needed
@@ -186,7 +221,7 @@ LightRAG supports binding to various LLM/Embedding backends:
186
  * openai & openai compatible
187
  * azure_openai
188
 
189
- Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type.
190
 
191
  ### Storage Types Supported
192
 
 
24
 
25
  ### Starting API Server with Default Settings
26
 
27
+ After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server`
28
+
29
  LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
30
 
31
  * ollama
 
94
  LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
95
 
96
  # start with ollama llm and ollama embedding (no apikey is needed)
97
+ light-server --llm-binding ollama --embedding-binding ollama
98
+ ```
99
+
100
+ ### Starting API Server with Gunicorn (Production)
101
+
102
+ For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities.
103
+
104
+ ```bash
105
+ # Start with lightrag-gunicorn command
106
+ lightrag-gunicorn --workers 4
107
+
108
+ # Alternatively, you can use the module directly
109
+ python -m lightrag.api.run_with_gunicorn --workers 4
110
  ```
111
 
112
+ The `--workers` parameter is crucial for performance:
113
+
114
+ - Determines how many worker processes Gunicorn will spawn to handle requests
115
+ - Each worker can handle concurrent requests using asyncio
116
+ - Recommended value is (2 x number_of_cores) + 1
117
+ - For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9
118
+ - Consider your server's memory when setting this value, as each worker consumes memory
119
+
120
+ Other important startup parameters:
121
+
122
+ - `--host`: Server listening address (default: 0.0.0.0)
123
+ - `--port`: Server listening port (default: 9621)
124
+ - `--timeout`: Request handling timeout (default: 150 seconds)
125
+ - `--log-level`: Logging level (default: INFO)
126
+ - `--ssl`: Enable HTTPS
127
+ - `--ssl-certfile`: Path to SSL certificate file
128
+ - `--ssl-keyfile`: Path to SSL private key file
129
+
130
+ The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`.
131
+
132
  ### For Azure OpenAI Backend
133
+
134
  Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
135
  ```bash
136
  # Change the resource group name, location and OpenAI resource name as needed
 
221
  * openai & openai compatible
222
  * azure_openai
223
 
224
+ Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
225
 
226
  ### Storage Types Supported
227
 
lightrag/api/gunicorn_config.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gunicorn_config.py
2
+ import os
3
+ import logging
4
+ from lightrag.kg.shared_storage import finalize_share_data
5
+ from lightrag.api.lightrag_server import LightragPathFilter
6
+
7
+ # Get log directory path from environment variable
8
+ log_dir = os.getenv("LOG_DIR", os.getcwd())
9
+ log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
10
+
11
+ # Get log file max size and backup count from environment variables
12
+ log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
13
+ log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
14
+
15
+ # These variables will be set by run_with_gunicorn.py
16
+ workers = None
17
+ bind = None
18
+ loglevel = None
19
+ certfile = None
20
+ keyfile = None
21
+
22
+ # Enable preload_app option
23
+ preload_app = True
24
+
25
+ # Use Uvicorn worker
26
+ worker_class = "uvicorn.workers.UvicornWorker"
27
+
28
+ # Other Gunicorn configurations
29
+ timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
30
+ keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
31
+
32
+ # Logging configuration
33
+ errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log
34
+ accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log
35
+
36
+ logconfig_dict = {
37
+ "version": 1,
38
+ "disable_existing_loggers": False,
39
+ "formatters": {
40
+ "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
41
+ },
42
+ "handlers": {
43
+ "console": {
44
+ "class": "logging.StreamHandler",
45
+ "formatter": "standard",
46
+ "stream": "ext://sys.stdout",
47
+ },
48
+ "file": {
49
+ "class": "logging.handlers.RotatingFileHandler",
50
+ "formatter": "standard",
51
+ "filename": log_file_path,
52
+ "maxBytes": log_max_bytes,
53
+ "backupCount": log_backup_count,
54
+ "encoding": "utf8",
55
+ },
56
+ },
57
+ "filters": {
58
+ "path_filter": {
59
+ "()": "lightrag.api.lightrag_server.LightragPathFilter",
60
+ },
61
+ },
62
+ "loggers": {
63
+ "lightrag": {
64
+ "handlers": ["console", "file"],
65
+ "level": loglevel.upper() if loglevel else "INFO",
66
+ "propagate": False,
67
+ },
68
+ "gunicorn": {
69
+ "handlers": ["console", "file"],
70
+ "level": loglevel.upper() if loglevel else "INFO",
71
+ "propagate": False,
72
+ },
73
+ "gunicorn.error": {
74
+ "handlers": ["console", "file"],
75
+ "level": loglevel.upper() if loglevel else "INFO",
76
+ "propagate": False,
77
+ },
78
+ "gunicorn.access": {
79
+ "handlers": ["console", "file"],
80
+ "level": loglevel.upper() if loglevel else "INFO",
81
+ "propagate": False,
82
+ "filters": ["path_filter"],
83
+ },
84
+ },
85
+ }
86
+
87
+
88
+ def on_starting(server):
89
+ """
90
+ Executed when Gunicorn starts, before forking the first worker processes
91
+ You can use this function to do more initialization tasks for all processes
92
+ """
93
+ print("=" * 80)
94
+ print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)")
95
+ print(f"Process ID: {os.getpid()}")
96
+ print("=" * 80)
97
+
98
+ # Memory usage monitoring
99
+ try:
100
+ import psutil
101
+
102
+ process = psutil.Process(os.getpid())
103
+ memory_info = process.memory_info()
104
+ msg = (
105
+ f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB"
106
+ )
107
+ print(msg)
108
+ except ImportError:
109
+ print("psutil not installed, skipping memory usage reporting")
110
+
111
+ print("Gunicorn initialization complete, forking workers...\n")
112
+
113
+
114
+ def on_exit(server):
115
+ """
116
+ Executed when Gunicorn is shutting down.
117
+ This is a good place to release shared resources.
118
+ """
119
+ print("=" * 80)
120
+ print("GUNICORN MASTER PROCESS: Shutting down")
121
+ print(f"Process ID: {os.getpid()}")
122
+ print("=" * 80)
123
+
124
+ # Release shared resources
125
+ finalize_share_data()
126
+
127
+ print("=" * 80)
128
+ print("Gunicorn shutdown complete")
129
+ print("=" * 80)
130
+
131
+
132
+ def post_fork(server, worker):
133
+ """
134
+ Executed after a worker has been forked.
135
+ This is a good place to set up worker-specific configurations.
136
+ """
137
+ # Configure formatters
138
+ detailed_formatter = logging.Formatter(
139
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
140
+ )
141
+ simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
142
+
143
+ def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
144
+ """Set up a logger with console and file handlers"""
145
+ logger_instance = logging.getLogger(logger_name)
146
+ logger_instance.setLevel(level)
147
+ logger_instance.handlers = [] # Clear existing handlers
148
+ logger_instance.propagate = False
149
+
150
+ # Add console handler
151
+ console_handler = logging.StreamHandler()
152
+ console_handler.setFormatter(simple_formatter)
153
+ console_handler.setLevel(level)
154
+ logger_instance.addHandler(console_handler)
155
+
156
+ # Add file handler
157
+ file_handler = logging.handlers.RotatingFileHandler(
158
+ filename=log_file_path,
159
+ maxBytes=log_max_bytes,
160
+ backupCount=log_backup_count,
161
+ encoding="utf-8",
162
+ )
163
+ file_handler.setFormatter(detailed_formatter)
164
+ file_handler.setLevel(level)
165
+ logger_instance.addHandler(file_handler)
166
+
167
+ # Add path filter if requested
168
+ if add_filter:
169
+ path_filter = LightragPathFilter()
170
+ logger_instance.addFilter(path_filter)
171
+
172
+ # Set up main loggers
173
+ log_level = loglevel.upper() if loglevel else "INFO"
174
+ setup_logger("uvicorn", log_level)
175
+ setup_logger("uvicorn.access", log_level, add_filter=True)
176
+ setup_logger("lightrag", log_level, add_filter=True)
177
+
178
+ # Set up lightrag submodule loggers
179
+ for name in logging.root.manager.loggerDict:
180
+ if name.startswith("lightrag."):
181
+ setup_logger(name, log_level, add_filter=True)
182
+
183
+ # Disable uvicorn.error logger
184
+ uvicorn_error_logger = logging.getLogger("uvicorn.error")
185
+ uvicorn_error_logger.handlers = []
186
+ uvicorn_error_logger.setLevel(logging.CRITICAL)
187
+ uvicorn_error_logger.propagate = False
lightrag/api/lightrag_server.py CHANGED
@@ -8,11 +8,12 @@ from fastapi import (
8
  )
9
  from fastapi.responses import FileResponse
10
  import asyncio
11
- import threading
12
  import os
13
- from fastapi.staticfiles import StaticFiles
14
  import logging
15
- from typing import Dict
 
 
 
16
  from pathlib import Path
17
  import configparser
18
  from ascii_colors import ASCIIColors
@@ -29,7 +30,6 @@ from lightrag import LightRAG
29
  from lightrag.types import GPTKeywordExtractionFormat
30
  from lightrag.api import __api_version__
31
  from lightrag.utils import EmbeddingFunc
32
- from lightrag.utils import logger
33
  from .routers.document_routes import (
34
  DocumentManager,
35
  create_document_routes,
@@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes
39
  from .routers.graph_routes import create_graph_routes
40
  from .routers.ollama_api import OllamaAPI
41
 
 
 
 
 
 
 
 
 
42
  # Load environment variables
43
- try:
44
- load_dotenv(override=True)
45
- except Exception as e:
46
- logger.warning(f"Failed to load .env file: {e}")
47
 
48
  # Initialize config parser
49
  config = configparser.ConfigParser()
50
  config.read("config.ini")
51
 
52
- # Global configuration
53
- global_top_k = 60 # default value
54
 
55
- # Global progress tracker
56
- scan_progress: Dict = {
57
- "is_scanning": False,
58
- "current_file": "",
59
- "indexed_count": 0,
60
- "total_files": 0,
61
- "progress": 0,
62
- }
63
 
64
- # Lock for thread-safe operations
65
- progress_lock = threading.Lock()
66
-
67
-
68
- class AccessLogFilter(logging.Filter):
69
  def __init__(self):
70
  super().__init__()
71
  # Define paths to be filtered
@@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter):
73
 
74
  def filter(self, record):
75
  try:
 
76
  if not hasattr(record, "args") or not isinstance(record.args, tuple):
77
  return True
78
  if len(record.args) < 5:
79
  return True
80
 
 
81
  method = record.args[1]
82
  path = record.args[2]
83
  status = record.args[4]
84
- # print(f"Debug - Method: {method}, Path: {path}, Status: {status}")
85
- # print(f"Debug - Filtered paths: {self.filtered_paths}")
86
 
 
87
  if (
88
  method == "GET"
89
  and (status == 200 or status == 304)
@@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter):
92
  return False
93
 
94
  return True
95
-
96
  except Exception:
 
97
  return True
98
 
99
 
100
  def create_app(args):
101
- # Set global top_k
102
- global global_top_k
103
- global_top_k = args.top_k # save top_k from args
104
-
105
- # Initialize verbose debug setting
106
- from lightrag.utils import set_verbose_debug
107
-
108
  set_verbose_debug(args.verbose)
109
 
110
  # Verify that bindings are correctly setup
@@ -138,11 +126,6 @@ def create_app(args):
138
  if not os.path.exists(args.ssl_keyfile):
139
  raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
140
 
141
- # Setup logging
142
- logging.basicConfig(
143
- format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
144
- )
145
-
146
  # Check if API key is provided either through env var or args
147
  api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
148
 
@@ -158,28 +141,23 @@ def create_app(args):
158
  try:
159
  # Initialize database connections
160
  await rag.initialize_storages()
 
161
 
162
  # Auto scan documents if enabled
163
  if args.auto_scan_at_startup:
164
- # Start scanning in background
165
- with progress_lock:
166
- if not scan_progress["is_scanning"]:
167
- scan_progress["is_scanning"] = True
168
- scan_progress["indexed_count"] = 0
169
- scan_progress["progress"] = 0
170
- # Create background task
171
- task = asyncio.create_task(
172
- run_scanning_process(rag, doc_manager)
173
- )
174
- app.state.background_tasks.add(task)
175
- task.add_done_callback(app.state.background_tasks.discard)
176
- ASCIIColors.info(
177
- f"Started background scanning of documents from {args.input_dir}"
178
- )
179
- else:
180
- ASCIIColors.info(
181
- "Skip document scanning(another scanning is active)"
182
- )
183
 
184
  ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
185
 
@@ -398,6 +376,9 @@ def create_app(args):
398
  @app.get("/health", dependencies=[Depends(optional_api_key)])
399
  async def get_status():
400
  """Get current system status"""
 
 
 
401
  return {
402
  "status": "healthy",
403
  "working_directory": str(args.working_dir),
@@ -417,6 +398,7 @@ def create_app(args):
417
  "graph_storage": args.graph_storage,
418
  "vector_storage": args.vector_storage,
419
  },
 
420
  }
421
 
422
  # Webui mount webui/index.html
@@ -435,12 +417,30 @@ def create_app(args):
435
  return app
436
 
437
 
438
- def main():
439
- args = parse_args()
440
- import uvicorn
441
- import logging.config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
- # Configure uvicorn logging
444
  logging.config.dictConfig(
445
  {
446
  "version": 1,
@@ -449,36 +449,106 @@ def main():
449
  "default": {
450
  "format": "%(levelname)s: %(message)s",
451
  },
 
 
 
452
  },
453
  "handlers": {
454
- "default": {
455
  "formatter": "default",
456
  "class": "logging.StreamHandler",
457
  "stream": "ext://sys.stderr",
458
  },
 
 
 
 
 
 
 
 
459
  },
460
  "loggers": {
 
 
 
 
 
 
461
  "uvicorn.access": {
462
- "handlers": ["default"],
 
 
 
 
 
 
 
 
 
 
 
463
  "level": "INFO",
464
  "propagate": False,
 
 
 
 
 
 
465
  },
466
  },
467
  }
468
  )
469
 
470
- # Add filter to uvicorn access logger
471
- uvicorn_access_logger = logging.getLogger("uvicorn.access")
472
- uvicorn_access_logger.addFilter(AccessLogFilter())
473
 
474
- app = create_app(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  display_splash_screen(args)
 
 
 
 
 
476
  uvicorn_config = {
477
- "app": app,
478
  "host": args.host,
479
  "port": args.port,
480
  "log_config": None, # Disable default config
481
  }
 
482
  if args.ssl:
483
  uvicorn_config.update(
484
  {
@@ -486,6 +556,8 @@ def main():
486
  "ssl_keyfile": args.ssl_keyfile,
487
  }
488
  )
 
 
489
  uvicorn.run(**uvicorn_config)
490
 
491
 
 
8
  )
9
  from fastapi.responses import FileResponse
10
  import asyncio
 
11
  import os
 
12
  import logging
13
+ import logging.config
14
+ import uvicorn
15
+ import pipmaster as pm
16
+ from fastapi.staticfiles import StaticFiles
17
  from pathlib import Path
18
  import configparser
19
  from ascii_colors import ASCIIColors
 
30
  from lightrag.types import GPTKeywordExtractionFormat
31
  from lightrag.api import __api_version__
32
  from lightrag.utils import EmbeddingFunc
 
33
  from .routers.document_routes import (
34
  DocumentManager,
35
  create_document_routes,
 
39
  from .routers.graph_routes import create_graph_routes
40
  from .routers.ollama_api import OllamaAPI
41
 
42
+ from lightrag.utils import logger, set_verbose_debug
43
+ from lightrag.kg.shared_storage import (
44
+ get_namespace_data,
45
+ get_pipeline_status_lock,
46
+ initialize_pipeline_status,
47
+ get_all_update_flags_status,
48
+ )
49
+
50
  # Load environment variables
51
+ load_dotenv(override=True)
 
 
 
52
 
53
  # Initialize config parser
54
  config = configparser.ConfigParser()
55
  config.read("config.ini")
56
 
 
 
57
 
58
+ class LightragPathFilter(logging.Filter):
59
+ """Filter for lightrag logger to filter out frequent path access logs"""
 
 
 
 
 
 
60
 
 
 
 
 
 
61
  def __init__(self):
62
  super().__init__()
63
  # Define paths to be filtered
 
65
 
66
  def filter(self, record):
67
  try:
68
+ # Check if record has the required attributes for an access log
69
  if not hasattr(record, "args") or not isinstance(record.args, tuple):
70
  return True
71
  if len(record.args) < 5:
72
  return True
73
 
74
+ # Extract method, path and status from the record args
75
  method = record.args[1]
76
  path = record.args[2]
77
  status = record.args[4]
 
 
78
 
79
+ # Filter out successful GET requests to filtered paths
80
  if (
81
  method == "GET"
82
  and (status == 200 or status == 304)
 
85
  return False
86
 
87
  return True
 
88
  except Exception:
89
+ # In case of any error, let the message through
90
  return True
91
 
92
 
93
  def create_app(args):
94
+ # Setup logging
95
+ logger.setLevel(args.log_level)
 
 
 
 
 
96
  set_verbose_debug(args.verbose)
97
 
98
  # Verify that bindings are correctly setup
 
126
  if not os.path.exists(args.ssl_keyfile):
127
  raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
128
 
 
 
 
 
 
129
  # Check if API key is provided either through env var or args
130
  api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
131
 
 
141
  try:
142
  # Initialize database connections
143
  await rag.initialize_storages()
144
+ await initialize_pipeline_status()
145
 
146
  # Auto scan documents if enabled
147
  if args.auto_scan_at_startup:
148
+ # Check if a task is already running (with lock protection)
149
+ pipeline_status = await get_namespace_data("pipeline_status")
150
+ should_start_task = False
151
+ async with get_pipeline_status_lock():
152
+ if not pipeline_status.get("busy", False):
153
+ should_start_task = True
154
+ # Only start the task if no other task is running
155
+ if should_start_task:
156
+ # Create background task
157
+ task = asyncio.create_task(run_scanning_process(rag, doc_manager))
158
+ app.state.background_tasks.add(task)
159
+ task.add_done_callback(app.state.background_tasks.discard)
160
+ logger.info("Auto scan task started at startup.")
 
 
 
 
 
 
161
 
162
  ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
163
 
 
376
  @app.get("/health", dependencies=[Depends(optional_api_key)])
377
  async def get_status():
378
  """Get current system status"""
379
+ # Get update flags status for all namespaces
380
+ update_status = await get_all_update_flags_status()
381
+
382
  return {
383
  "status": "healthy",
384
  "working_directory": str(args.working_dir),
 
398
  "graph_storage": args.graph_storage,
399
  "vector_storage": args.vector_storage,
400
  },
401
+ "update_status": update_status,
402
  }
403
 
404
  # Webui mount webui/index.html
 
417
  return app
418
 
419
 
420
+ def get_application(args=None):
421
+ """Factory function for creating the FastAPI application"""
422
+ if args is None:
423
+ args = parse_args()
424
+ return create_app(args)
425
+
426
+
427
+ def configure_logging():
428
+ """Configure logging for uvicorn startup"""
429
+
430
+ # Reset any existing handlers to ensure clean configuration
431
+ for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
432
+ logger = logging.getLogger(logger_name)
433
+ logger.handlers = []
434
+ logger.filters = []
435
+
436
+ # Get log directory path from environment variable
437
+ log_dir = os.getenv("LOG_DIR", os.getcwd())
438
+ log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
439
+
440
+ # Get log file max size and backup count from environment variables
441
+ log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
442
+ log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
443
 
 
444
  logging.config.dictConfig(
445
  {
446
  "version": 1,
 
449
  "default": {
450
  "format": "%(levelname)s: %(message)s",
451
  },
452
+ "detailed": {
453
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
454
+ },
455
  },
456
  "handlers": {
457
+ "console": {
458
  "formatter": "default",
459
  "class": "logging.StreamHandler",
460
  "stream": "ext://sys.stderr",
461
  },
462
+ "file": {
463
+ "formatter": "detailed",
464
+ "class": "logging.handlers.RotatingFileHandler",
465
+ "filename": log_file_path,
466
+ "maxBytes": log_max_bytes,
467
+ "backupCount": log_backup_count,
468
+ "encoding": "utf-8",
469
+ },
470
  },
471
  "loggers": {
472
+ # Configure all uvicorn related loggers
473
+ "uvicorn": {
474
+ "handlers": ["console", "file"],
475
+ "level": "INFO",
476
+ "propagate": False,
477
+ },
478
  "uvicorn.access": {
479
+ "handlers": ["console", "file"],
480
+ "level": "INFO",
481
+ "propagate": False,
482
+ "filters": ["path_filter"],
483
+ },
484
+ "uvicorn.error": {
485
+ "handlers": ["console", "file"],
486
+ "level": "INFO",
487
+ "propagate": False,
488
+ },
489
+ "lightrag": {
490
+ "handlers": ["console", "file"],
491
  "level": "INFO",
492
  "propagate": False,
493
+ "filters": ["path_filter"],
494
+ },
495
+ },
496
+ "filters": {
497
+ "path_filter": {
498
+ "()": "lightrag.api.lightrag_server.LightragPathFilter",
499
  },
500
  },
501
  }
502
  )
503
 
 
 
 
504
 
505
+ def check_and_install_dependencies():
506
+ """Check and install required dependencies"""
507
+ required_packages = [
508
+ "uvicorn",
509
+ "tiktoken",
510
+ "fastapi",
511
+ # Add other required packages here
512
+ ]
513
+
514
+ for package in required_packages:
515
+ if not pm.is_installed(package):
516
+ print(f"Installing {package}...")
517
+ pm.install(package)
518
+ print(f"{package} installed successfully")
519
+
520
+
521
+ def main():
522
+ # Check if running under Gunicorn
523
+ if "GUNICORN_CMD_ARGS" in os.environ:
524
+ # If started with Gunicorn, return directly as Gunicorn will call get_application
525
+ print("Running under Gunicorn - worker management handled by Gunicorn")
526
+ return
527
+
528
+ # Check and install dependencies
529
+ check_and_install_dependencies()
530
+
531
+ from multiprocessing import freeze_support
532
+
533
+ freeze_support()
534
+
535
+ # Configure logging before parsing args
536
+ configure_logging()
537
+
538
+ args = parse_args(is_uvicorn_mode=True)
539
  display_splash_screen(args)
540
+
541
+ # Create application instance directly instead of using factory function
542
+ app = create_app(args)
543
+
544
+ # Start Uvicorn in single process mode
545
  uvicorn_config = {
546
+ "app": app, # Pass application instance directly instead of string path
547
  "host": args.host,
548
  "port": args.port,
549
  "log_config": None, # Disable default config
550
  }
551
+
552
  if args.ssl:
553
  uvicorn_config.update(
554
  {
 
556
  "ssl_keyfile": args.ssl_keyfile,
557
  }
558
  )
559
+
560
+ print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
561
  uvicorn.run(**uvicorn_config)
562
 
563
 
lightrag/api/routers/document_routes.py CHANGED
@@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
3
  """
4
 
5
  import asyncio
6
- import logging
7
- import os
8
  import aiofiles
9
  import shutil
10
  import traceback
@@ -12,7 +11,6 @@ import pipmaster as pm
12
  from datetime import datetime
13
  from pathlib import Path
14
  from typing import Dict, List, Optional, Any
15
-
16
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
17
  from pydantic import BaseModel, Field, field_validator
18
 
@@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency
23
 
24
  router = APIRouter(prefix="/documents", tags=["documents"])
25
 
26
- # Global progress tracker
27
- scan_progress: Dict = {
28
- "is_scanning": False,
29
- "current_file": "",
30
- "indexed_count": 0,
31
- "total_files": 0,
32
- "progress": 0,
33
- }
34
-
35
- # Lock for thread-safe operations
36
- progress_lock = asyncio.Lock()
37
-
38
  # Temporary file prefix
39
  temp_prefix = "__tmp__"
40
 
@@ -161,19 +147,12 @@ class DocumentManager:
161
  """Scan input directory for new files"""
162
  new_files = []
163
  for ext in self.supported_extensions:
164
- logging.debug(f"Scanning for {ext} files in {self.input_dir}")
165
  for file_path in self.input_dir.rglob(f"*{ext}"):
166
  if file_path not in self.indexed_files:
167
  new_files.append(file_path)
168
  return new_files
169
 
170
- # def scan_directory(self) -> List[Path]:
171
- # new_files = []
172
- # for ext in self.supported_extensions:
173
- # for file_path in self.input_dir.rglob(f"*{ext}"):
174
- # new_files.append(file_path)
175
- # return new_files
176
-
177
  def mark_as_indexed(self, file_path: Path):
178
  self.indexed_files.add(file_path)
179
 
@@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
287
  )
288
  content += "\n"
289
  case _:
290
- logging.error(
291
  f"Unsupported file type: {file_path.name} (extension {ext})"
292
  )
293
  return False
@@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
295
  # Insert into the RAG queue
296
  if content:
297
  await rag.apipeline_enqueue_documents(content)
298
- logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
299
  return True
300
  else:
301
- logging.error(f"No content could be extracted from file: {file_path.name}")
302
 
303
  except Exception as e:
304
- logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
305
- logging.error(traceback.format_exc())
306
  finally:
307
  if file_path.name.startswith(temp_prefix):
308
  try:
309
  file_path.unlink()
310
  except Exception as e:
311
- logging.error(f"Error deleting file {file_path}: {str(e)}")
312
  return False
313
 
314
 
@@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
324
  await rag.apipeline_process_enqueue_documents()
325
 
326
  except Exception as e:
327
- logging.error(f"Error indexing file {file_path.name}: {str(e)}")
328
- logging.error(traceback.format_exc())
329
 
330
 
331
  async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
@@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
349
  if enqueued:
350
  await rag.apipeline_process_enqueue_documents()
351
  except Exception as e:
352
- logging.error(f"Error indexing files: {str(e)}")
353
- logging.error(traceback.format_exc())
354
 
355
 
356
  async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
@@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
393
  """Background task to scan and index documents"""
394
  try:
395
  new_files = doc_manager.scan_directory_for_new_files()
396
- scan_progress["total_files"] = len(new_files)
 
397
 
398
- logging.info(f"Found {len(new_files)} new files to index.")
399
- for file_path in new_files:
400
  try:
401
- async with progress_lock:
402
- scan_progress["current_file"] = os.path.basename(file_path)
403
-
404
  await pipeline_index_file(rag, file_path)
405
-
406
- async with progress_lock:
407
- scan_progress["indexed_count"] += 1
408
- scan_progress["progress"] = (
409
- scan_progress["indexed_count"] / scan_progress["total_files"]
410
- ) * 100
411
-
412
  except Exception as e:
413
- logging.error(f"Error indexing file {file_path}: {str(e)}")
414
 
415
  except Exception as e:
416
- logging.error(f"Error during scanning process: {str(e)}")
417
- finally:
418
- async with progress_lock:
419
- scan_progress["is_scanning"] = False
420
 
421
 
422
  def create_document_routes(
@@ -436,34 +402,10 @@ def create_document_routes(
436
  Returns:
437
  dict: A dictionary containing the scanning status
438
  """
439
- async with progress_lock:
440
- if scan_progress["is_scanning"]:
441
- return {"status": "already_scanning"}
442
-
443
- scan_progress["is_scanning"] = True
444
- scan_progress["indexed_count"] = 0
445
- scan_progress["progress"] = 0
446
-
447
  # Start the scanning process in the background
448
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
449
  return {"status": "scanning_started"}
450
 
451
- @router.get("/scan-progress")
452
- async def get_scan_progress():
453
- """
454
- Get the current progress of the document scanning process.
455
-
456
- Returns:
457
- dict: A dictionary containing the current scanning progress information including:
458
- - is_scanning: Whether a scan is currently in progress
459
- - current_file: The file currently being processed
460
- - indexed_count: Number of files indexed so far
461
- - total_files: Total number of files to process
462
- - progress: Percentage of completion
463
- """
464
- async with progress_lock:
465
- return scan_progress
466
-
467
  @router.post("/upload", dependencies=[Depends(optional_api_key)])
468
  async def upload_to_input_dir(
469
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
@@ -504,8 +446,8 @@ def create_document_routes(
504
  message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
505
  )
506
  except Exception as e:
507
- logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
508
- logging.error(traceback.format_exc())
509
  raise HTTPException(status_code=500, detail=str(e))
510
 
511
  @router.post(
@@ -537,8 +479,8 @@ def create_document_routes(
537
  message="Text successfully received. Processing will continue in background.",
538
  )
539
  except Exception as e:
540
- logging.error(f"Error /documents/text: {str(e)}")
541
- logging.error(traceback.format_exc())
542
  raise HTTPException(status_code=500, detail=str(e))
543
 
544
  @router.post(
@@ -572,8 +514,8 @@ def create_document_routes(
572
  message="Text successfully received. Processing will continue in background.",
573
  )
574
  except Exception as e:
575
- logging.error(f"Error /documents/text: {str(e)}")
576
- logging.error(traceback.format_exc())
577
  raise HTTPException(status_code=500, detail=str(e))
578
 
579
  @router.post(
@@ -615,8 +557,8 @@ def create_document_routes(
615
  message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
616
  )
617
  except Exception as e:
618
- logging.error(f"Error /documents/file: {str(e)}")
619
- logging.error(traceback.format_exc())
620
  raise HTTPException(status_code=500, detail=str(e))
621
 
622
  @router.post(
@@ -678,8 +620,8 @@ def create_document_routes(
678
 
679
  return InsertResponse(status=status, message=status_message)
680
  except Exception as e:
681
- logging.error(f"Error /documents/batch: {str(e)}")
682
- logging.error(traceback.format_exc())
683
  raise HTTPException(status_code=500, detail=str(e))
684
 
685
  @router.delete(
@@ -706,8 +648,42 @@ def create_document_routes(
706
  status="success", message="All documents cleared successfully"
707
  )
708
  except Exception as e:
709
- logging.error(f"Error DELETE /documents: {str(e)}")
710
- logging.error(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  raise HTTPException(status_code=500, detail=str(e))
712
 
713
  @router.get("", dependencies=[Depends(optional_api_key)])
@@ -763,8 +739,8 @@ def create_document_routes(
763
  )
764
  return response
765
  except Exception as e:
766
- logging.error(f"Error GET /documents: {str(e)}")
767
- logging.error(traceback.format_exc())
768
  raise HTTPException(status_code=500, detail=str(e))
769
 
770
  return router
 
3
  """
4
 
5
  import asyncio
6
+ from lightrag.utils import logger
 
7
  import aiofiles
8
  import shutil
9
  import traceback
 
11
  from datetime import datetime
12
  from pathlib import Path
13
  from typing import Dict, List, Optional, Any
 
14
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field, field_validator
16
 
 
21
 
22
  router = APIRouter(prefix="/documents", tags=["documents"])
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Temporary file prefix
25
  temp_prefix = "__tmp__"
26
 
 
147
  """Scan input directory for new files"""
148
  new_files = []
149
  for ext in self.supported_extensions:
150
+ logger.debug(f"Scanning for {ext} files in {self.input_dir}")
151
  for file_path in self.input_dir.rglob(f"*{ext}"):
152
  if file_path not in self.indexed_files:
153
  new_files.append(file_path)
154
  return new_files
155
 
 
 
 
 
 
 
 
156
  def mark_as_indexed(self, file_path: Path):
157
  self.indexed_files.add(file_path)
158
 
 
266
  )
267
  content += "\n"
268
  case _:
269
+ logger.error(
270
  f"Unsupported file type: {file_path.name} (extension {ext})"
271
  )
272
  return False
 
274
  # Insert into the RAG queue
275
  if content:
276
  await rag.apipeline_enqueue_documents(content)
277
+ logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
278
  return True
279
  else:
280
+ logger.error(f"No content could be extracted from file: {file_path.name}")
281
 
282
  except Exception as e:
283
+ logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
284
+ logger.error(traceback.format_exc())
285
  finally:
286
  if file_path.name.startswith(temp_prefix):
287
  try:
288
  file_path.unlink()
289
  except Exception as e:
290
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
291
  return False
292
 
293
 
 
303
  await rag.apipeline_process_enqueue_documents()
304
 
305
  except Exception as e:
306
+ logger.error(f"Error indexing file {file_path.name}: {str(e)}")
307
+ logger.error(traceback.format_exc())
308
 
309
 
310
  async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
 
328
  if enqueued:
329
  await rag.apipeline_process_enqueue_documents()
330
  except Exception as e:
331
+ logger.error(f"Error indexing files: {str(e)}")
332
+ logger.error(traceback.format_exc())
333
 
334
 
335
  async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
 
372
  """Background task to scan and index documents"""
373
  try:
374
  new_files = doc_manager.scan_directory_for_new_files()
375
+ total_files = len(new_files)
376
+ logger.info(f"Found {total_files} new files to index.")
377
 
378
+ for idx, file_path in enumerate(new_files):
 
379
  try:
 
 
 
380
  await pipeline_index_file(rag, file_path)
 
 
 
 
 
 
 
381
  except Exception as e:
382
+ logger.error(f"Error indexing file {file_path}: {str(e)}")
383
 
384
  except Exception as e:
385
+ logger.error(f"Error during scanning process: {str(e)}")
 
 
 
386
 
387
 
388
  def create_document_routes(
 
402
  Returns:
403
  dict: A dictionary containing the scanning status
404
  """
 
 
 
 
 
 
 
 
405
  # Start the scanning process in the background
406
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
407
  return {"status": "scanning_started"}
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  @router.post("/upload", dependencies=[Depends(optional_api_key)])
410
  async def upload_to_input_dir(
411
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
 
446
  message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
447
  )
448
  except Exception as e:
449
+ logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
450
+ logger.error(traceback.format_exc())
451
  raise HTTPException(status_code=500, detail=str(e))
452
 
453
  @router.post(
 
479
  message="Text successfully received. Processing will continue in background.",
480
  )
481
  except Exception as e:
482
+ logger.error(f"Error /documents/text: {str(e)}")
483
+ logger.error(traceback.format_exc())
484
  raise HTTPException(status_code=500, detail=str(e))
485
 
486
  @router.post(
 
514
  message="Text successfully received. Processing will continue in background.",
515
  )
516
  except Exception as e:
517
+ logger.error(f"Error /documents/text: {str(e)}")
518
+ logger.error(traceback.format_exc())
519
  raise HTTPException(status_code=500, detail=str(e))
520
 
521
  @router.post(
 
557
  message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
558
  )
559
  except Exception as e:
560
+ logger.error(f"Error /documents/file: {str(e)}")
561
+ logger.error(traceback.format_exc())
562
  raise HTTPException(status_code=500, detail=str(e))
563
 
564
  @router.post(
 
620
 
621
  return InsertResponse(status=status, message=status_message)
622
  except Exception as e:
623
+ logger.error(f"Error /documents/batch: {str(e)}")
624
+ logger.error(traceback.format_exc())
625
  raise HTTPException(status_code=500, detail=str(e))
626
 
627
  @router.delete(
 
648
  status="success", message="All documents cleared successfully"
649
  )
650
  except Exception as e:
651
+ logger.error(f"Error DELETE /documents: {str(e)}")
652
+ logger.error(traceback.format_exc())
653
+ raise HTTPException(status_code=500, detail=str(e))
654
+
655
+ @router.get("/pipeline_status", dependencies=[Depends(optional_api_key)])
656
+ async def get_pipeline_status():
657
+ """
658
+ Get the current status of the document indexing pipeline.
659
+
660
+ This endpoint returns information about the current state of the document processing pipeline,
661
+ including whether it's busy, the current job name, when it started, how many documents
662
+ are being processed, how many batches there are, and which batch is currently being processed.
663
+
664
+ Returns:
665
+ dict: A dictionary containing the pipeline status information
666
+ """
667
+ try:
668
+ from lightrag.kg.shared_storage import get_namespace_data
669
+
670
+ pipeline_status = await get_namespace_data("pipeline_status")
671
+
672
+ # Convert to regular dict if it's a Manager.dict
673
+ status_dict = dict(pipeline_status)
674
+
675
+ # Convert history_messages to a regular list if it's a Manager.list
676
+ if "history_messages" in status_dict:
677
+ status_dict["history_messages"] = list(status_dict["history_messages"])
678
+
679
+ # Format the job_start time if it exists
680
+ if status_dict.get("job_start"):
681
+ status_dict["job_start"] = str(status_dict["job_start"])
682
+
683
+ return status_dict
684
+ except Exception as e:
685
+ logger.error(f"Error getting pipeline status: {str(e)}")
686
+ logger.error(traceback.format_exc())
687
  raise HTTPException(status_code=500, detail=str(e))
688
 
689
  @router.get("", dependencies=[Depends(optional_api_key)])
 
739
  )
740
  return response
741
  except Exception as e:
742
+ logger.error(f"Error GET /documents: {str(e)}")
743
+ logger.error(traceback.format_exc())
744
  raise HTTPException(status_code=500, detail=str(e))
745
 
746
  return router
lightrag/api/run_with_gunicorn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Start LightRAG server with Gunicorn
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import signal
9
+ import pipmaster as pm
10
+ from lightrag.api.utils_api import parse_args, display_splash_screen
11
+ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
+
13
+
14
+ def check_and_install_dependencies():
15
+ """Check and install required dependencies"""
16
+ required_packages = [
17
+ "gunicorn",
18
+ "tiktoken",
19
+ "psutil",
20
+ # Add other required packages here
21
+ ]
22
+
23
+ for package in required_packages:
24
+ if not pm.is_installed(package):
25
+ print(f"Installing {package}...")
26
+ pm.install(package)
27
+ print(f"{package} installed successfully")
28
+
29
+
30
+ # Signal handler for graceful shutdown
31
+ def signal_handler(sig, frame):
32
+ print("\n\n" + "=" * 80)
33
+ print("RECEIVED TERMINATION SIGNAL")
34
+ print(f"Process ID: {os.getpid()}")
35
+ print("=" * 80 + "\n")
36
+
37
+ # Release shared resources
38
+ finalize_share_data()
39
+
40
+ # Exit with success status
41
+ sys.exit(0)
42
+
43
+
44
+ def main():
45
+ # Check and install dependencies
46
+ check_and_install_dependencies()
47
+
48
+ # Register signal handlers for graceful shutdown
49
+ signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
50
+ signal.signal(signal.SIGTERM, signal_handler) # kill command
51
+
52
+ # Parse all arguments using parse_args
53
+ args = parse_args(is_uvicorn_mode=False)
54
+
55
+ # Display startup information
56
+ display_splash_screen(args)
57
+
58
+ print("🚀 Starting LightRAG with Gunicorn")
59
+ print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
60
+ print("🔍 Preloading app: Enabled")
61
+ print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
62
+ print("\n\n" + "=" * 80)
63
+ print("MAIN PROCESS INITIALIZATION")
64
+ print(f"Process ID: {os.getpid()}")
65
+ print(f"Workers setting: {args.workers}")
66
+ print("=" * 80 + "\n")
67
+
68
+ # Import Gunicorn's StandaloneApplication
69
+ from gunicorn.app.base import BaseApplication
70
+
71
+ # Define a custom application class that loads our config
72
+ class GunicornApp(BaseApplication):
73
+ def __init__(self, app, options=None):
74
+ self.options = options or {}
75
+ self.application = app
76
+ super().__init__()
77
+
78
+ def load_config(self):
79
+ # Define valid Gunicorn configuration options
80
+ valid_options = {
81
+ "bind",
82
+ "workers",
83
+ "worker_class",
84
+ "timeout",
85
+ "keepalive",
86
+ "preload_app",
87
+ "errorlog",
88
+ "accesslog",
89
+ "loglevel",
90
+ "certfile",
91
+ "keyfile",
92
+ "limit_request_line",
93
+ "limit_request_fields",
94
+ "limit_request_field_size",
95
+ "graceful_timeout",
96
+ "max_requests",
97
+ "max_requests_jitter",
98
+ }
99
+
100
+ # Special hooks that need to be set separately
101
+ special_hooks = {
102
+ "on_starting",
103
+ "on_reload",
104
+ "on_exit",
105
+ "pre_fork",
106
+ "post_fork",
107
+ "pre_exec",
108
+ "pre_request",
109
+ "post_request",
110
+ "worker_init",
111
+ "worker_exit",
112
+ "nworkers_changed",
113
+ "child_exit",
114
+ }
115
+
116
+ # Import and configure the gunicorn_config module
117
+ from lightrag.api import gunicorn_config
118
+
119
+ # Set configuration variables in gunicorn_config, prioritizing command line arguments
120
+ gunicorn_config.workers = (
121
+ args.workers if args.workers else int(os.getenv("WORKERS", 1))
122
+ )
123
+
124
+ # Bind configuration prioritizes command line arguments
125
+ host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
126
+ port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
127
+ gunicorn_config.bind = f"{host}:{port}"
128
+
129
+ # Log level configuration prioritizes command line arguments
130
+ gunicorn_config.loglevel = (
131
+ args.log_level.lower()
132
+ if args.log_level
133
+ else os.getenv("LOG_LEVEL", "info")
134
+ )
135
+
136
+ # Timeout configuration prioritizes command line arguments
137
+ gunicorn_config.timeout = (
138
+ args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
139
+ )
140
+
141
+ # Keepalive configuration
142
+ gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
143
+
144
+ # SSL configuration prioritizes command line arguments
145
+ if args.ssl or os.getenv("SSL", "").lower() in (
146
+ "true",
147
+ "1",
148
+ "yes",
149
+ "t",
150
+ "on",
151
+ ):
152
+ gunicorn_config.certfile = (
153
+ args.ssl_certfile
154
+ if args.ssl_certfile
155
+ else os.getenv("SSL_CERTFILE")
156
+ )
157
+ gunicorn_config.keyfile = (
158
+ args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
159
+ )
160
+
161
+ # Set configuration options from the module
162
+ for key in dir(gunicorn_config):
163
+ if key in valid_options:
164
+ value = getattr(gunicorn_config, key)
165
+ # Skip functions like on_starting and None values
166
+ if not callable(value) and value is not None:
167
+ self.cfg.set(key, value)
168
+ # Set special hooks
169
+ elif key in special_hooks:
170
+ value = getattr(gunicorn_config, key)
171
+ if callable(value):
172
+ self.cfg.set(key, value)
173
+
174
+ if hasattr(gunicorn_config, "logconfig_dict"):
175
+ self.cfg.set(
176
+ "logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
177
+ )
178
+
179
+ def load(self):
180
+ # Import the application
181
+ from lightrag.api.lightrag_server import get_application
182
+
183
+ return get_application(args)
184
+
185
+ # Create the application
186
+ app = GunicornApp("")
187
+
188
+ # Force workers to be an integer and greater than 1 for multi-process mode
189
+ workers_count = int(args.workers)
190
+ if workers_count > 1:
191
+ # Set a flag to indicate we're in the main process
192
+ os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
193
+ initialize_share_data(workers_count)
194
+ else:
195
+ initialize_share_data(1)
196
+
197
+ # Run the application
198
+ print("\nStarting Gunicorn with direct Python API...")
199
+ app.run()
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
lightrag/api/utils_api.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import argparse
7
  from typing import Optional
8
  import sys
 
9
  from ascii_colors import ASCIIColors
10
  from lightrag.api import __api_version__
11
  from fastapi import HTTPException, Security
@@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
110
  return default
111
 
112
 
113
- def parse_args() -> argparse.Namespace:
114
  """
115
  Parse command line arguments with environment variable fallback
116
 
 
 
 
117
  Returns:
118
  argparse.Namespace: Parsed arguments
119
  """
@@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace:
260
  help="Enable automatic scanning when the program starts",
261
  )
262
 
 
 
 
 
 
 
 
 
263
  # LLM and embedding bindings
264
  parser.add_argument(
265
  "--llm-binding",
@@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace:
278
 
279
  args = parser.parse_args()
280
 
 
 
 
 
 
 
 
 
 
281
  # convert relative path to absolute path
282
  args.working_dir = os.path.abspath(args.working_dir)
283
  args.input_dir = os.path.abspath(args.input_dir)
@@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None:
346
  ASCIIColors.yellow(f"{args.host}")
347
  ASCIIColors.white(" ├─ Port: ", end="")
348
  ASCIIColors.yellow(f"{args.port}")
 
 
349
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
350
  ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
351
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
352
  ASCIIColors.yellow(f"{args.ssl}")
353
- ASCIIColors.white(" └─ API Key: ", end="")
354
- ASCIIColors.yellow("Set" if args.key else "Not Set")
355
  if args.ssl:
356
  ASCIIColors.white(" ├─ SSL Cert: ", end="")
357
  ASCIIColors.yellow(f"{args.ssl_certfile}")
358
- ASCIIColors.white(" └─ SSL Key: ", end="")
359
  ASCIIColors.yellow(f"{args.ssl_keyfile}")
 
 
 
 
 
 
 
 
 
 
360
 
361
  # Directory Configuration
362
  ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
415
  ASCIIColors.white(" └─ Document Status Storage: ", end="")
416
  ASCIIColors.yellow(f"{args.doc_status_storage}")
417
 
418
- ASCIIColors.magenta("\n🛠️ System Configuration:")
419
- ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
420
- ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
421
- ASCIIColors.white(" ├─ Log Level: ", end="")
422
- ASCIIColors.yellow(f"{args.log_level}")
423
- ASCIIColors.white(" ├─ Verbose Debug: ", end="")
424
- ASCIIColors.yellow(f"{args.verbose}")
425
- ASCIIColors.white(" └─ Timeout: ", end="")
426
- ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
427
-
428
  # Server Status
429
  ASCIIColors.green("\n✨ Server starting up...\n")
430
 
@@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
478
  ASCIIColors.cyan(""" 3. Basic Operations:
479
  - POST /upload_document: Upload new documents to RAG
480
  - POST /query: Query your document collection
481
- - GET /collections: List available collections
482
 
483
  4. Monitor the server:
484
  - Check server logs for detailed operation information
 
6
  import argparse
7
  from typing import Optional
8
  import sys
9
+ import logging
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
  from fastapi import HTTPException, Security
 
111
  return default
112
 
113
 
114
+ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
115
  """
116
  Parse command line arguments with environment variable fallback
117
 
118
+ Args:
119
+ is_uvicorn_mode: Whether running under uvicorn mode
120
+
121
  Returns:
122
  argparse.Namespace: Parsed arguments
123
  """
 
264
  help="Enable automatic scanning when the program starts",
265
  )
266
 
267
+ # Server workers configuration
268
+ parser.add_argument(
269
+ "--workers",
270
+ type=int,
271
+ default=get_env_value("WORKERS", 1, int),
272
+ help="Number of worker processes (default: from env or 1)",
273
+ )
274
+
275
  # LLM and embedding bindings
276
  parser.add_argument(
277
  "--llm-binding",
 
290
 
291
  args = parser.parse_args()
292
 
293
+ # If in uvicorn mode and workers > 1, force it to 1 and log warning
294
+ if is_uvicorn_mode and args.workers > 1:
295
+ original_workers = args.workers
296
+ args.workers = 1
297
+ # Log warning directly here
298
+ logging.warning(
299
+ f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
300
+ )
301
+
302
  # convert relative path to absolute path
303
  args.working_dir = os.path.abspath(args.working_dir)
304
  args.input_dir = os.path.abspath(args.input_dir)
 
367
  ASCIIColors.yellow(f"{args.host}")
368
  ASCIIColors.white(" ├─ Port: ", end="")
369
  ASCIIColors.yellow(f"{args.port}")
370
+ ASCIIColors.white(" ├─ Workers: ", end="")
371
+ ASCIIColors.yellow(f"{args.workers}")
372
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
373
  ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
374
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
375
  ASCIIColors.yellow(f"{args.ssl}")
 
 
376
  if args.ssl:
377
  ASCIIColors.white(" ├─ SSL Cert: ", end="")
378
  ASCIIColors.yellow(f"{args.ssl_certfile}")
379
+ ASCIIColors.white(" ├─ SSL Key: ", end="")
380
  ASCIIColors.yellow(f"{args.ssl_keyfile}")
381
+ ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
382
+ ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
383
+ ASCIIColors.white(" ├─ Log Level: ", end="")
384
+ ASCIIColors.yellow(f"{args.log_level}")
385
+ ASCIIColors.white(" ├─ Verbose Debug: ", end="")
386
+ ASCIIColors.yellow(f"{args.verbose}")
387
+ ASCIIColors.white(" ├─ Timeout: ", end="")
388
+ ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
389
+ ASCIIColors.white(" └─ API Key: ", end="")
390
+ ASCIIColors.yellow("Set" if args.key else "Not Set")
391
 
392
  # Directory Configuration
393
  ASCIIColors.magenta("\n📂 Directory Configuration:")
 
446
  ASCIIColors.white(" └─ Document Status Storage: ", end="")
447
  ASCIIColors.yellow(f"{args.doc_status_storage}")
448
 
 
 
 
 
 
 
 
 
 
 
449
  # Server Status
450
  ASCIIColors.green("\n✨ Server starting up...\n")
451
 
 
499
  ASCIIColors.cyan(""" 3. Basic Operations:
500
  - POST /upload_document: Upload new documents to RAG
501
  - POST /query: Query your document collection
 
502
 
503
  4. Monitor the server:
504
  - Check server logs for detailed operation information
lightrag/kg/faiss_impl.py CHANGED
@@ -2,25 +2,25 @@ import os
2
  import time
3
  import asyncio
4
  from typing import Any, final
5
-
6
  import json
7
  import numpy as np
8
 
9
  from dataclasses import dataclass
10
  import pipmaster as pm
11
 
12
- from lightrag.utils import (
13
- logger,
14
- compute_mdhash_id,
15
- )
16
- from lightrag.base import (
17
- BaseVectorStorage,
18
- )
19
 
20
  if not pm.is_installed("faiss"):
21
  pm.install("faiss")
22
 
23
- import faiss
 
 
 
 
 
 
24
 
25
 
26
  @final
@@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage):
55
  # If you have a large number of vectors, you might want IVF or other indexes.
56
  # For demonstration, we use a simple IndexFlatIP.
57
  self._index = faiss.IndexFlatIP(self._dim)
58
-
59
  # Keep a local store for metadata, IDs, etc.
60
  # Maps <int faiss_id> → metadata (including your original ID).
61
  self._id_to_meta = {}
62
 
63
- # Attempt to load an existing index + metadata from disk
64
  self._load_faiss_index()
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
67
  """
68
  Insert or update vectors in the Faiss index.
@@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
113
  )
114
  return []
115
 
116
- # Normalize embeddings for cosine similarity (in-place)
 
117
  faiss.normalize_L2(embeddings)
118
 
119
  # Upsert logic:
@@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
127
  existing_ids_to_remove.append(faiss_internal_id)
128
 
129
  if existing_ids_to_remove:
130
- self._remove_faiss_ids(existing_ids_to_remove)
131
 
132
  # Step 2: Add new vectors
133
- start_idx = self._index.ntotal
134
- self._index.add(embeddings)
 
135
 
136
  # Step 3: Store metadata + vector for each new ID
137
  for i, meta in enumerate(list_data):
138
  fid = start_idx + i
139
  # Store the raw vector so we can rebuild if something is removed
140
  meta["__vector__"] = embeddings[i].tolist()
141
- self._id_to_meta[fid] = meta
142
 
143
  logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
144
  return [m["__id__"] for m in list_data]
@@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
157
  )
158
 
159
  # Perform the similarity search
160
- distances, indices = self._index.search(embedding, top_k)
 
161
 
162
  distances = distances[0]
163
  indices = indices[0]
@@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
201
  to_remove.append(fid)
202
 
203
  if to_remove:
204
- self._remove_faiss_ids(to_remove)
205
- logger.info(
206
  f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
207
  )
208
 
@@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
223
 
224
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
225
  if relations:
226
- self._remove_faiss_ids(relations)
227
  logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
228
 
229
- async def index_done_callback(self) -> None:
230
- self._save_faiss_index()
231
-
232
  # --------------------------------------------------------------------------------
233
  # Internal helper methods
234
  # --------------------------------------------------------------------------------
@@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
242
  return fid
243
  return None
244
 
245
- def _remove_faiss_ids(self, fid_list):
246
  """
247
  Remove a list of internal Faiss IDs from the index.
248
  Because IndexFlatIP doesn't support 'removals',
@@ -258,13 +284,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
258
  vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
259
  new_id_to_meta[new_fid] = vec_meta
260
 
261
- # Re-init index
262
- self._index = faiss.IndexFlatIP(self._dim)
263
- if vectors_to_keep:
264
- arr = np.array(vectors_to_keep, dtype=np.float32)
265
- self._index.add(arr)
 
266
 
267
- self._id_to_meta = new_id_to_meta
268
 
269
  def _save_faiss_index(self):
270
  """
@@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
312
  logger.warning("Starting with an empty Faiss index.")
313
  self._index = faiss.IndexFlatIP(self._dim)
314
  self._id_to_meta = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import time
3
  import asyncio
4
  from typing import Any, final
 
5
  import json
6
  import numpy as np
7
 
8
  from dataclasses import dataclass
9
  import pipmaster as pm
10
 
11
+ from lightrag.utils import logger, compute_mdhash_id
12
+ from lightrag.base import BaseVectorStorage
 
 
 
 
 
13
 
14
  if not pm.is_installed("faiss"):
15
  pm.install("faiss")
16
 
17
+ import faiss # type: ignore
18
+ from .shared_storage import (
19
+ get_storage_lock,
20
+ get_update_flag,
21
+ set_all_update_flags,
22
+ is_multiprocess,
23
+ )
24
 
25
 
26
  @final
 
55
  # If you have a large number of vectors, you might want IVF or other indexes.
56
  # For demonstration, we use a simple IndexFlatIP.
57
  self._index = faiss.IndexFlatIP(self._dim)
 
58
  # Keep a local store for metadata, IDs, etc.
59
  # Maps <int faiss_id> → metadata (including your original ID).
60
  self._id_to_meta = {}
61
 
 
62
  self._load_faiss_index()
63
 
64
+ async def initialize(self):
65
+ """Initialize storage data"""
66
+ # Get the update flag for cross-process update notification
67
+ self.storage_updated = await get_update_flag(self.namespace)
68
+ # Get the storage lock for use in other methods
69
+ self._storage_lock = get_storage_lock()
70
+
71
+ async def _get_index(self):
72
+ """Check if the shtorage should be reloaded"""
73
+ # Acquire lock to prevent concurrent read and write
74
+ async with self._storage_lock:
75
+ # Check if storage was updated by another process
76
+ if (is_multiprocess and self.storage_updated.value) or (
77
+ not is_multiprocess and self.storage_updated
78
+ ):
79
+ logger.info(
80
+ f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
81
+ )
82
+ # Reload data
83
+ self._index = faiss.IndexFlatIP(self._dim)
84
+ self._id_to_meta = {}
85
+ self._load_faiss_index()
86
+ if is_multiprocess:
87
+ self.storage_updated.value = False
88
+ else:
89
+ self.storage_updated = False
90
+ return self._index
91
+
92
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
93
  """
94
  Insert or update vectors in the Faiss index.
 
139
  )
140
  return []
141
 
142
+ # Convert to float32 and normalize embeddings for cosine similarity (in-place)
143
+ embeddings = embeddings.astype(np.float32)
144
  faiss.normalize_L2(embeddings)
145
 
146
  # Upsert logic:
 
154
  existing_ids_to_remove.append(faiss_internal_id)
155
 
156
  if existing_ids_to_remove:
157
+ await self._remove_faiss_ids(existing_ids_to_remove)
158
 
159
  # Step 2: Add new vectors
160
+ index = await self._get_index()
161
+ start_idx = index.ntotal
162
+ index.add(embeddings)
163
 
164
  # Step 3: Store metadata + vector for each new ID
165
  for i, meta in enumerate(list_data):
166
  fid = start_idx + i
167
  # Store the raw vector so we can rebuild if something is removed
168
  meta["__vector__"] = embeddings[i].tolist()
169
+ self._id_to_meta.update({fid: meta})
170
 
171
  logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
172
  return [m["__id__"] for m in list_data]
 
185
  )
186
 
187
  # Perform the similarity search
188
+ index = await self._get_index()
189
+ distances, indices = index.search(embedding, top_k)
190
 
191
  distances = distances[0]
192
  indices = indices[0]
 
230
  to_remove.append(fid)
231
 
232
  if to_remove:
233
+ await self._remove_faiss_ids(to_remove)
234
+ logger.debug(
235
  f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
236
  )
237
 
 
252
 
253
  logger.debug(f"Found {len(relations)} relations for {entity_name}")
254
  if relations:
255
+ await self._remove_faiss_ids(relations)
256
  logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
257
 
 
 
 
258
  # --------------------------------------------------------------------------------
259
  # Internal helper methods
260
  # --------------------------------------------------------------------------------
 
268
  return fid
269
  return None
270
 
271
+ async def _remove_faiss_ids(self, fid_list):
272
  """
273
  Remove a list of internal Faiss IDs from the index.
274
  Because IndexFlatIP doesn't support 'removals',
 
284
  vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
285
  new_id_to_meta[new_fid] = vec_meta
286
 
287
+ async with self._storage_lock:
288
+ # Re-init index
289
+ self._index = faiss.IndexFlatIP(self._dim)
290
+ if vectors_to_keep:
291
+ arr = np.array(vectors_to_keep, dtype=np.float32)
292
+ self._index.add(arr)
293
 
294
+ self._id_to_meta = new_id_to_meta
295
 
296
  def _save_faiss_index(self):
297
  """
 
339
  logger.warning("Starting with an empty Faiss index.")
340
  self._index = faiss.IndexFlatIP(self._dim)
341
  self._id_to_meta = {}
342
+
343
+ async def index_done_callback(self) -> None:
344
+ # Check if storage was updated by another process
345
+ if is_multiprocess and self.storage_updated.value:
346
+ # Storage was updated by another process, reload data instead of saving
347
+ logger.warning(
348
+ f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
349
+ )
350
+ async with self._storage_lock:
351
+ self._index = faiss.IndexFlatIP(self._dim)
352
+ self._id_to_meta = {}
353
+ self._load_faiss_index()
354
+ self.storage_updated.value = False
355
+ return False # Return error
356
+
357
+ # Acquire lock and perform persistence
358
+ async with self._storage_lock:
359
+ try:
360
+ # Save data to disk
361
+ self._save_faiss_index()
362
+ # Notify other processes that data has been updated
363
+ await set_all_update_flags(self.namespace)
364
+ # Reset own update flag to avoid self-reloading
365
+ if is_multiprocess:
366
+ self.storage_updated.value = False
367
+ else:
368
+ self.storage_updated = False
369
+ except Exception as e:
370
+ logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
371
+ return False # Return error
372
+
373
+ return True # Return success
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -12,6 +12,11 @@ from lightrag.utils import (
12
  logger,
13
  write_json,
14
  )
 
 
 
 
 
15
 
16
 
17
  @final
@@ -22,26 +27,42 @@ class JsonDocStatusStorage(DocStatusStorage):
22
  def __post_init__(self):
23
  working_dir = self.global_config["working_dir"]
24
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
25
- self._data: dict[str, Any] = load_json(self._file_name) or {}
26
- logger.info(f"Loaded document status storage with {len(self._data)} records")
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  async def filter_keys(self, keys: set[str]) -> set[str]:
29
  """Return keys that should be processed (not in storage or not successfully processed)"""
30
- return set(keys) - set(self._data.keys())
 
31
 
32
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
33
  result: list[dict[str, Any]] = []
34
- for id in ids:
35
- data = self._data.get(id, None)
36
- if data:
37
- result.append(data)
 
38
  return result
39
 
40
  async def get_status_counts(self) -> dict[str, int]:
41
  """Get counts of documents in each status"""
42
  counts = {status.value: 0 for status in DocStatus}
43
- for doc in self._data.values():
44
- counts[doc["status"]] += 1
 
45
  return counts
46
 
47
  async def get_docs_by_status(
@@ -49,39 +70,48 @@ class JsonDocStatusStorage(DocStatusStorage):
49
  ) -> dict[str, DocProcessingStatus]:
50
  """Get all documents with a specific status"""
51
  result = {}
52
- for k, v in self._data.items():
53
- if v["status"] == status.value:
54
- try:
55
- # Make a copy of the data to avoid modifying the original
56
- data = v.copy()
57
- # If content is missing, use content_summary as content
58
- if "content" not in data and "content_summary" in data:
59
- data["content"] = data["content_summary"]
60
- result[k] = DocProcessingStatus(**data)
61
- except KeyError as e:
62
- logger.error(f"Missing required field for document {k}: {e}")
63
- continue
 
64
  return result
65
 
66
  async def index_done_callback(self) -> None:
67
- write_json(self._data, self._file_name)
 
 
 
 
68
 
69
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
70
  logger.info(f"Inserting {len(data)} to {self.namespace}")
71
  if not data:
72
  return
73
 
74
- self._data.update(data)
 
75
  await self.index_done_callback()
76
 
77
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
78
- return self._data.get(id)
 
79
 
80
  async def delete(self, doc_ids: list[str]):
81
- for doc_id in doc_ids:
82
- self._data.pop(doc_id, None)
 
83
  await self.index_done_callback()
84
 
85
  async def drop(self) -> None:
86
  """Drop the storage"""
87
- self._data.clear()
 
 
12
  logger,
13
  write_json,
14
  )
15
+ from .shared_storage import (
16
+ get_namespace_data,
17
+ get_storage_lock,
18
+ try_initialize_namespace,
19
+ )
20
 
21
 
22
  @final
 
27
  def __post_init__(self):
28
  working_dir = self.global_config["working_dir"]
29
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
30
+ self._storage_lock = get_storage_lock()
31
+ self._data = None
32
+
33
+ async def initialize(self):
34
+ """Initialize storage data"""
35
+ # check need_init must before get_namespace_data
36
+ need_init = try_initialize_namespace(self.namespace)
37
+ self._data = await get_namespace_data(self.namespace)
38
+ if need_init:
39
+ loaded_data = load_json(self._file_name) or {}
40
+ async with self._storage_lock:
41
+ self._data.update(loaded_data)
42
+ logger.info(
43
+ f"Loaded document status storage with {len(loaded_data)} records"
44
+ )
45
 
46
  async def filter_keys(self, keys: set[str]) -> set[str]:
47
  """Return keys that should be processed (not in storage or not successfully processed)"""
48
+ async with self._storage_lock:
49
+ return set(keys) - set(self._data.keys())
50
 
51
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
52
  result: list[dict[str, Any]] = []
53
+ async with self._storage_lock:
54
+ for id in ids:
55
+ data = self._data.get(id, None)
56
+ if data:
57
+ result.append(data)
58
  return result
59
 
60
  async def get_status_counts(self) -> dict[str, int]:
61
  """Get counts of documents in each status"""
62
  counts = {status.value: 0 for status in DocStatus}
63
+ async with self._storage_lock:
64
+ for doc in self._data.values():
65
+ counts[doc["status"]] += 1
66
  return counts
67
 
68
  async def get_docs_by_status(
 
70
  ) -> dict[str, DocProcessingStatus]:
71
  """Get all documents with a specific status"""
72
  result = {}
73
+ async with self._storage_lock:
74
+ for k, v in self._data.items():
75
+ if v["status"] == status.value:
76
+ try:
77
+ # Make a copy of the data to avoid modifying the original
78
+ data = v.copy()
79
+ # If content is missing, use content_summary as content
80
+ if "content" not in data and "content_summary" in data:
81
+ data["content"] = data["content_summary"]
82
+ result[k] = DocProcessingStatus(**data)
83
+ except KeyError as e:
84
+ logger.error(f"Missing required field for document {k}: {e}")
85
+ continue
86
  return result
87
 
88
  async def index_done_callback(self) -> None:
89
+ async with self._storage_lock:
90
+ data_dict = (
91
+ dict(self._data) if hasattr(self._data, "_getvalue") else self._data
92
+ )
93
+ write_json(data_dict, self._file_name)
94
 
95
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
96
  logger.info(f"Inserting {len(data)} to {self.namespace}")
97
  if not data:
98
  return
99
 
100
+ async with self._storage_lock:
101
+ self._data.update(data)
102
  await self.index_done_callback()
103
 
104
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
105
+ async with self._storage_lock:
106
+ return self._data.get(id)
107
 
108
  async def delete(self, doc_ids: list[str]):
109
+ async with self._storage_lock:
110
+ for doc_id in doc_ids:
111
+ self._data.pop(doc_id, None)
112
  await self.index_done_callback()
113
 
114
  async def drop(self) -> None:
115
  """Drop the storage"""
116
+ async with self._storage_lock:
117
+ self._data.clear()
lightrag/kg/json_kv_impl.py CHANGED
@@ -1,4 +1,3 @@
1
- import asyncio
2
  import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
@@ -11,6 +10,11 @@ from lightrag.utils import (
11
  logger,
12
  write_json,
13
  )
 
 
 
 
 
14
 
15
 
16
  @final
@@ -19,37 +23,56 @@ class JsonKVStorage(BaseKVStorage):
19
  def __post_init__(self):
20
  working_dir = self.global_config["working_dir"]
21
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
22
- self._data: dict[str, Any] = load_json(self._file_name) or {}
23
- self._lock = asyncio.Lock()
24
- logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
 
 
 
 
 
 
 
 
 
 
25
 
26
  async def index_done_callback(self) -> None:
27
- write_json(self._data, self._file_name)
 
 
 
 
28
 
29
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
30
- return self._data.get(id)
 
31
 
32
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
33
- return [
34
- (
35
- {k: v for k, v in self._data[id].items()}
36
- if self._data.get(id, None)
37
- else None
38
- )
39
- for id in ids
40
- ]
 
41
 
42
  async def filter_keys(self, keys: set[str]) -> set[str]:
43
- return set(keys) - set(self._data.keys())
 
44
 
45
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
46
  logger.info(f"Inserting {len(data)} to {self.namespace}")
47
  if not data:
48
  return
49
- left_data = {k: v for k, v in data.items() if k not in self._data}
50
- self._data.update(left_data)
 
51
 
52
  async def delete(self, ids: list[str]) -> None:
53
- for doc_id in ids:
54
- self._data.pop(doc_id, None)
 
55
  await self.index_done_callback()
 
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
 
10
  logger,
11
  write_json,
12
  )
13
+ from .shared_storage import (
14
+ get_namespace_data,
15
+ get_storage_lock,
16
+ try_initialize_namespace,
17
+ )
18
 
19
 
20
  @final
 
23
  def __post_init__(self):
24
  working_dir = self.global_config["working_dir"]
25
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
26
+ self._storage_lock = get_storage_lock()
27
+ self._data = None
28
+
29
+ async def initialize(self):
30
+ """Initialize storage data"""
31
+ # check need_init must before get_namespace_data
32
+ need_init = try_initialize_namespace(self.namespace)
33
+ self._data = await get_namespace_data(self.namespace)
34
+ if need_init:
35
+ loaded_data = load_json(self._file_name) or {}
36
+ async with self._storage_lock:
37
+ self._data.update(loaded_data)
38
+ logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
39
 
40
  async def index_done_callback(self) -> None:
41
+ async with self._storage_lock:
42
+ data_dict = (
43
+ dict(self._data) if hasattr(self._data, "_getvalue") else self._data
44
+ )
45
+ write_json(data_dict, self._file_name)
46
 
47
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
48
+ async with self._storage_lock:
49
+ return self._data.get(id)
50
 
51
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
52
+ async with self._storage_lock:
53
+ return [
54
+ (
55
+ {k: v for k, v in self._data[id].items()}
56
+ if self._data.get(id, None)
57
+ else None
58
+ )
59
+ for id in ids
60
+ ]
61
 
62
  async def filter_keys(self, keys: set[str]) -> set[str]:
63
+ async with self._storage_lock:
64
+ return set(keys) - set(self._data.keys())
65
 
66
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
67
  logger.info(f"Inserting {len(data)} to {self.namespace}")
68
  if not data:
69
  return
70
+ async with self._storage_lock:
71
+ left_data = {k: v for k, v in data.items() if k not in self._data}
72
+ self._data.update(left_data)
73
 
74
  async def delete(self, ids: list[str]) -> None:
75
+ async with self._storage_lock:
76
+ for doc_id in ids:
77
+ self._data.pop(doc_id, None)
78
  await self.index_done_callback()
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
6
-
7
  import time
8
 
9
  from lightrag.utils import (
@@ -11,22 +10,29 @@ from lightrag.utils import (
11
  compute_mdhash_id,
12
  )
13
  import pipmaster as pm
14
- from lightrag.base import (
15
- BaseVectorStorage,
16
- )
17
 
18
  if not pm.is_installed("nano-vectordb"):
19
  pm.install("nano-vectordb")
20
 
21
  from nano_vectordb import NanoVectorDB
 
 
 
 
 
 
22
 
23
 
24
  @final
25
  @dataclass
26
  class NanoVectorDBStorage(BaseVectorStorage):
27
  def __post_init__(self):
28
- # Initialize lock only for file operations
29
- self._save_lock = asyncio.Lock()
 
 
 
30
  # Use global config value if specified, otherwise use default
31
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
32
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
40
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
41
  )
42
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
43
  self._client = NanoVectorDB(
44
- self.embedding_func.embedding_dim, storage_file=self._client_file_name
 
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
48
  logger.info(f"Inserting {len(data)} to {self.namespace}")
49
  if not data:
@@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
64
  for i in range(0, len(contents), self._max_batch_size)
65
  ]
66
 
 
67
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
68
  embeddings_list = await asyncio.gather(*embedding_tasks)
69
 
@@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
71
  if len(embeddings) == len(list_data):
72
  for i, d in enumerate(list_data):
73
  d["__vector__"] = embeddings[i]
74
- results = self._client.upsert(datas=list_data)
 
75
  return results
76
  else:
77
  # sometimes the embedding is not returned correctly. just log it.
@@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
80
  )
81
 
82
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
 
83
  embedding = await self.embedding_func([query])
84
  embedding = embedding[0]
85
- results = self._client.query(
 
 
86
  query=embedding,
87
  top_k=top_k,
88
  better_than_threshold=self.cosine_better_than_threshold,
@@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
99
  return results
100
 
101
  @property
102
- def client_storage(self):
103
- return getattr(self._client, "_NanoVectorDB__storage")
 
104
 
105
  async def delete(self, ids: list[str]):
106
  """Delete vectors with specified IDs
@@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
109
  ids: List of vector IDs to be deleted
110
  """
111
  try:
112
- self._client.delete(ids)
113
- logger.info(
 
114
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
115
  )
116
  except Exception as e:
@@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
122
  logger.debug(
123
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
124
  )
 
125
  # Check if the entity exists
126
- if self._client.get([entity_id]):
127
- await self.delete([entity_id])
 
128
  logger.debug(f"Successfully deleted entity {entity_name}")
129
  else:
130
  logger.debug(f"Entity {entity_name} not found in storage")
@@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
133
 
134
  async def delete_entity_relation(self, entity_name: str) -> None:
135
  try:
 
 
136
  relations = [
137
  dp
138
- for dp in self.client_storage["data"]
139
  if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
140
  ]
141
  logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
142
  ids_to_delete = [relation["__id__"] for relation in relations]
143
 
144
  if ids_to_delete:
145
- await self.delete(ids_to_delete)
 
146
  logger.debug(
147
  f"Deleted {len(ids_to_delete)} relations for {entity_name}"
148
  )
@@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
151
  except Exception as e:
152
  logger.error(f"Error deleting relations for {entity_name}: {e}")
153
 
154
- async def index_done_callback(self) -> None:
155
- async with self._save_lock:
156
- self._client.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
 
6
  import time
7
 
8
  from lightrag.utils import (
 
10
  compute_mdhash_id,
11
  )
12
  import pipmaster as pm
13
+ from lightrag.base import BaseVectorStorage
 
 
14
 
15
  if not pm.is_installed("nano-vectordb"):
16
  pm.install("nano-vectordb")
17
 
18
  from nano_vectordb import NanoVectorDB
19
+ from .shared_storage import (
20
+ get_storage_lock,
21
+ get_update_flag,
22
+ set_all_update_flags,
23
+ is_multiprocess,
24
+ )
25
 
26
 
27
  @final
28
  @dataclass
29
  class NanoVectorDBStorage(BaseVectorStorage):
30
  def __post_init__(self):
31
+ # Initialize basic attributes
32
+ self._client = None
33
+ self._storage_lock = None
34
+ self.storage_updated = None
35
+
36
  # Use global config value if specified, otherwise use default
37
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
38
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
46
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
47
  )
48
  self._max_batch_size = self.global_config["embedding_batch_num"]
49
+
50
  self._client = NanoVectorDB(
51
+ self.embedding_func.embedding_dim,
52
+ storage_file=self._client_file_name,
53
  )
54
 
55
+ async def initialize(self):
56
+ """Initialize storage data"""
57
+ # Get the update flag for cross-process update notification
58
+ self.storage_updated = await get_update_flag(self.namespace)
59
+ # Get the storage lock for use in other methods
60
+ self._storage_lock = get_storage_lock()
61
+
62
+ async def _get_client(self):
63
+ """Check if the storage should be reloaded"""
64
+ # Acquire lock to prevent concurrent read and write
65
+ async with self._storage_lock:
66
+ # Check if data needs to be reloaded
67
+ if (is_multiprocess and self.storage_updated.value) or (
68
+ not is_multiprocess and self.storage_updated
69
+ ):
70
+ logger.info(
71
+ f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
72
+ )
73
+ # Reload data
74
+ self._client = NanoVectorDB(
75
+ self.embedding_func.embedding_dim,
76
+ storage_file=self._client_file_name,
77
+ )
78
+ # Reset update flag
79
+ if is_multiprocess:
80
+ self.storage_updated.value = False
81
+ else:
82
+ self.storage_updated = False
83
+
84
+ return self._client
85
+
86
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
87
  logger.info(f"Inserting {len(data)} to {self.namespace}")
88
  if not data:
 
103
  for i in range(0, len(contents), self._max_batch_size)
104
  ]
105
 
106
+ # Execute embedding outside of lock to avoid long lock times
107
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
108
  embeddings_list = await asyncio.gather(*embedding_tasks)
109
 
 
111
  if len(embeddings) == len(list_data):
112
  for i, d in enumerate(list_data):
113
  d["__vector__"] = embeddings[i]
114
+ client = await self._get_client()
115
+ results = client.upsert(datas=list_data)
116
  return results
117
  else:
118
  # sometimes the embedding is not returned correctly. just log it.
 
121
  )
122
 
123
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
124
+ # Execute embedding outside of lock to avoid long lock times
125
  embedding = await self.embedding_func([query])
126
  embedding = embedding[0]
127
+
128
+ client = await self._get_client()
129
+ results = client.query(
130
  query=embedding,
131
  top_k=top_k,
132
  better_than_threshold=self.cosine_better_than_threshold,
 
143
  return results
144
 
145
  @property
146
+ async def client_storage(self):
147
+ client = await self._get_client()
148
+ return getattr(client, "_NanoVectorDB__storage")
149
 
150
  async def delete(self, ids: list[str]):
151
  """Delete vectors with specified IDs
 
154
  ids: List of vector IDs to be deleted
155
  """
156
  try:
157
+ client = await self._get_client()
158
+ client.delete(ids)
159
+ logger.debug(
160
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
161
  )
162
  except Exception as e:
 
168
  logger.debug(
169
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
170
  )
171
+
172
  # Check if the entity exists
173
+ client = await self._get_client()
174
+ if client.get([entity_id]):
175
+ client.delete([entity_id])
176
  logger.debug(f"Successfully deleted entity {entity_name}")
177
  else:
178
  logger.debug(f"Entity {entity_name} not found in storage")
 
181
 
182
  async def delete_entity_relation(self, entity_name: str) -> None:
183
  try:
184
+ client = await self._get_client()
185
+ storage = getattr(client, "_NanoVectorDB__storage")
186
  relations = [
187
  dp
188
+ for dp in storage["data"]
189
  if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
190
  ]
191
  logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
192
  ids_to_delete = [relation["__id__"] for relation in relations]
193
 
194
  if ids_to_delete:
195
+ client = await self._get_client()
196
+ client.delete(ids_to_delete)
197
  logger.debug(
198
  f"Deleted {len(ids_to_delete)} relations for {entity_name}"
199
  )
 
202
  except Exception as e:
203
  logger.error(f"Error deleting relations for {entity_name}: {e}")
204
 
205
+ async def index_done_callback(self) -> bool:
206
+ """Save data to disk"""
207
+ # Check if storage was updated by another process
208
+ if is_multiprocess and self.storage_updated.value:
209
+ # Storage was updated by another process, reload data instead of saving
210
+ logger.warning(
211
+ f"Storage for {self.namespace} was updated by another process, reloading..."
212
+ )
213
+ self._client = NanoVectorDB(
214
+ self.embedding_func.embedding_dim,
215
+ storage_file=self._client_file_name,
216
+ )
217
+ # Reset update flag
218
+ self.storage_updated.value = False
219
+ return False # Return error
220
+
221
+ # Acquire lock and perform persistence
222
+ async with self._storage_lock:
223
+ try:
224
+ # Save data to disk
225
+ self._client.save()
226
+ # Notify other processes that data has been updated
227
+ await set_all_update_flags(self.namespace)
228
+ # Reset own update flag to avoid self-reloading
229
+ if is_multiprocess:
230
+ self.storage_updated.value = False
231
+ else:
232
+ self.storage_updated = False
233
+ return True # Return success
234
+ except Exception as e:
235
+ logger.error(f"Error saving data for {self.namespace}: {e}")
236
+ return False # Return error
237
+
238
+ return True # Return success
lightrag/kg/networkx_impl.py CHANGED
@@ -1,18 +1,12 @@
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
4
-
5
  import numpy as np
6
 
7
-
8
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
9
- from lightrag.utils import (
10
- logger,
11
- )
12
 
13
- from lightrag.base import (
14
- BaseGraphStorage,
15
- )
16
  import pipmaster as pm
17
 
18
  if not pm.is_installed("networkx"):
@@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"):
23
 
24
  import networkx as nx
25
  from graspologic import embed
 
 
 
 
 
 
26
 
27
 
28
  @final
@@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage):
78
  self._graphml_xml_file = os.path.join(
79
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
80
  )
 
 
 
 
 
81
  preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
82
  if preloaded_graph is not None:
83
  logger.info(
84
  f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
85
  )
 
 
86
  self._graph = preloaded_graph or nx.Graph()
 
87
  self._node_embed_algorithms = {
88
  "node2vec": self._node2vec_embed,
89
  }
90
 
91
- async def index_done_callback(self) -> None:
92
- NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  async def has_node(self, node_id: str) -> bool:
95
- return self._graph.has_node(node_id)
 
96
 
97
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
98
- return self._graph.has_edge(source_node_id, target_node_id)
 
99
 
100
  async def get_node(self, node_id: str) -> dict[str, str] | None:
101
- return self._graph.nodes.get(node_id)
 
102
 
103
  async def node_degree(self, node_id: str) -> int:
104
- return self._graph.degree(node_id)
 
105
 
106
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
107
- return self._graph.degree(src_id) + self._graph.degree(tgt_id)
 
108
 
109
  async def get_edge(
110
  self, source_node_id: str, target_node_id: str
111
  ) -> dict[str, str] | None:
112
- return self._graph.edges.get((source_node_id, target_node_id))
 
113
 
114
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
115
- if self._graph.has_node(source_node_id):
116
- return list(self._graph.edges(source_node_id))
 
117
  return None
118
 
119
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
120
- self._graph.add_node(node_id, **node_data)
 
121
 
122
  async def upsert_edge(
123
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
124
  ) -> None:
125
- self._graph.add_edge(source_node_id, target_node_id, **edge_data)
 
126
 
127
  async def delete_node(self, node_id: str) -> None:
128
- if self._graph.has_node(node_id):
129
- self._graph.remove_node(node_id)
130
- logger.info(f"Node {node_id} deleted from the graph.")
 
131
  else:
132
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
133
 
@@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage):
138
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
139
  return await self._node_embed_algorithms[algorithm]()
140
 
141
- # @TODO: NOT USED
142
  async def _node2vec_embed(self):
 
143
  embeddings, nodes = embed.node2vec_embed(
144
- self._graph,
145
  **self.global_config["node2vec_params"],
146
  )
147
-
148
- nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
149
  return embeddings, nodes_ids
150
 
151
- def remove_nodes(self, nodes: list[str]):
152
  """Delete multiple nodes
153
 
154
  Args:
155
  nodes: List of node IDs to be deleted
156
  """
 
157
  for node in nodes:
158
- if self._graph.has_node(node):
159
- self._graph.remove_node(node)
160
 
161
- def remove_edges(self, edges: list[tuple[str, str]]):
162
  """Delete multiple edges
163
 
164
  Args:
165
  edges: List of edges to be deleted, each edge is a (source, target) tuple
166
  """
 
167
  for source, target in edges:
168
- if self._graph.has_edge(source, target):
169
- self._graph.remove_edge(source, target)
170
 
171
  async def get_all_labels(self) -> list[str]:
172
  """
@@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage):
174
  Returns:
175
  [label1, label2, ...] # Alphabetically sorted label list
176
  """
 
177
  labels = set()
178
- for node in self._graph.nodes():
179
  labels.add(str(node)) # Add node id as a label
180
 
181
  # Return sorted list
@@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage):
198
  seen_nodes = set()
199
  seen_edges = set()
200
 
 
 
201
  # Handle special case for "*" label
202
  if node_label == "*":
203
  # For "*", return the entire graph including all nodes and edges
204
  subgraph = (
205
- self._graph.copy()
206
  ) # Create a copy to avoid modifying the original graph
207
  else:
208
  # Find nodes with matching node id (partial match)
209
  nodes_to_explore = []
210
- for n, attr in self._graph.nodes(data=True):
211
  if node_label in str(n): # Use partial matching
212
  nodes_to_explore.append(n)
213
 
@@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage):
216
  return result
217
 
218
  # Get subgraph using ego_graph
219
- subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
220
 
221
  # Check if number of nodes exceeds max_graph_nodes
222
  max_graph_nodes = 500
@@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage):
278
  )
279
  seen_edges.add(edge_id)
280
 
281
- # logger.info(result.edges)
282
-
283
  logger.info(
284
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
285
  )
286
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
 
4
  import numpy as np
5
 
 
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
+ from lightrag.utils import logger
8
+ from lightrag.base import BaseGraphStorage
 
9
 
 
 
 
10
  import pipmaster as pm
11
 
12
  if not pm.is_installed("networkx"):
 
17
 
18
  import networkx as nx
19
  from graspologic import embed
20
+ from .shared_storage import (
21
+ get_storage_lock,
22
+ get_update_flag,
23
+ set_all_update_flags,
24
+ is_multiprocess,
25
+ )
26
 
27
 
28
  @final
 
78
  self._graphml_xml_file = os.path.join(
79
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
80
  )
81
+ self._storage_lock = None
82
+ self.storage_updated = None
83
+ self._graph = None
84
+
85
+ # Load initial graph
86
  preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
87
  if preloaded_graph is not None:
88
  logger.info(
89
  f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
90
  )
91
+ else:
92
+ logger.info("Created new empty graph")
93
  self._graph = preloaded_graph or nx.Graph()
94
+
95
  self._node_embed_algorithms = {
96
  "node2vec": self._node2vec_embed,
97
  }
98
 
99
+ async def initialize(self):
100
+ """Initialize storage data"""
101
+ # Get the update flag for cross-process update notification
102
+ self.storage_updated = await get_update_flag(self.namespace)
103
+ # Get the storage lock for use in other methods
104
+ self._storage_lock = get_storage_lock()
105
+
106
+ async def _get_graph(self):
107
+ """Check if the storage should be reloaded"""
108
+ # Acquire lock to prevent concurrent read and write
109
+ async with self._storage_lock:
110
+ # Check if data needs to be reloaded
111
+ if (is_multiprocess and self.storage_updated.value) or (
112
+ not is_multiprocess and self.storage_updated
113
+ ):
114
+ logger.info(
115
+ f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
116
+ )
117
+ # Reload data
118
+ self._graph = (
119
+ NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
120
+ )
121
+ # Reset update flag
122
+ if is_multiprocess:
123
+ self.storage_updated.value = False
124
+ else:
125
+ self.storage_updated = False
126
+
127
+ return self._graph
128
 
129
  async def has_node(self, node_id: str) -> bool:
130
+ graph = await self._get_graph()
131
+ return graph.has_node(node_id)
132
 
133
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
134
+ graph = await self._get_graph()
135
+ return graph.has_edge(source_node_id, target_node_id)
136
 
137
  async def get_node(self, node_id: str) -> dict[str, str] | None:
138
+ graph = await self._get_graph()
139
+ return graph.nodes.get(node_id)
140
 
141
  async def node_degree(self, node_id: str) -> int:
142
+ graph = await self._get_graph()
143
+ return graph.degree(node_id)
144
 
145
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
146
+ graph = await self._get_graph()
147
+ return graph.degree(src_id) + graph.degree(tgt_id)
148
 
149
  async def get_edge(
150
  self, source_node_id: str, target_node_id: str
151
  ) -> dict[str, str] | None:
152
+ graph = await self._get_graph()
153
+ return graph.edges.get((source_node_id, target_node_id))
154
 
155
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
156
+ graph = await self._get_graph()
157
+ if graph.has_node(source_node_id):
158
+ return list(graph.edges(source_node_id))
159
  return None
160
 
161
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
162
+ graph = await self._get_graph()
163
+ graph.add_node(node_id, **node_data)
164
 
165
  async def upsert_edge(
166
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
167
  ) -> None:
168
+ graph = await self._get_graph()
169
+ graph.add_edge(source_node_id, target_node_id, **edge_data)
170
 
171
  async def delete_node(self, node_id: str) -> None:
172
+ graph = await self._get_graph()
173
+ if graph.has_node(node_id):
174
+ graph.remove_node(node_id)
175
+ logger.debug(f"Node {node_id} deleted from the graph.")
176
  else:
177
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
178
 
 
183
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
184
  return await self._node_embed_algorithms[algorithm]()
185
 
186
+ # TODO: NOT USED
187
  async def _node2vec_embed(self):
188
+ graph = await self._get_graph()
189
  embeddings, nodes = embed.node2vec_embed(
190
+ graph,
191
  **self.global_config["node2vec_params"],
192
  )
193
+ nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
 
194
  return embeddings, nodes_ids
195
 
196
+ async def remove_nodes(self, nodes: list[str]):
197
  """Delete multiple nodes
198
 
199
  Args:
200
  nodes: List of node IDs to be deleted
201
  """
202
+ graph = await self._get_graph()
203
  for node in nodes:
204
+ if graph.has_node(node):
205
+ graph.remove_node(node)
206
 
207
+ async def remove_edges(self, edges: list[tuple[str, str]]):
208
  """Delete multiple edges
209
 
210
  Args:
211
  edges: List of edges to be deleted, each edge is a (source, target) tuple
212
  """
213
+ graph = await self._get_graph()
214
  for source, target in edges:
215
+ if graph.has_edge(source, target):
216
+ graph.remove_edge(source, target)
217
 
218
  async def get_all_labels(self) -> list[str]:
219
  """
 
221
  Returns:
222
  [label1, label2, ...] # Alphabetically sorted label list
223
  """
224
+ graph = await self._get_graph()
225
  labels = set()
226
+ for node in graph.nodes():
227
  labels.add(str(node)) # Add node id as a label
228
 
229
  # Return sorted list
 
246
  seen_nodes = set()
247
  seen_edges = set()
248
 
249
+ graph = await self._get_graph()
250
+
251
  # Handle special case for "*" label
252
  if node_label == "*":
253
  # For "*", return the entire graph including all nodes and edges
254
  subgraph = (
255
+ graph.copy()
256
  ) # Create a copy to avoid modifying the original graph
257
  else:
258
  # Find nodes with matching node id (partial match)
259
  nodes_to_explore = []
260
+ for n, attr in graph.nodes(data=True):
261
  if node_label in str(n): # Use partial matching
262
  nodes_to_explore.append(n)
263
 
 
266
  return result
267
 
268
  # Get subgraph using ego_graph
269
+ subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
270
 
271
  # Check if number of nodes exceeds max_graph_nodes
272
  max_graph_nodes = 500
 
328
  )
329
  seen_edges.add(edge_id)
330
 
 
 
331
  logger.info(
332
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
333
  )
334
  return result
335
+
336
+ async def index_done_callback(self) -> bool:
337
+ """Save data to disk"""
338
+ # Check if storage was updated by another process
339
+ if is_multiprocess and self.storage_updated.value:
340
+ # Storage was updated by another process, reload data instead of saving
341
+ logger.warning(
342
+ f"Graph for {self.namespace} was updated by another process, reloading..."
343
+ )
344
+ self._graph = (
345
+ NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
346
+ )
347
+ # Reset update flag
348
+ self.storage_updated.value = False
349
+ return False # Return error
350
+
351
+ # Acquire lock and perform persistence
352
+ async with self._storage_lock:
353
+ try:
354
+ # Save data to disk
355
+ NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
356
+ # Notify other processes that data has been updated
357
+ await set_all_update_flags(self.namespace)
358
+ # Reset own update flag to avoid self-reloading
359
+ if is_multiprocess:
360
+ self.storage_updated.value = False
361
+ else:
362
+ self.storage_updated = False
363
+ return True # Return success
364
+ except Exception as e:
365
+ logger.error(f"Error saving graph for {self.namespace}: {e}")
366
+ return False # Return error
367
+
368
+ return True
lightrag/kg/postgres_impl.py CHANGED
@@ -38,8 +38,8 @@ import pipmaster as pm
38
  if not pm.is_installed("asyncpg"):
39
  pm.install("asyncpg")
40
 
41
- import asyncpg
42
- from asyncpg import Pool
43
 
44
 
45
  class PostgreSQLDB:
 
38
  if not pm.is_installed("asyncpg"):
39
  pm.install("asyncpg")
40
 
41
+ import asyncpg # type: ignore
42
+ from asyncpg import Pool # type: ignore
43
 
44
 
45
  class PostgreSQLDB:
lightrag/kg/shared_storage.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import asyncio
4
+ from multiprocessing.synchronize import Lock as ProcessLock
5
+ from multiprocessing import Manager
6
+ from typing import Any, Dict, Optional, Union, TypeVar, Generic
7
+
8
+
9
+ # Define a direct print function for critical logs that must be visible in all processes
10
+ def direct_log(message, level="INFO"):
11
+ """
12
+ Log a message directly to stderr to ensure visibility in all processes,
13
+ including the Gunicorn master process.
14
+ """
15
+ print(f"{level}: {message}", file=sys.stderr, flush=True)
16
+
17
+
18
+ T = TypeVar("T")
19
+ LockType = Union[ProcessLock, asyncio.Lock]
20
+
21
+ is_multiprocess = None
22
+ _workers = None
23
+ _manager = None
24
+ _initialized = None
25
+
26
+ # shared data for storage across processes
27
+ _shared_dicts: Optional[Dict[str, Any]] = None
28
+ _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
29
+ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
30
+
31
+ # locks for mutex access
32
+ _storage_lock: Optional[LockType] = None
33
+ _internal_lock: Optional[LockType] = None
34
+ _pipeline_status_lock: Optional[LockType] = None
35
+
36
+
37
+ class UnifiedLock(Generic[T]):
38
+ """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
39
+
40
+ def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
41
+ self._lock = lock
42
+ self._is_async = is_async
43
+
44
+ async def __aenter__(self) -> "UnifiedLock[T]":
45
+ if self._is_async:
46
+ await self._lock.acquire()
47
+ else:
48
+ self._lock.acquire()
49
+ return self
50
+
51
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
52
+ if self._is_async:
53
+ self._lock.release()
54
+ else:
55
+ self._lock.release()
56
+
57
+ def __enter__(self) -> "UnifiedLock[T]":
58
+ """For backward compatibility"""
59
+ if self._is_async:
60
+ raise RuntimeError("Use 'async with' for shared_storage lock")
61
+ self._lock.acquire()
62
+ return self
63
+
64
+ def __exit__(self, exc_type, exc_val, exc_tb):
65
+ """For backward compatibility"""
66
+ if self._is_async:
67
+ raise RuntimeError("Use 'async with' for shared_storage lock")
68
+ self._lock.release()
69
+
70
+
71
+ def get_internal_lock() -> UnifiedLock:
72
+ """return unified storage lock for data consistency"""
73
+ return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
74
+
75
+
76
+ def get_storage_lock() -> UnifiedLock:
77
+ """return unified storage lock for data consistency"""
78
+ return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
79
+
80
+
81
+ def get_pipeline_status_lock() -> UnifiedLock:
82
+ """return unified storage lock for data consistency"""
83
+ return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
84
+
85
+
86
+ def initialize_share_data(workers: int = 1):
87
+ """
88
+ Initialize shared storage data for single or multi-process mode.
89
+
90
+ When used with Gunicorn's preload feature, this function is called once in the
91
+ master process before forking worker processes, allowing all workers to share
92
+ the same initialized data.
93
+
94
+ In single-process mode, this function is called in FASTAPI lifespan function.
95
+
96
+ The function determines whether to use cross-process shared variables for data storage
97
+ based on the number of workers. If workers=1, it uses thread locks and local dictionaries.
98
+ If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager.
99
+
100
+ Args:
101
+ workers (int): Number of worker processes. If 1, single-process mode is used.
102
+ If > 1, multi-process mode with shared memory is used.
103
+ """
104
+ global \
105
+ _manager, \
106
+ _workers, \
107
+ is_multiprocess, \
108
+ _storage_lock, \
109
+ _internal_lock, \
110
+ _pipeline_status_lock, \
111
+ _shared_dicts, \
112
+ _init_flags, \
113
+ _initialized, \
114
+ _update_flags
115
+
116
+ # Check if already initialized
117
+ if _initialized:
118
+ direct_log(
119
+ f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
120
+ )
121
+ return
122
+
123
+ _manager = Manager()
124
+ _workers = workers
125
+
126
+ if workers > 1:
127
+ is_multiprocess = True
128
+ _internal_lock = _manager.Lock()
129
+ _storage_lock = _manager.Lock()
130
+ _pipeline_status_lock = _manager.Lock()
131
+ _shared_dicts = _manager.dict()
132
+ _init_flags = _manager.dict()
133
+ _update_flags = _manager.dict()
134
+ direct_log(
135
+ f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
136
+ )
137
+ else:
138
+ is_multiprocess = False
139
+ _internal_lock = asyncio.Lock()
140
+ _storage_lock = asyncio.Lock()
141
+ _pipeline_status_lock = asyncio.Lock()
142
+ _shared_dicts = {}
143
+ _init_flags = {}
144
+ _update_flags = {}
145
+ direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
146
+
147
+ # Mark as initialized
148
+ _initialized = True
149
+
150
+
151
+ async def initialize_pipeline_status():
152
+ """
153
+ Initialize pipeline namespace with default values.
154
+ This function is called during FASTAPI lifespan for each worker.
155
+ """
156
+ pipeline_namespace = await get_namespace_data("pipeline_status")
157
+
158
+ async with get_internal_lock():
159
+ # Check if already initialized by checking for required fields
160
+ if "busy" in pipeline_namespace:
161
+ return
162
+
163
+ # Create a shared list object for history_messages
164
+ history_messages = _manager.list() if is_multiprocess else []
165
+ pipeline_namespace.update(
166
+ {
167
+ "busy": False, # Control concurrent processes
168
+ "job_name": "Default Job", # Current job name (indexing files/indexing texts)
169
+ "job_start": None, # Job start time
170
+ "docs": 0, # Total number of documents to be indexed
171
+ "batchs": 0, # Number of batches for processing documents
172
+ "cur_batch": 0, # Current processing batch
173
+ "request_pending": False, # Flag for pending request for processing
174
+ "latest_message": "", # Latest message from pipeline processing
175
+ "history_messages": history_messages, # 使用共享列表对象
176
+ }
177
+ )
178
+ direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
179
+
180
+
181
+ async def get_update_flag(namespace: str):
182
+ """
183
+ Create a namespace's update flag for a workers.
184
+ Returen the update flag to caller for referencing or reset.
185
+ """
186
+ global _update_flags
187
+ if _update_flags is None:
188
+ raise ValueError("Try to create namespace before Shared-Data is initialized")
189
+
190
+ async with get_internal_lock():
191
+ if namespace not in _update_flags:
192
+ if is_multiprocess and _manager is not None:
193
+ _update_flags[namespace] = _manager.list()
194
+ else:
195
+ _update_flags[namespace] = []
196
+ direct_log(
197
+ f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
198
+ )
199
+
200
+ if is_multiprocess and _manager is not None:
201
+ new_update_flag = _manager.Value("b", False)
202
+ else:
203
+ new_update_flag = False
204
+
205
+ _update_flags[namespace].append(new_update_flag)
206
+ return new_update_flag
207
+
208
+
209
+ async def set_all_update_flags(namespace: str):
210
+ """Set all update flag of namespace indicating all workers need to reload data from files"""
211
+ global _update_flags
212
+ if _update_flags is None:
213
+ raise ValueError("Try to create namespace before Shared-Data is initialized")
214
+
215
+ async with get_internal_lock():
216
+ if namespace not in _update_flags:
217
+ raise ValueError(f"Namespace {namespace} not found in update flags")
218
+ # Update flags for both modes
219
+ for i in range(len(_update_flags[namespace])):
220
+ if is_multiprocess:
221
+ _update_flags[namespace][i].value = True
222
+ else:
223
+ _update_flags[namespace][i] = True
224
+
225
+
226
+ async def get_all_update_flags_status() -> Dict[str, list]:
227
+ """
228
+ Get update flags status for all namespaces.
229
+
230
+ Returns:
231
+ Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
232
+ """
233
+ if _update_flags is None:
234
+ return {}
235
+
236
+ result = {}
237
+ async with get_internal_lock():
238
+ for namespace, flags in _update_flags.items():
239
+ worker_statuses = []
240
+ for flag in flags:
241
+ if is_multiprocess:
242
+ worker_statuses.append(flag.value)
243
+ else:
244
+ worker_statuses.append(flag)
245
+ result[namespace] = worker_statuses
246
+
247
+ return result
248
+
249
+
250
+ def try_initialize_namespace(namespace: str) -> bool:
251
+ """
252
+ Returns True if the current worker(process) gets initialization permission for loading data later.
253
+ The worker does not get the permission is prohibited to load data from files.
254
+ """
255
+ global _init_flags, _manager
256
+
257
+ if _init_flags is None:
258
+ raise ValueError("Try to create nanmespace before Shared-Data is initialized")
259
+
260
+ if namespace not in _init_flags:
261
+ _init_flags[namespace] = True
262
+ direct_log(
263
+ f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
264
+ )
265
+ return True
266
+ direct_log(
267
+ f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
268
+ )
269
+ return False
270
+
271
+
272
+ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
273
+ """get the shared data reference for specific namespace"""
274
+ if _shared_dicts is None:
275
+ direct_log(
276
+ f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
277
+ level="ERROR",
278
+ )
279
+ raise ValueError("Shared dictionaries not initialized")
280
+
281
+ async with get_internal_lock():
282
+ if namespace not in _shared_dicts:
283
+ if is_multiprocess and _manager is not None:
284
+ _shared_dicts[namespace] = _manager.dict()
285
+ else:
286
+ _shared_dicts[namespace] = {}
287
+
288
+ return _shared_dicts[namespace]
289
+
290
+
291
+ def finalize_share_data():
292
+ """
293
+ Release shared resources and clean up.
294
+
295
+ This function should be called when the application is shutting down
296
+ to properly release shared resources and avoid memory leaks.
297
+
298
+ In multi-process mode, it shuts down the Manager and releases all shared objects.
299
+ In single-process mode, it simply resets the global variables.
300
+ """
301
+ global \
302
+ _manager, \
303
+ is_multiprocess, \
304
+ _storage_lock, \
305
+ _internal_lock, \
306
+ _pipeline_status_lock, \
307
+ _shared_dicts, \
308
+ _init_flags, \
309
+ _initialized, \
310
+ _update_flags
311
+
312
+ # Check if already initialized
313
+ if not _initialized:
314
+ direct_log(
315
+ f"Process {os.getpid()} storage data not initialized, nothing to finalize"
316
+ )
317
+ return
318
+
319
+ direct_log(
320
+ f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
321
+ )
322
+
323
+ # In multi-process mode, shut down the Manager
324
+ if is_multiprocess and _manager is not None:
325
+ try:
326
+ # Clear shared resources before shutting down Manager
327
+ if _shared_dicts is not None:
328
+ # Clear pipeline status history messages first if exists
329
+ try:
330
+ pipeline_status = _shared_dicts.get("pipeline_status", {})
331
+ if "history_messages" in pipeline_status:
332
+ pipeline_status["history_messages"].clear()
333
+ except Exception:
334
+ pass # Ignore any errors during history messages cleanup
335
+ _shared_dicts.clear()
336
+ if _init_flags is not None:
337
+ _init_flags.clear()
338
+ if _update_flags is not None:
339
+ # Clear each namespace's update flags list and Value objects
340
+ try:
341
+ for namespace in _update_flags:
342
+ flags_list = _update_flags[namespace]
343
+ if isinstance(flags_list, list):
344
+ # Clear Value objects in the list
345
+ for flag in flags_list:
346
+ if hasattr(
347
+ flag, "value"
348
+ ): # Check if it's a Value object
349
+ flag.value = False
350
+ flags_list.clear()
351
+ except Exception:
352
+ pass # Ignore any errors during update flags cleanup
353
+ _update_flags.clear()
354
+
355
+ # Shut down the Manager - this will automatically clean up all shared resources
356
+ _manager.shutdown()
357
+ direct_log(f"Process {os.getpid()} Manager shutdown complete")
358
+ except Exception as e:
359
+ direct_log(
360
+ f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR"
361
+ )
362
+
363
+ # Reset global variables
364
+ _manager = None
365
+ _initialized = None
366
+ is_multiprocess = None
367
+ _shared_dicts = None
368
+ _init_flags = None
369
+ _storage_lock = None
370
+ _internal_lock = None
371
+ _pipeline_status_lock = None
372
+ _update_flags = None
373
+
374
+ direct_log(f"Process {os.getpid()} storage data finalization complete")
lightrag/lightrag.py CHANGED
@@ -45,7 +45,6 @@ from .utils import (
45
  lazy_external_import,
46
  limit_async_func_call,
47
  logger,
48
- set_logger,
49
  )
50
  from .types import KnowledgeGraph
51
  from dotenv import load_dotenv
@@ -268,9 +267,14 @@ class LightRAG:
268
 
269
  def __post_init__(self):
270
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
271
- set_logger(self.log_file_path, self.log_level)
272
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
273
 
 
 
 
 
 
 
274
  if not os.path.exists(self.working_dir):
275
  logger.info(f"Creating working directory {self.working_dir}")
276
  os.makedirs(self.working_dir)
@@ -692,117 +696,221 @@ class LightRAG:
692
  3. Process each chunk for entity and relation extraction
693
  4. Update the document status
694
  """
695
- # 1. Get all pending, failed, and abnormally terminated processing documents.
696
- # Run the asynchronous status retrievals in parallel using asyncio.gather
697
- processing_docs, failed_docs, pending_docs = await asyncio.gather(
698
- self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
699
- self.doc_status.get_docs_by_status(DocStatus.FAILED),
700
- self.doc_status.get_docs_by_status(DocStatus.PENDING),
701
  )
702
 
703
- to_process_docs: dict[str, DocProcessingStatus] = {}
704
- to_process_docs.update(processing_docs)
705
- to_process_docs.update(failed_docs)
706
- to_process_docs.update(pending_docs)
 
 
 
 
 
 
 
 
 
 
707
 
708
- if not to_process_docs:
709
- logger.info("All documents have been processed or are duplicates")
710
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
- # 2. split docs into chunks, insert chunks, update doc status
713
- docs_batches = [
714
- list(to_process_docs.items())[i : i + self.max_parallel_insert]
715
- for i in range(0, len(to_process_docs), self.max_parallel_insert)
716
- ]
 
 
 
 
 
 
 
 
 
 
717
 
718
- logger.info(f"Number of batches to process: {len(docs_batches)}.")
719
-
720
- batches: list[Any] = []
721
- # 3. iterate over batches
722
- for batch_idx, docs_batch in enumerate(docs_batches):
723
-
724
- async def batch(
725
- batch_idx: int,
726
- docs_batch: list[tuple[str, DocProcessingStatus]],
727
- size_batch: int,
728
- ) -> None:
729
- logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
730
- # 4. iterate over batch
731
- for doc_id_processing_status in docs_batch:
732
- doc_id, status_doc = doc_id_processing_status
733
- # Generate chunks from document
734
- chunks: dict[str, Any] = {
735
- compute_mdhash_id(dp["content"], prefix="chunk-"): {
736
- **dp,
737
- "full_doc_id": doc_id,
738
- }
739
- for dp in self.chunking_func(
740
- status_doc.content,
741
- split_by_character,
742
- split_by_character_only,
743
- self.chunk_overlap_token_size,
744
- self.chunk_token_size,
745
- self.tiktoken_model_name,
746
- )
747
- }
748
- # Process document (text chunks and full docs) in parallel
749
- tasks = [
750
- self.doc_status.upsert(
751
- {
752
- doc_id: {
753
- "status": DocStatus.PROCESSING,
754
- "updated_at": datetime.now().isoformat(),
755
- "content": status_doc.content,
756
- "content_summary": status_doc.content_summary,
757
- "content_length": status_doc.content_length,
758
- "created_at": status_doc.created_at,
759
- }
760
- }
761
- ),
762
- self.chunks_vdb.upsert(chunks),
763
- self._process_entity_relation_graph(chunks),
764
- self.full_docs.upsert(
765
- {doc_id: {"content": status_doc.content}}
766
- ),
767
- self.text_chunks.upsert(chunks),
768
- ]
769
- try:
770
- await asyncio.gather(*tasks)
771
- await self.doc_status.upsert(
772
- {
773
- doc_id: {
774
- "status": DocStatus.PROCESSED,
775
- "chunks_count": len(chunks),
776
- "content": status_doc.content,
777
- "content_summary": status_doc.content_summary,
778
- "content_length": status_doc.content_length,
779
- "created_at": status_doc.created_at,
780
- "updated_at": datetime.now().isoformat(),
781
- }
782
- }
783
  )
784
- except Exception as e:
785
- logger.error(f"Failed to process document {doc_id}: {str(e)}")
786
- await self.doc_status.upsert(
787
- {
788
- doc_id: {
789
- "status": DocStatus.FAILED,
790
- "error": str(e),
791
- "content": status_doc.content,
792
- "content_summary": status_doc.content_summary,
793
- "content_length": status_doc.content_length,
794
- "created_at": status_doc.created_at,
795
- "updated_at": datetime.now().isoformat(),
796
  }
 
 
 
 
 
 
 
 
797
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  )
799
- continue
800
- logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
 
 
 
801
 
802
- batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
- await asyncio.gather(*batches)
805
- await self._insert_done()
 
 
 
 
 
 
806
 
807
  async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
808
  try:
@@ -833,7 +941,16 @@ class LightRAG:
833
  if storage_inst is not None
834
  ]
835
  await asyncio.gather(*tasks)
836
- logger.info("All Insert done")
 
 
 
 
 
 
 
 
 
837
 
838
  def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
839
  loop = always_get_an_event_loop()
 
45
  lazy_external_import,
46
  limit_async_func_call,
47
  logger,
 
48
  )
49
  from .types import KnowledgeGraph
50
  from dotenv import load_dotenv
 
267
 
268
  def __post_init__(self):
269
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
 
270
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
271
 
272
+ from lightrag.kg.shared_storage import (
273
+ initialize_share_data,
274
+ )
275
+
276
+ initialize_share_data()
277
+
278
  if not os.path.exists(self.working_dir):
279
  logger.info(f"Creating working directory {self.working_dir}")
280
  os.makedirs(self.working_dir)
 
696
  3. Process each chunk for entity and relation extraction
697
  4. Update the document status
698
  """
699
+ from lightrag.kg.shared_storage import (
700
+ get_namespace_data,
701
+ get_pipeline_status_lock,
 
 
 
702
  )
703
 
704
+ # Get pipeline status shared data and lock
705
+ pipeline_status = await get_namespace_data("pipeline_status")
706
+ pipeline_status_lock = get_pipeline_status_lock()
707
+
708
+ # Check if another process is already processing the queue
709
+ async with pipeline_status_lock:
710
+ # Ensure only one worker is processing documents
711
+ if not pipeline_status.get("busy", False):
712
+ # 先检查是否有需要处理的文档
713
+ processing_docs, failed_docs, pending_docs = await asyncio.gather(
714
+ self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
715
+ self.doc_status.get_docs_by_status(DocStatus.FAILED),
716
+ self.doc_status.get_docs_by_status(DocStatus.PENDING),
717
+ )
718
 
719
+ to_process_docs: dict[str, DocProcessingStatus] = {}
720
+ to_process_docs.update(processing_docs)
721
+ to_process_docs.update(failed_docs)
722
+ to_process_docs.update(pending_docs)
723
+
724
+ # 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
725
+ if not to_process_docs:
726
+ logger.info("No documents to process")
727
+ return
728
+
729
+ # 有文档需要处理,更新 pipeline_status
730
+ pipeline_status.update(
731
+ {
732
+ "busy": True,
733
+ "job_name": "indexing files",
734
+ "job_start": datetime.now().isoformat(),
735
+ "docs": 0,
736
+ "batchs": 0,
737
+ "cur_batch": 0,
738
+ "request_pending": False, # Clear any previous request
739
+ "latest_message": "",
740
+ }
741
+ )
742
+ # Cleaning history_messages without breaking it as a shared list object
743
+ del pipeline_status["history_messages"][:]
744
+ else:
745
+ # Another process is busy, just set request flag and return
746
+ pipeline_status["request_pending"] = True
747
+ logger.info(
748
+ "Another process is already processing the document queue. Request queued."
749
+ )
750
+ return
751
 
752
+ try:
753
+ # Process documents until no more documents or requests
754
+ while True:
755
+ if not to_process_docs:
756
+ log_message = "All documents have been processed or are duplicates"
757
+ logger.info(log_message)
758
+ pipeline_status["latest_message"] = log_message
759
+ pipeline_status["history_messages"].append(log_message)
760
+ break
761
+
762
+ # 2. split docs into chunks, insert chunks, update doc status
763
+ docs_batches = [
764
+ list(to_process_docs.items())[i : i + self.max_parallel_insert]
765
+ for i in range(0, len(to_process_docs), self.max_parallel_insert)
766
+ ]
767
 
768
+ log_message = f"Number of batches to process: {len(docs_batches)}."
769
+ logger.info(log_message)
770
+
771
+ # Update pipeline status with current batch information
772
+ pipeline_status["docs"] += len(to_process_docs)
773
+ pipeline_status["batchs"] += len(docs_batches)
774
+ pipeline_status["latest_message"] = log_message
775
+ pipeline_status["history_messages"].append(log_message)
776
+
777
+ batches: list[Any] = []
778
+ # 3. iterate over batches
779
+ for batch_idx, docs_batch in enumerate(docs_batches):
780
+ # Update current batch in pipeline status (directly, as it's atomic)
781
+ pipeline_status["cur_batch"] += 1
782
+
783
+ async def batch(
784
+ batch_idx: int,
785
+ docs_batch: list[tuple[str, DocProcessingStatus]],
786
+ size_batch: int,
787
+ ) -> None:
788
+ log_message = (
789
+ f"Start processing batch {batch_idx + 1} of {size_batch}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  )
791
+ logger.info(log_message)
792
+ pipeline_status["latest_message"] = log_message
793
+ pipeline_status["history_messages"].append(log_message)
794
+ # 4. iterate over batch
795
+ for doc_id_processing_status in docs_batch:
796
+ doc_id, status_doc = doc_id_processing_status
797
+ # Generate chunks from document
798
+ chunks: dict[str, Any] = {
799
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
800
+ **dp,
801
+ "full_doc_id": doc_id,
 
802
  }
803
+ for dp in self.chunking_func(
804
+ status_doc.content,
805
+ split_by_character,
806
+ split_by_character_only,
807
+ self.chunk_overlap_token_size,
808
+ self.chunk_token_size,
809
+ self.tiktoken_model_name,
810
+ )
811
  }
812
+ # Process document (text chunks and full docs) in parallel
813
+ tasks = [
814
+ self.doc_status.upsert(
815
+ {
816
+ doc_id: {
817
+ "status": DocStatus.PROCESSING,
818
+ "updated_at": datetime.now().isoformat(),
819
+ "content": status_doc.content,
820
+ "content_summary": status_doc.content_summary,
821
+ "content_length": status_doc.content_length,
822
+ "created_at": status_doc.created_at,
823
+ }
824
+ }
825
+ ),
826
+ self.chunks_vdb.upsert(chunks),
827
+ self._process_entity_relation_graph(chunks),
828
+ self.full_docs.upsert(
829
+ {doc_id: {"content": status_doc.content}}
830
+ ),
831
+ self.text_chunks.upsert(chunks),
832
+ ]
833
+ try:
834
+ await asyncio.gather(*tasks)
835
+ await self.doc_status.upsert(
836
+ {
837
+ doc_id: {
838
+ "status": DocStatus.PROCESSED,
839
+ "chunks_count": len(chunks),
840
+ "content": status_doc.content,
841
+ "content_summary": status_doc.content_summary,
842
+ "content_length": status_doc.content_length,
843
+ "created_at": status_doc.created_at,
844
+ "updated_at": datetime.now().isoformat(),
845
+ }
846
+ }
847
+ )
848
+ except Exception as e:
849
+ logger.error(
850
+ f"Failed to process document {doc_id}: {str(e)}"
851
+ )
852
+ await self.doc_status.upsert(
853
+ {
854
+ doc_id: {
855
+ "status": DocStatus.FAILED,
856
+ "error": str(e),
857
+ "content": status_doc.content,
858
+ "content_summary": status_doc.content_summary,
859
+ "content_length": status_doc.content_length,
860
+ "created_at": status_doc.created_at,
861
+ "updated_at": datetime.now().isoformat(),
862
+ }
863
+ }
864
+ )
865
+ continue
866
+ log_message = (
867
+ f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
868
  )
869
+ logger.info(log_message)
870
+ pipeline_status["latest_message"] = log_message
871
+ pipeline_status["history_messages"].append(log_message)
872
+
873
+ batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
874
 
875
+ await asyncio.gather(*batches)
876
+ await self._insert_done()
877
+
878
+ # Check if there's a pending request to process more documents (with lock)
879
+ has_pending_request = False
880
+ async with pipeline_status_lock:
881
+ has_pending_request = pipeline_status.get("request_pending", False)
882
+ if has_pending_request:
883
+ # Clear the request flag before checking for more documents
884
+ pipeline_status["request_pending"] = False
885
+
886
+ if not has_pending_request:
887
+ break
888
+
889
+ log_message = "Processing additional documents due to pending request"
890
+ logger.info(log_message)
891
+ pipeline_status["latest_message"] = log_message
892
+ pipeline_status["history_messages"].append(log_message)
893
+
894
+ # 获取新的待处理文档
895
+ processing_docs, failed_docs, pending_docs = await asyncio.gather(
896
+ self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
897
+ self.doc_status.get_docs_by_status(DocStatus.FAILED),
898
+ self.doc_status.get_docs_by_status(DocStatus.PENDING),
899
+ )
900
+
901
+ to_process_docs = {}
902
+ to_process_docs.update(processing_docs)
903
+ to_process_docs.update(failed_docs)
904
+ to_process_docs.update(pending_docs)
905
 
906
+ finally:
907
+ log_message = "Document processing pipeline completed"
908
+ logger.info(log_message)
909
+ # Always reset busy status when done or if an exception occurs (with lock)
910
+ async with pipeline_status_lock:
911
+ pipeline_status["busy"] = False
912
+ pipeline_status["latest_message"] = log_message
913
+ pipeline_status["history_messages"].append(log_message)
914
 
915
  async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
916
  try:
 
941
  if storage_inst is not None
942
  ]
943
  await asyncio.gather(*tasks)
944
+
945
+ log_message = "All Insert done"
946
+ logger.info(log_message)
947
+
948
+ # 获取 pipeline_status 并更新 latest_message 和 history_messages
949
+ from lightrag.kg.shared_storage import get_namespace_data
950
+
951
+ pipeline_status = await get_namespace_data("pipeline_status")
952
+ pipeline_status["latest_message"] = log_message
953
+ pipeline_status["history_messages"].append(log_message)
954
 
955
  def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
956
  loop = always_get_an_event_loop()
lightrag/operate.py CHANGED
@@ -339,6 +339,9 @@ async def extract_entities(
339
  global_config: dict[str, str],
340
  llm_response_cache: BaseKVStorage | None = None,
341
  ) -> None:
 
 
 
342
  use_llm_func: callable = global_config["llm_model_func"]
343
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
344
  enable_llm_cache_for_entity_extract: bool = global_config[
@@ -499,9 +502,10 @@ async def extract_entities(
499
  processed_chunks += 1
500
  entities_count = len(maybe_nodes)
501
  relations_count = len(maybe_edges)
502
- logger.info(
503
- f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
504
- )
 
505
  return dict(maybe_nodes), dict(maybe_edges)
506
 
507
  tasks = [_process_single_content(c) for c in ordered_chunks]
@@ -530,17 +534,27 @@ async def extract_entities(
530
  )
531
 
532
  if not (all_entities_data or all_relationships_data):
533
- logger.info("Didn't extract any entities and relationships.")
 
 
 
534
  return
535
 
536
  if not all_entities_data:
537
- logger.info("Didn't extract any entities")
 
 
 
538
  if not all_relationships_data:
539
- logger.info("Didn't extract any relationships")
540
-
541
- logger.info(
542
- f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
543
- )
 
 
 
 
544
  verbose_debug(
545
  f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
546
  )
 
339
  global_config: dict[str, str],
340
  llm_response_cache: BaseKVStorage | None = None,
341
  ) -> None:
342
+ from lightrag.kg.shared_storage import get_namespace_data
343
+
344
+ pipeline_status = await get_namespace_data("pipeline_status")
345
  use_llm_func: callable = global_config["llm_model_func"]
346
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
347
  enable_llm_cache_for_entity_extract: bool = global_config[
 
502
  processed_chunks += 1
503
  entities_count = len(maybe_nodes)
504
  relations_count = len(maybe_edges)
505
+ log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
506
+ logger.info(log_message)
507
+ pipeline_status["latest_message"] = log_message
508
+ pipeline_status["history_messages"].append(log_message)
509
  return dict(maybe_nodes), dict(maybe_edges)
510
 
511
  tasks = [_process_single_content(c) for c in ordered_chunks]
 
534
  )
535
 
536
  if not (all_entities_data or all_relationships_data):
537
+ log_message = "Didn't extract any entities and relationships."
538
+ logger.info(log_message)
539
+ pipeline_status["latest_message"] = log_message
540
+ pipeline_status["history_messages"].append(log_message)
541
  return
542
 
543
  if not all_entities_data:
544
+ log_message = "Didn't extract any entities"
545
+ logger.info(log_message)
546
+ pipeline_status["latest_message"] = log_message
547
+ pipeline_status["history_messages"].append(log_message)
548
  if not all_relationships_data:
549
+ log_message = "Didn't extract any relationships"
550
+ logger.info(log_message)
551
+ pipeline_status["latest_message"] = log_message
552
+ pipeline_status["history_messages"].append(log_message)
553
+
554
+ log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
555
+ logger.info(log_message)
556
+ pipeline_status["latest_message"] = log_message
557
+ pipeline_status["history_messages"].append(log_message)
558
  verbose_debug(
559
  f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
560
  )
lightrag/utils.py CHANGED
@@ -56,45 +56,29 @@ def set_verbose_debug(enabled: bool):
56
  VERBOSE_DEBUG = enabled
57
 
58
 
59
- class UnlimitedSemaphore:
60
- """A context manager that allows unlimited access."""
61
-
62
- async def __aenter__(self):
63
- pass
64
-
65
- async def __aexit__(self, exc_type, exc, tb):
66
- pass
67
-
68
-
69
- ENCODER = None
70
-
71
  statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
72
 
 
73
  logger = logging.getLogger("lightrag")
 
 
 
74
 
75
  # Set httpx logging level to WARNING
76
  logging.getLogger("httpx").setLevel(logging.WARNING)
77
 
78
 
79
- def set_logger(log_file: str, level: int = logging.DEBUG):
80
- """Set up file logging with the specified level.
81
 
82
- Args:
83
- log_file: Path to the log file
84
- level: Logging level (e.g. logging.DEBUG, logging.INFO)
85
- """
86
- logger.setLevel(level)
87
 
88
- file_handler = logging.FileHandler(log_file, encoding="utf-8")
89
- file_handler.setLevel(level)
90
 
91
- formatter = logging.Formatter(
92
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
93
- )
94
- file_handler.setFormatter(formatter)
95
 
96
- if not logger.handlers:
97
- logger.addHandler(file_handler)
98
 
99
 
100
  @dataclass
 
56
  VERBOSE_DEBUG = enabled
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
60
 
61
+ # Initialize logger
62
  logger = logging.getLogger("lightrag")
63
+ logger.propagate = False # prevent log message send to root loggger
64
+ # Let the main application configure the handlers
65
+ logger.setLevel(logging.INFO)
66
 
67
  # Set httpx logging level to WARNING
68
  logging.getLogger("httpx").setLevel(logging.WARNING)
69
 
70
 
71
+ class UnlimitedSemaphore:
72
+ """A context manager that allows unlimited access."""
73
 
74
+ async def __aenter__(self):
75
+ pass
 
 
 
76
 
77
+ async def __aexit__(self, exc_type, exc, tb):
78
+ pass
79
 
 
 
 
 
80
 
81
+ ENCODER = None
 
82
 
83
 
84
  @dataclass
run_with_gunicorn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Start LightRAG server with Gunicorn
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import signal
9
+ import pipmaster as pm
10
+ from lightrag.api.utils_api import parse_args, display_splash_screen
11
+ from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
+
13
+
14
+ def check_and_install_dependencies():
15
+ """Check and install required dependencies"""
16
+ required_packages = [
17
+ "gunicorn",
18
+ "tiktoken",
19
+ "psutil",
20
+ # Add other required packages here
21
+ ]
22
+
23
+ for package in required_packages:
24
+ if not pm.is_installed(package):
25
+ print(f"Installing {package}...")
26
+ pm.install(package)
27
+ print(f"{package} installed successfully")
28
+
29
+
30
+ # Signal handler for graceful shutdown
31
+ def signal_handler(sig, frame):
32
+ print("\n\n" + "=" * 80)
33
+ print("RECEIVED TERMINATION SIGNAL")
34
+ print(f"Process ID: {os.getpid()}")
35
+ print("=" * 80 + "\n")
36
+
37
+ # Release shared resources
38
+ finalize_share_data()
39
+
40
+ # Exit with success status
41
+ sys.exit(0)
42
+
43
+
44
+ def main():
45
+ # Check and install dependencies
46
+ check_and_install_dependencies()
47
+
48
+ # Register signal handlers for graceful shutdown
49
+ signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
50
+ signal.signal(signal.SIGTERM, signal_handler) # kill command
51
+
52
+ # Parse all arguments using parse_args
53
+ args = parse_args(is_uvicorn_mode=False)
54
+
55
+ # Display startup information
56
+ display_splash_screen(args)
57
+
58
+ print("🚀 Starting LightRAG with Gunicorn")
59
+ print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
60
+ print("🔍 Preloading app: Enabled")
61
+ print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
62
+ print("\n\n" + "=" * 80)
63
+ print("MAIN PROCESS INITIALIZATION")
64
+ print(f"Process ID: {os.getpid()}")
65
+ print(f"Workers setting: {args.workers}")
66
+ print("=" * 80 + "\n")
67
+
68
+ # Import Gunicorn's StandaloneApplication
69
+ from gunicorn.app.base import BaseApplication
70
+
71
+ # Define a custom application class that loads our config
72
+ class GunicornApp(BaseApplication):
73
+ def __init__(self, app, options=None):
74
+ self.options = options or {}
75
+ self.application = app
76
+ super().__init__()
77
+
78
+ def load_config(self):
79
+ # Define valid Gunicorn configuration options
80
+ valid_options = {
81
+ "bind",
82
+ "workers",
83
+ "worker_class",
84
+ "timeout",
85
+ "keepalive",
86
+ "preload_app",
87
+ "errorlog",
88
+ "accesslog",
89
+ "loglevel",
90
+ "certfile",
91
+ "keyfile",
92
+ "limit_request_line",
93
+ "limit_request_fields",
94
+ "limit_request_field_size",
95
+ "graceful_timeout",
96
+ "max_requests",
97
+ "max_requests_jitter",
98
+ }
99
+
100
+ # Special hooks that need to be set separately
101
+ special_hooks = {
102
+ "on_starting",
103
+ "on_reload",
104
+ "on_exit",
105
+ "pre_fork",
106
+ "post_fork",
107
+ "pre_exec",
108
+ "pre_request",
109
+ "post_request",
110
+ "worker_init",
111
+ "worker_exit",
112
+ "nworkers_changed",
113
+ "child_exit",
114
+ }
115
+
116
+ # Import and configure the gunicorn_config module
117
+ import gunicorn_config
118
+
119
+ # Set configuration variables in gunicorn_config, prioritizing command line arguments
120
+ gunicorn_config.workers = (
121
+ args.workers if args.workers else int(os.getenv("WORKERS", 1))
122
+ )
123
+
124
+ # Bind configuration prioritizes command line arguments
125
+ host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
126
+ port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
127
+ gunicorn_config.bind = f"{host}:{port}"
128
+
129
+ # Log level configuration prioritizes command line arguments
130
+ gunicorn_config.loglevel = (
131
+ args.log_level.lower()
132
+ if args.log_level
133
+ else os.getenv("LOG_LEVEL", "info")
134
+ )
135
+
136
+ # Timeout configuration prioritizes command line arguments
137
+ gunicorn_config.timeout = (
138
+ args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
139
+ )
140
+
141
+ # Keepalive configuration
142
+ gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
143
+
144
+ # SSL configuration prioritizes command line arguments
145
+ if args.ssl or os.getenv("SSL", "").lower() in (
146
+ "true",
147
+ "1",
148
+ "yes",
149
+ "t",
150
+ "on",
151
+ ):
152
+ gunicorn_config.certfile = (
153
+ args.ssl_certfile
154
+ if args.ssl_certfile
155
+ else os.getenv("SSL_CERTFILE")
156
+ )
157
+ gunicorn_config.keyfile = (
158
+ args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
159
+ )
160
+
161
+ # Set configuration options from the module
162
+ for key in dir(gunicorn_config):
163
+ if key in valid_options:
164
+ value = getattr(gunicorn_config, key)
165
+ # Skip functions like on_starting and None values
166
+ if not callable(value) and value is not None:
167
+ self.cfg.set(key, value)
168
+ # Set special hooks
169
+ elif key in special_hooks:
170
+ value = getattr(gunicorn_config, key)
171
+ if callable(value):
172
+ self.cfg.set(key, value)
173
+
174
+ if hasattr(gunicorn_config, "logconfig_dict"):
175
+ self.cfg.set(
176
+ "logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
177
+ )
178
+
179
+ def load(self):
180
+ # Import the application
181
+ from lightrag.api.lightrag_server import get_application
182
+
183
+ return get_application(args)
184
+
185
+ # Create the application
186
+ app = GunicornApp("")
187
+
188
+ # Force workers to be an integer and greater than 1 for multi-process mode
189
+ workers_count = int(args.workers)
190
+ if workers_count > 1:
191
+ # Set a flag to indicate we're in the main process
192
+ os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
193
+ initialize_share_data(workers_count)
194
+ else:
195
+ initialize_share_data(1)
196
+
197
+ # Run the application
198
+ print("\nStarting Gunicorn with direct Python API...")
199
+ app.run()
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
setup.py CHANGED
@@ -112,6 +112,7 @@ setuptools.setup(
112
  entry_points={
113
  "console_scripts": [
114
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
 
115
  "lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
116
  ],
117
  },
 
112
  entry_points={
113
  "console_scripts": [
114
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
115
+ "lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]",
116
  "lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
117
  ],
118
  },