Merge pull request #969 from danielaskdd/add-multi-worker-support
Browse files- .gitignore +1 -0
- .env.example → env.example +9 -0
- lightrag/api/README.md +37 -2
- lightrag/api/gunicorn_config.py +187 -0
- lightrag/api/lightrag_server.py +141 -69
- lightrag/api/routers/document_routes.py +65 -89
- lightrag/api/run_with_gunicorn.py +203 -0
- lightrag/api/utils_api.py +35 -15
- lightrag/kg/faiss_impl.py +89 -30
- lightrag/kg/json_doc_status_impl.py +57 -27
- lightrag/kg/json_kv_impl.py +42 -19
- lightrag/kg/nano_vector_db_impl.py +102 -20
- lightrag/kg/networkx_impl.py +121 -39
- lightrag/kg/postgres_impl.py +2 -2
- lightrag/kg/shared_storage.py +374 -0
- lightrag/lightrag.py +220 -103
- lightrag/operate.py +24 -10
- lightrag/utils.py +11 -27
- run_with_gunicorn.py +203 -0
- setup.py +1 -0
.gitignore
CHANGED
@@ -21,6 +21,7 @@ site/
|
|
21 |
|
22 |
# Logs / Reports
|
23 |
*.log
|
|
|
24 |
*.logfire
|
25 |
*.coverage/
|
26 |
log/
|
|
|
21 |
|
22 |
# Logs / Reports
|
23 |
*.log
|
24 |
+
*.log.*
|
25 |
*.logfire
|
26 |
*.coverage/
|
27 |
log/
|
.env.example → env.example
RENAMED
@@ -1,6 +1,9 @@
|
|
|
|
|
|
1 |
### Server Configuration
|
2 |
# HOST=0.0.0.0
|
3 |
# PORT=9621
|
|
|
4 |
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
5 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
6 |
|
@@ -22,6 +25,9 @@
|
|
22 |
### Logging level
|
23 |
# LOG_LEVEL=INFO
|
24 |
# VERBOSE=False
|
|
|
|
|
|
|
25 |
|
26 |
### Max async calls for LLM
|
27 |
# MAX_ASYNC=4
|
@@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
|
|
138 |
### Qdrant
|
139 |
QDRANT_URL=http://localhost:16333
|
140 |
# QDRANT_API_KEY=your-api-key
|
|
|
|
|
|
|
|
1 |
+
### This is sample file of .env
|
2 |
+
|
3 |
### Server Configuration
|
4 |
# HOST=0.0.0.0
|
5 |
# PORT=9621
|
6 |
+
# WORKERS=1
|
7 |
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
8 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
9 |
|
|
|
25 |
### Logging level
|
26 |
# LOG_LEVEL=INFO
|
27 |
# VERBOSE=False
|
28 |
+
# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory
|
29 |
+
# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB
|
30 |
+
# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5
|
31 |
|
32 |
### Max async calls for LLM
|
33 |
# MAX_ASYNC=4
|
|
|
144 |
### Qdrant
|
145 |
QDRANT_URL=http://localhost:16333
|
146 |
# QDRANT_API_KEY=your-api-key
|
147 |
+
|
148 |
+
### Redis
|
149 |
+
REDIS_URI=redis://localhost:6379
|
lightrag/api/README.md
CHANGED
@@ -24,6 +24,8 @@ pip install -e ".[api]"
|
|
24 |
|
25 |
### Starting API Server with Default Settings
|
26 |
|
|
|
|
|
27 |
LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
|
28 |
|
29 |
* ollama
|
@@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama
|
|
92 |
LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
|
93 |
|
94 |
# start with ollama llm and ollama embedding (no apikey is needed)
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
```
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
### For Azure OpenAI Backend
|
|
|
99 |
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
|
100 |
```bash
|
101 |
# Change the resource group name, location and OpenAI resource name as needed
|
@@ -186,7 +221,7 @@ LightRAG supports binding to various LLM/Embedding backends:
|
|
186 |
* openai & openai compatible
|
187 |
* azure_openai
|
188 |
|
189 |
-
Use environment variables `LLM_BINDING
|
190 |
|
191 |
### Storage Types Supported
|
192 |
|
|
|
24 |
|
25 |
### Starting API Server with Default Settings
|
26 |
|
27 |
+
After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server`
|
28 |
+
|
29 |
LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
|
30 |
|
31 |
* ollama
|
|
|
94 |
LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
|
95 |
|
96 |
# start with ollama llm and ollama embedding (no apikey is needed)
|
97 |
+
light-server --llm-binding ollama --embedding-binding ollama
|
98 |
+
```
|
99 |
+
|
100 |
+
### Starting API Server with Gunicorn (Production)
|
101 |
+
|
102 |
+
For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities.
|
103 |
+
|
104 |
+
```bash
|
105 |
+
# Start with lightrag-gunicorn command
|
106 |
+
lightrag-gunicorn --workers 4
|
107 |
+
|
108 |
+
# Alternatively, you can use the module directly
|
109 |
+
python -m lightrag.api.run_with_gunicorn --workers 4
|
110 |
```
|
111 |
|
112 |
+
The `--workers` parameter is crucial for performance:
|
113 |
+
|
114 |
+
- Determines how many worker processes Gunicorn will spawn to handle requests
|
115 |
+
- Each worker can handle concurrent requests using asyncio
|
116 |
+
- Recommended value is (2 x number_of_cores) + 1
|
117 |
+
- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9
|
118 |
+
- Consider your server's memory when setting this value, as each worker consumes memory
|
119 |
+
|
120 |
+
Other important startup parameters:
|
121 |
+
|
122 |
+
- `--host`: Server listening address (default: 0.0.0.0)
|
123 |
+
- `--port`: Server listening port (default: 9621)
|
124 |
+
- `--timeout`: Request handling timeout (default: 150 seconds)
|
125 |
+
- `--log-level`: Logging level (default: INFO)
|
126 |
+
- `--ssl`: Enable HTTPS
|
127 |
+
- `--ssl-certfile`: Path to SSL certificate file
|
128 |
+
- `--ssl-keyfile`: Path to SSL private key file
|
129 |
+
|
130 |
+
The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`.
|
131 |
+
|
132 |
### For Azure OpenAI Backend
|
133 |
+
|
134 |
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
|
135 |
```bash
|
136 |
# Change the resource group name, location and OpenAI resource name as needed
|
|
|
221 |
* openai & openai compatible
|
222 |
* azure_openai
|
223 |
|
224 |
+
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
|
225 |
|
226 |
### Storage Types Supported
|
227 |
|
lightrag/api/gunicorn_config.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gunicorn_config.py
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from lightrag.kg.shared_storage import finalize_share_data
|
5 |
+
from lightrag.api.lightrag_server import LightragPathFilter
|
6 |
+
|
7 |
+
# Get log directory path from environment variable
|
8 |
+
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
9 |
+
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
10 |
+
|
11 |
+
# Get log file max size and backup count from environment variables
|
12 |
+
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
13 |
+
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
14 |
+
|
15 |
+
# These variables will be set by run_with_gunicorn.py
|
16 |
+
workers = None
|
17 |
+
bind = None
|
18 |
+
loglevel = None
|
19 |
+
certfile = None
|
20 |
+
keyfile = None
|
21 |
+
|
22 |
+
# Enable preload_app option
|
23 |
+
preload_app = True
|
24 |
+
|
25 |
+
# Use Uvicorn worker
|
26 |
+
worker_class = "uvicorn.workers.UvicornWorker"
|
27 |
+
|
28 |
+
# Other Gunicorn configurations
|
29 |
+
timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
|
30 |
+
keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
|
31 |
+
|
32 |
+
# Logging configuration
|
33 |
+
errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log
|
34 |
+
accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log
|
35 |
+
|
36 |
+
logconfig_dict = {
|
37 |
+
"version": 1,
|
38 |
+
"disable_existing_loggers": False,
|
39 |
+
"formatters": {
|
40 |
+
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
|
41 |
+
},
|
42 |
+
"handlers": {
|
43 |
+
"console": {
|
44 |
+
"class": "logging.StreamHandler",
|
45 |
+
"formatter": "standard",
|
46 |
+
"stream": "ext://sys.stdout",
|
47 |
+
},
|
48 |
+
"file": {
|
49 |
+
"class": "logging.handlers.RotatingFileHandler",
|
50 |
+
"formatter": "standard",
|
51 |
+
"filename": log_file_path,
|
52 |
+
"maxBytes": log_max_bytes,
|
53 |
+
"backupCount": log_backup_count,
|
54 |
+
"encoding": "utf8",
|
55 |
+
},
|
56 |
+
},
|
57 |
+
"filters": {
|
58 |
+
"path_filter": {
|
59 |
+
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
60 |
+
},
|
61 |
+
},
|
62 |
+
"loggers": {
|
63 |
+
"lightrag": {
|
64 |
+
"handlers": ["console", "file"],
|
65 |
+
"level": loglevel.upper() if loglevel else "INFO",
|
66 |
+
"propagate": False,
|
67 |
+
},
|
68 |
+
"gunicorn": {
|
69 |
+
"handlers": ["console", "file"],
|
70 |
+
"level": loglevel.upper() if loglevel else "INFO",
|
71 |
+
"propagate": False,
|
72 |
+
},
|
73 |
+
"gunicorn.error": {
|
74 |
+
"handlers": ["console", "file"],
|
75 |
+
"level": loglevel.upper() if loglevel else "INFO",
|
76 |
+
"propagate": False,
|
77 |
+
},
|
78 |
+
"gunicorn.access": {
|
79 |
+
"handlers": ["console", "file"],
|
80 |
+
"level": loglevel.upper() if loglevel else "INFO",
|
81 |
+
"propagate": False,
|
82 |
+
"filters": ["path_filter"],
|
83 |
+
},
|
84 |
+
},
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
def on_starting(server):
|
89 |
+
"""
|
90 |
+
Executed when Gunicorn starts, before forking the first worker processes
|
91 |
+
You can use this function to do more initialization tasks for all processes
|
92 |
+
"""
|
93 |
+
print("=" * 80)
|
94 |
+
print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)")
|
95 |
+
print(f"Process ID: {os.getpid()}")
|
96 |
+
print("=" * 80)
|
97 |
+
|
98 |
+
# Memory usage monitoring
|
99 |
+
try:
|
100 |
+
import psutil
|
101 |
+
|
102 |
+
process = psutil.Process(os.getpid())
|
103 |
+
memory_info = process.memory_info()
|
104 |
+
msg = (
|
105 |
+
f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB"
|
106 |
+
)
|
107 |
+
print(msg)
|
108 |
+
except ImportError:
|
109 |
+
print("psutil not installed, skipping memory usage reporting")
|
110 |
+
|
111 |
+
print("Gunicorn initialization complete, forking workers...\n")
|
112 |
+
|
113 |
+
|
114 |
+
def on_exit(server):
|
115 |
+
"""
|
116 |
+
Executed when Gunicorn is shutting down.
|
117 |
+
This is a good place to release shared resources.
|
118 |
+
"""
|
119 |
+
print("=" * 80)
|
120 |
+
print("GUNICORN MASTER PROCESS: Shutting down")
|
121 |
+
print(f"Process ID: {os.getpid()}")
|
122 |
+
print("=" * 80)
|
123 |
+
|
124 |
+
# Release shared resources
|
125 |
+
finalize_share_data()
|
126 |
+
|
127 |
+
print("=" * 80)
|
128 |
+
print("Gunicorn shutdown complete")
|
129 |
+
print("=" * 80)
|
130 |
+
|
131 |
+
|
132 |
+
def post_fork(server, worker):
|
133 |
+
"""
|
134 |
+
Executed after a worker has been forked.
|
135 |
+
This is a good place to set up worker-specific configurations.
|
136 |
+
"""
|
137 |
+
# Configure formatters
|
138 |
+
detailed_formatter = logging.Formatter(
|
139 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
140 |
+
)
|
141 |
+
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
142 |
+
|
143 |
+
def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
|
144 |
+
"""Set up a logger with console and file handlers"""
|
145 |
+
logger_instance = logging.getLogger(logger_name)
|
146 |
+
logger_instance.setLevel(level)
|
147 |
+
logger_instance.handlers = [] # Clear existing handlers
|
148 |
+
logger_instance.propagate = False
|
149 |
+
|
150 |
+
# Add console handler
|
151 |
+
console_handler = logging.StreamHandler()
|
152 |
+
console_handler.setFormatter(simple_formatter)
|
153 |
+
console_handler.setLevel(level)
|
154 |
+
logger_instance.addHandler(console_handler)
|
155 |
+
|
156 |
+
# Add file handler
|
157 |
+
file_handler = logging.handlers.RotatingFileHandler(
|
158 |
+
filename=log_file_path,
|
159 |
+
maxBytes=log_max_bytes,
|
160 |
+
backupCount=log_backup_count,
|
161 |
+
encoding="utf-8",
|
162 |
+
)
|
163 |
+
file_handler.setFormatter(detailed_formatter)
|
164 |
+
file_handler.setLevel(level)
|
165 |
+
logger_instance.addHandler(file_handler)
|
166 |
+
|
167 |
+
# Add path filter if requested
|
168 |
+
if add_filter:
|
169 |
+
path_filter = LightragPathFilter()
|
170 |
+
logger_instance.addFilter(path_filter)
|
171 |
+
|
172 |
+
# Set up main loggers
|
173 |
+
log_level = loglevel.upper() if loglevel else "INFO"
|
174 |
+
setup_logger("uvicorn", log_level)
|
175 |
+
setup_logger("uvicorn.access", log_level, add_filter=True)
|
176 |
+
setup_logger("lightrag", log_level, add_filter=True)
|
177 |
+
|
178 |
+
# Set up lightrag submodule loggers
|
179 |
+
for name in logging.root.manager.loggerDict:
|
180 |
+
if name.startswith("lightrag."):
|
181 |
+
setup_logger(name, log_level, add_filter=True)
|
182 |
+
|
183 |
+
# Disable uvicorn.error logger
|
184 |
+
uvicorn_error_logger = logging.getLogger("uvicorn.error")
|
185 |
+
uvicorn_error_logger.handlers = []
|
186 |
+
uvicorn_error_logger.setLevel(logging.CRITICAL)
|
187 |
+
uvicorn_error_logger.propagate = False
|
lightrag/api/lightrag_server.py
CHANGED
@@ -8,11 +8,12 @@ from fastapi import (
|
|
8 |
)
|
9 |
from fastapi.responses import FileResponse
|
10 |
import asyncio
|
11 |
-
import threading
|
12 |
import os
|
13 |
-
from fastapi.staticfiles import StaticFiles
|
14 |
import logging
|
15 |
-
|
|
|
|
|
|
|
16 |
from pathlib import Path
|
17 |
import configparser
|
18 |
from ascii_colors import ASCIIColors
|
@@ -29,7 +30,6 @@ from lightrag import LightRAG
|
|
29 |
from lightrag.types import GPTKeywordExtractionFormat
|
30 |
from lightrag.api import __api_version__
|
31 |
from lightrag.utils import EmbeddingFunc
|
32 |
-
from lightrag.utils import logger
|
33 |
from .routers.document_routes import (
|
34 |
DocumentManager,
|
35 |
create_document_routes,
|
@@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes
|
|
39 |
from .routers.graph_routes import create_graph_routes
|
40 |
from .routers.ollama_api import OllamaAPI
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# Load environment variables
|
43 |
-
|
44 |
-
load_dotenv(override=True)
|
45 |
-
except Exception as e:
|
46 |
-
logger.warning(f"Failed to load .env file: {e}")
|
47 |
|
48 |
# Initialize config parser
|
49 |
config = configparser.ConfigParser()
|
50 |
config.read("config.ini")
|
51 |
|
52 |
-
# Global configuration
|
53 |
-
global_top_k = 60 # default value
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
"is_scanning": False,
|
58 |
-
"current_file": "",
|
59 |
-
"indexed_count": 0,
|
60 |
-
"total_files": 0,
|
61 |
-
"progress": 0,
|
62 |
-
}
|
63 |
|
64 |
-
# Lock for thread-safe operations
|
65 |
-
progress_lock = threading.Lock()
|
66 |
-
|
67 |
-
|
68 |
-
class AccessLogFilter(logging.Filter):
|
69 |
def __init__(self):
|
70 |
super().__init__()
|
71 |
# Define paths to be filtered
|
@@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter):
|
|
73 |
|
74 |
def filter(self, record):
|
75 |
try:
|
|
|
76 |
if not hasattr(record, "args") or not isinstance(record.args, tuple):
|
77 |
return True
|
78 |
if len(record.args) < 5:
|
79 |
return True
|
80 |
|
|
|
81 |
method = record.args[1]
|
82 |
path = record.args[2]
|
83 |
status = record.args[4]
|
84 |
-
# print(f"Debug - Method: {method}, Path: {path}, Status: {status}")
|
85 |
-
# print(f"Debug - Filtered paths: {self.filtered_paths}")
|
86 |
|
|
|
87 |
if (
|
88 |
method == "GET"
|
89 |
and (status == 200 or status == 304)
|
@@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter):
|
|
92 |
return False
|
93 |
|
94 |
return True
|
95 |
-
|
96 |
except Exception:
|
|
|
97 |
return True
|
98 |
|
99 |
|
100 |
def create_app(args):
|
101 |
-
#
|
102 |
-
|
103 |
-
global_top_k = args.top_k # save top_k from args
|
104 |
-
|
105 |
-
# Initialize verbose debug setting
|
106 |
-
from lightrag.utils import set_verbose_debug
|
107 |
-
|
108 |
set_verbose_debug(args.verbose)
|
109 |
|
110 |
# Verify that bindings are correctly setup
|
@@ -138,11 +126,6 @@ def create_app(args):
|
|
138 |
if not os.path.exists(args.ssl_keyfile):
|
139 |
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
140 |
|
141 |
-
# Setup logging
|
142 |
-
logging.basicConfig(
|
143 |
-
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
144 |
-
)
|
145 |
-
|
146 |
# Check if API key is provided either through env var or args
|
147 |
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
148 |
|
@@ -158,28 +141,23 @@ def create_app(args):
|
|
158 |
try:
|
159 |
# Initialize database connections
|
160 |
await rag.initialize_storages()
|
|
|
161 |
|
162 |
# Auto scan documents if enabled
|
163 |
if args.auto_scan_at_startup:
|
164 |
-
#
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
f"Started background scanning of documents from {args.input_dir}"
|
178 |
-
)
|
179 |
-
else:
|
180 |
-
ASCIIColors.info(
|
181 |
-
"Skip document scanning(another scanning is active)"
|
182 |
-
)
|
183 |
|
184 |
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
185 |
|
@@ -398,6 +376,9 @@ def create_app(args):
|
|
398 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
399 |
async def get_status():
|
400 |
"""Get current system status"""
|
|
|
|
|
|
|
401 |
return {
|
402 |
"status": "healthy",
|
403 |
"working_directory": str(args.working_dir),
|
@@ -417,6 +398,7 @@ def create_app(args):
|
|
417 |
"graph_storage": args.graph_storage,
|
418 |
"vector_storage": args.vector_storage,
|
419 |
},
|
|
|
420 |
}
|
421 |
|
422 |
# Webui mount webui/index.html
|
@@ -435,12 +417,30 @@ def create_app(args):
|
|
435 |
return app
|
436 |
|
437 |
|
438 |
-
def
|
439 |
-
|
440 |
-
|
441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
-
# Configure uvicorn logging
|
444 |
logging.config.dictConfig(
|
445 |
{
|
446 |
"version": 1,
|
@@ -449,36 +449,106 @@ def main():
|
|
449 |
"default": {
|
450 |
"format": "%(levelname)s: %(message)s",
|
451 |
},
|
|
|
|
|
|
|
452 |
},
|
453 |
"handlers": {
|
454 |
-
"
|
455 |
"formatter": "default",
|
456 |
"class": "logging.StreamHandler",
|
457 |
"stream": "ext://sys.stderr",
|
458 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
},
|
460 |
"loggers": {
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
"uvicorn.access": {
|
462 |
-
"handlers": ["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
"level": "INFO",
|
464 |
"propagate": False,
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
},
|
466 |
},
|
467 |
}
|
468 |
)
|
469 |
|
470 |
-
# Add filter to uvicorn access logger
|
471 |
-
uvicorn_access_logger = logging.getLogger("uvicorn.access")
|
472 |
-
uvicorn_access_logger.addFilter(AccessLogFilter())
|
473 |
|
474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
display_splash_screen(args)
|
|
|
|
|
|
|
|
|
|
|
476 |
uvicorn_config = {
|
477 |
-
"app": app,
|
478 |
"host": args.host,
|
479 |
"port": args.port,
|
480 |
"log_config": None, # Disable default config
|
481 |
}
|
|
|
482 |
if args.ssl:
|
483 |
uvicorn_config.update(
|
484 |
{
|
@@ -486,6 +556,8 @@ def main():
|
|
486 |
"ssl_keyfile": args.ssl_keyfile,
|
487 |
}
|
488 |
)
|
|
|
|
|
489 |
uvicorn.run(**uvicorn_config)
|
490 |
|
491 |
|
|
|
8 |
)
|
9 |
from fastapi.responses import FileResponse
|
10 |
import asyncio
|
|
|
11 |
import os
|
|
|
12 |
import logging
|
13 |
+
import logging.config
|
14 |
+
import uvicorn
|
15 |
+
import pipmaster as pm
|
16 |
+
from fastapi.staticfiles import StaticFiles
|
17 |
from pathlib import Path
|
18 |
import configparser
|
19 |
from ascii_colors import ASCIIColors
|
|
|
30 |
from lightrag.types import GPTKeywordExtractionFormat
|
31 |
from lightrag.api import __api_version__
|
32 |
from lightrag.utils import EmbeddingFunc
|
|
|
33 |
from .routers.document_routes import (
|
34 |
DocumentManager,
|
35 |
create_document_routes,
|
|
|
39 |
from .routers.graph_routes import create_graph_routes
|
40 |
from .routers.ollama_api import OllamaAPI
|
41 |
|
42 |
+
from lightrag.utils import logger, set_verbose_debug
|
43 |
+
from lightrag.kg.shared_storage import (
|
44 |
+
get_namespace_data,
|
45 |
+
get_pipeline_status_lock,
|
46 |
+
initialize_pipeline_status,
|
47 |
+
get_all_update_flags_status,
|
48 |
+
)
|
49 |
+
|
50 |
# Load environment variables
|
51 |
+
load_dotenv(override=True)
|
|
|
|
|
|
|
52 |
|
53 |
# Initialize config parser
|
54 |
config = configparser.ConfigParser()
|
55 |
config.read("config.ini")
|
56 |
|
|
|
|
|
57 |
|
58 |
+
class LightragPathFilter(logging.Filter):
|
59 |
+
"""Filter for lightrag logger to filter out frequent path access logs"""
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
61 |
def __init__(self):
|
62 |
super().__init__()
|
63 |
# Define paths to be filtered
|
|
|
65 |
|
66 |
def filter(self, record):
|
67 |
try:
|
68 |
+
# Check if record has the required attributes for an access log
|
69 |
if not hasattr(record, "args") or not isinstance(record.args, tuple):
|
70 |
return True
|
71 |
if len(record.args) < 5:
|
72 |
return True
|
73 |
|
74 |
+
# Extract method, path and status from the record args
|
75 |
method = record.args[1]
|
76 |
path = record.args[2]
|
77 |
status = record.args[4]
|
|
|
|
|
78 |
|
79 |
+
# Filter out successful GET requests to filtered paths
|
80 |
if (
|
81 |
method == "GET"
|
82 |
and (status == 200 or status == 304)
|
|
|
85 |
return False
|
86 |
|
87 |
return True
|
|
|
88 |
except Exception:
|
89 |
+
# In case of any error, let the message through
|
90 |
return True
|
91 |
|
92 |
|
93 |
def create_app(args):
|
94 |
+
# Setup logging
|
95 |
+
logger.setLevel(args.log_level)
|
|
|
|
|
|
|
|
|
|
|
96 |
set_verbose_debug(args.verbose)
|
97 |
|
98 |
# Verify that bindings are correctly setup
|
|
|
126 |
if not os.path.exists(args.ssl_keyfile):
|
127 |
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
128 |
|
|
|
|
|
|
|
|
|
|
|
129 |
# Check if API key is provided either through env var or args
|
130 |
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
131 |
|
|
|
141 |
try:
|
142 |
# Initialize database connections
|
143 |
await rag.initialize_storages()
|
144 |
+
await initialize_pipeline_status()
|
145 |
|
146 |
# Auto scan documents if enabled
|
147 |
if args.auto_scan_at_startup:
|
148 |
+
# Check if a task is already running (with lock protection)
|
149 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
150 |
+
should_start_task = False
|
151 |
+
async with get_pipeline_status_lock():
|
152 |
+
if not pipeline_status.get("busy", False):
|
153 |
+
should_start_task = True
|
154 |
+
# Only start the task if no other task is running
|
155 |
+
if should_start_task:
|
156 |
+
# Create background task
|
157 |
+
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
158 |
+
app.state.background_tasks.add(task)
|
159 |
+
task.add_done_callback(app.state.background_tasks.discard)
|
160 |
+
logger.info("Auto scan task started at startup.")
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
163 |
|
|
|
376 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
377 |
async def get_status():
|
378 |
"""Get current system status"""
|
379 |
+
# Get update flags status for all namespaces
|
380 |
+
update_status = await get_all_update_flags_status()
|
381 |
+
|
382 |
return {
|
383 |
"status": "healthy",
|
384 |
"working_directory": str(args.working_dir),
|
|
|
398 |
"graph_storage": args.graph_storage,
|
399 |
"vector_storage": args.vector_storage,
|
400 |
},
|
401 |
+
"update_status": update_status,
|
402 |
}
|
403 |
|
404 |
# Webui mount webui/index.html
|
|
|
417 |
return app
|
418 |
|
419 |
|
420 |
+
def get_application(args=None):
|
421 |
+
"""Factory function for creating the FastAPI application"""
|
422 |
+
if args is None:
|
423 |
+
args = parse_args()
|
424 |
+
return create_app(args)
|
425 |
+
|
426 |
+
|
427 |
+
def configure_logging():
|
428 |
+
"""Configure logging for uvicorn startup"""
|
429 |
+
|
430 |
+
# Reset any existing handlers to ensure clean configuration
|
431 |
+
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
|
432 |
+
logger = logging.getLogger(logger_name)
|
433 |
+
logger.handlers = []
|
434 |
+
logger.filters = []
|
435 |
+
|
436 |
+
# Get log directory path from environment variable
|
437 |
+
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
438 |
+
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
439 |
+
|
440 |
+
# Get log file max size and backup count from environment variables
|
441 |
+
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
442 |
+
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
443 |
|
|
|
444 |
logging.config.dictConfig(
|
445 |
{
|
446 |
"version": 1,
|
|
|
449 |
"default": {
|
450 |
"format": "%(levelname)s: %(message)s",
|
451 |
},
|
452 |
+
"detailed": {
|
453 |
+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
454 |
+
},
|
455 |
},
|
456 |
"handlers": {
|
457 |
+
"console": {
|
458 |
"formatter": "default",
|
459 |
"class": "logging.StreamHandler",
|
460 |
"stream": "ext://sys.stderr",
|
461 |
},
|
462 |
+
"file": {
|
463 |
+
"formatter": "detailed",
|
464 |
+
"class": "logging.handlers.RotatingFileHandler",
|
465 |
+
"filename": log_file_path,
|
466 |
+
"maxBytes": log_max_bytes,
|
467 |
+
"backupCount": log_backup_count,
|
468 |
+
"encoding": "utf-8",
|
469 |
+
},
|
470 |
},
|
471 |
"loggers": {
|
472 |
+
# Configure all uvicorn related loggers
|
473 |
+
"uvicorn": {
|
474 |
+
"handlers": ["console", "file"],
|
475 |
+
"level": "INFO",
|
476 |
+
"propagate": False,
|
477 |
+
},
|
478 |
"uvicorn.access": {
|
479 |
+
"handlers": ["console", "file"],
|
480 |
+
"level": "INFO",
|
481 |
+
"propagate": False,
|
482 |
+
"filters": ["path_filter"],
|
483 |
+
},
|
484 |
+
"uvicorn.error": {
|
485 |
+
"handlers": ["console", "file"],
|
486 |
+
"level": "INFO",
|
487 |
+
"propagate": False,
|
488 |
+
},
|
489 |
+
"lightrag": {
|
490 |
+
"handlers": ["console", "file"],
|
491 |
"level": "INFO",
|
492 |
"propagate": False,
|
493 |
+
"filters": ["path_filter"],
|
494 |
+
},
|
495 |
+
},
|
496 |
+
"filters": {
|
497 |
+
"path_filter": {
|
498 |
+
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
499 |
},
|
500 |
},
|
501 |
}
|
502 |
)
|
503 |
|
|
|
|
|
|
|
504 |
|
505 |
+
def check_and_install_dependencies():
|
506 |
+
"""Check and install required dependencies"""
|
507 |
+
required_packages = [
|
508 |
+
"uvicorn",
|
509 |
+
"tiktoken",
|
510 |
+
"fastapi",
|
511 |
+
# Add other required packages here
|
512 |
+
]
|
513 |
+
|
514 |
+
for package in required_packages:
|
515 |
+
if not pm.is_installed(package):
|
516 |
+
print(f"Installing {package}...")
|
517 |
+
pm.install(package)
|
518 |
+
print(f"{package} installed successfully")
|
519 |
+
|
520 |
+
|
521 |
+
def main():
|
522 |
+
# Check if running under Gunicorn
|
523 |
+
if "GUNICORN_CMD_ARGS" in os.environ:
|
524 |
+
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
525 |
+
print("Running under Gunicorn - worker management handled by Gunicorn")
|
526 |
+
return
|
527 |
+
|
528 |
+
# Check and install dependencies
|
529 |
+
check_and_install_dependencies()
|
530 |
+
|
531 |
+
from multiprocessing import freeze_support
|
532 |
+
|
533 |
+
freeze_support()
|
534 |
+
|
535 |
+
# Configure logging before parsing args
|
536 |
+
configure_logging()
|
537 |
+
|
538 |
+
args = parse_args(is_uvicorn_mode=True)
|
539 |
display_splash_screen(args)
|
540 |
+
|
541 |
+
# Create application instance directly instead of using factory function
|
542 |
+
app = create_app(args)
|
543 |
+
|
544 |
+
# Start Uvicorn in single process mode
|
545 |
uvicorn_config = {
|
546 |
+
"app": app, # Pass application instance directly instead of string path
|
547 |
"host": args.host,
|
548 |
"port": args.port,
|
549 |
"log_config": None, # Disable default config
|
550 |
}
|
551 |
+
|
552 |
if args.ssl:
|
553 |
uvicorn_config.update(
|
554 |
{
|
|
|
556 |
"ssl_keyfile": args.ssl_keyfile,
|
557 |
}
|
558 |
)
|
559 |
+
|
560 |
+
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
|
561 |
uvicorn.run(**uvicorn_config)
|
562 |
|
563 |
|
lightrag/api/routers/document_routes.py
CHANGED
@@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
|
|
3 |
"""
|
4 |
|
5 |
import asyncio
|
6 |
-
import
|
7 |
-
import os
|
8 |
import aiofiles
|
9 |
import shutil
|
10 |
import traceback
|
@@ -12,7 +11,6 @@ import pipmaster as pm
|
|
12 |
from datetime import datetime
|
13 |
from pathlib import Path
|
14 |
from typing import Dict, List, Optional, Any
|
15 |
-
|
16 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
17 |
from pydantic import BaseModel, Field, field_validator
|
18 |
|
@@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency
|
|
23 |
|
24 |
router = APIRouter(prefix="/documents", tags=["documents"])
|
25 |
|
26 |
-
# Global progress tracker
|
27 |
-
scan_progress: Dict = {
|
28 |
-
"is_scanning": False,
|
29 |
-
"current_file": "",
|
30 |
-
"indexed_count": 0,
|
31 |
-
"total_files": 0,
|
32 |
-
"progress": 0,
|
33 |
-
}
|
34 |
-
|
35 |
-
# Lock for thread-safe operations
|
36 |
-
progress_lock = asyncio.Lock()
|
37 |
-
|
38 |
# Temporary file prefix
|
39 |
temp_prefix = "__tmp__"
|
40 |
|
@@ -161,19 +147,12 @@ class DocumentManager:
|
|
161 |
"""Scan input directory for new files"""
|
162 |
new_files = []
|
163 |
for ext in self.supported_extensions:
|
164 |
-
|
165 |
for file_path in self.input_dir.rglob(f"*{ext}"):
|
166 |
if file_path not in self.indexed_files:
|
167 |
new_files.append(file_path)
|
168 |
return new_files
|
169 |
|
170 |
-
# def scan_directory(self) -> List[Path]:
|
171 |
-
# new_files = []
|
172 |
-
# for ext in self.supported_extensions:
|
173 |
-
# for file_path in self.input_dir.rglob(f"*{ext}"):
|
174 |
-
# new_files.append(file_path)
|
175 |
-
# return new_files
|
176 |
-
|
177 |
def mark_as_indexed(self, file_path: Path):
|
178 |
self.indexed_files.add(file_path)
|
179 |
|
@@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
287 |
)
|
288 |
content += "\n"
|
289 |
case _:
|
290 |
-
|
291 |
f"Unsupported file type: {file_path.name} (extension {ext})"
|
292 |
)
|
293 |
return False
|
@@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
295 |
# Insert into the RAG queue
|
296 |
if content:
|
297 |
await rag.apipeline_enqueue_documents(content)
|
298 |
-
|
299 |
return True
|
300 |
else:
|
301 |
-
|
302 |
|
303 |
except Exception as e:
|
304 |
-
|
305 |
-
|
306 |
finally:
|
307 |
if file_path.name.startswith(temp_prefix):
|
308 |
try:
|
309 |
file_path.unlink()
|
310 |
except Exception as e:
|
311 |
-
|
312 |
return False
|
313 |
|
314 |
|
@@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
|
|
324 |
await rag.apipeline_process_enqueue_documents()
|
325 |
|
326 |
except Exception as e:
|
327 |
-
|
328 |
-
|
329 |
|
330 |
|
331 |
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
@@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
|
349 |
if enqueued:
|
350 |
await rag.apipeline_process_enqueue_documents()
|
351 |
except Exception as e:
|
352 |
-
|
353 |
-
|
354 |
|
355 |
|
356 |
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
@@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|
393 |
"""Background task to scan and index documents"""
|
394 |
try:
|
395 |
new_files = doc_manager.scan_directory_for_new_files()
|
396 |
-
|
|
|
397 |
|
398 |
-
|
399 |
-
for file_path in new_files:
|
400 |
try:
|
401 |
-
async with progress_lock:
|
402 |
-
scan_progress["current_file"] = os.path.basename(file_path)
|
403 |
-
|
404 |
await pipeline_index_file(rag, file_path)
|
405 |
-
|
406 |
-
async with progress_lock:
|
407 |
-
scan_progress["indexed_count"] += 1
|
408 |
-
scan_progress["progress"] = (
|
409 |
-
scan_progress["indexed_count"] / scan_progress["total_files"]
|
410 |
-
) * 100
|
411 |
-
|
412 |
except Exception as e:
|
413 |
-
|
414 |
|
415 |
except Exception as e:
|
416 |
-
|
417 |
-
finally:
|
418 |
-
async with progress_lock:
|
419 |
-
scan_progress["is_scanning"] = False
|
420 |
|
421 |
|
422 |
def create_document_routes(
|
@@ -436,34 +402,10 @@ def create_document_routes(
|
|
436 |
Returns:
|
437 |
dict: A dictionary containing the scanning status
|
438 |
"""
|
439 |
-
async with progress_lock:
|
440 |
-
if scan_progress["is_scanning"]:
|
441 |
-
return {"status": "already_scanning"}
|
442 |
-
|
443 |
-
scan_progress["is_scanning"] = True
|
444 |
-
scan_progress["indexed_count"] = 0
|
445 |
-
scan_progress["progress"] = 0
|
446 |
-
|
447 |
# Start the scanning process in the background
|
448 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
449 |
return {"status": "scanning_started"}
|
450 |
|
451 |
-
@router.get("/scan-progress")
|
452 |
-
async def get_scan_progress():
|
453 |
-
"""
|
454 |
-
Get the current progress of the document scanning process.
|
455 |
-
|
456 |
-
Returns:
|
457 |
-
dict: A dictionary containing the current scanning progress information including:
|
458 |
-
- is_scanning: Whether a scan is currently in progress
|
459 |
-
- current_file: The file currently being processed
|
460 |
-
- indexed_count: Number of files indexed so far
|
461 |
-
- total_files: Total number of files to process
|
462 |
-
- progress: Percentage of completion
|
463 |
-
"""
|
464 |
-
async with progress_lock:
|
465 |
-
return scan_progress
|
466 |
-
|
467 |
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
468 |
async def upload_to_input_dir(
|
469 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
@@ -504,8 +446,8 @@ def create_document_routes(
|
|
504 |
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
|
505 |
)
|
506 |
except Exception as e:
|
507 |
-
|
508 |
-
|
509 |
raise HTTPException(status_code=500, detail=str(e))
|
510 |
|
511 |
@router.post(
|
@@ -537,8 +479,8 @@ def create_document_routes(
|
|
537 |
message="Text successfully received. Processing will continue in background.",
|
538 |
)
|
539 |
except Exception as e:
|
540 |
-
|
541 |
-
|
542 |
raise HTTPException(status_code=500, detail=str(e))
|
543 |
|
544 |
@router.post(
|
@@ -572,8 +514,8 @@ def create_document_routes(
|
|
572 |
message="Text successfully received. Processing will continue in background.",
|
573 |
)
|
574 |
except Exception as e:
|
575 |
-
|
576 |
-
|
577 |
raise HTTPException(status_code=500, detail=str(e))
|
578 |
|
579 |
@router.post(
|
@@ -615,8 +557,8 @@ def create_document_routes(
|
|
615 |
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
|
616 |
)
|
617 |
except Exception as e:
|
618 |
-
|
619 |
-
|
620 |
raise HTTPException(status_code=500, detail=str(e))
|
621 |
|
622 |
@router.post(
|
@@ -678,8 +620,8 @@ def create_document_routes(
|
|
678 |
|
679 |
return InsertResponse(status=status, message=status_message)
|
680 |
except Exception as e:
|
681 |
-
|
682 |
-
|
683 |
raise HTTPException(status_code=500, detail=str(e))
|
684 |
|
685 |
@router.delete(
|
@@ -706,8 +648,42 @@ def create_document_routes(
|
|
706 |
status="success", message="All documents cleared successfully"
|
707 |
)
|
708 |
except Exception as e:
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
raise HTTPException(status_code=500, detail=str(e))
|
712 |
|
713 |
@router.get("", dependencies=[Depends(optional_api_key)])
|
@@ -763,8 +739,8 @@ def create_document_routes(
|
|
763 |
)
|
764 |
return response
|
765 |
except Exception as e:
|
766 |
-
|
767 |
-
|
768 |
raise HTTPException(status_code=500, detail=str(e))
|
769 |
|
770 |
return router
|
|
|
3 |
"""
|
4 |
|
5 |
import asyncio
|
6 |
+
from lightrag.utils import logger
|
|
|
7 |
import aiofiles
|
8 |
import shutil
|
9 |
import traceback
|
|
|
11 |
from datetime import datetime
|
12 |
from pathlib import Path
|
13 |
from typing import Dict, List, Optional, Any
|
|
|
14 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
15 |
from pydantic import BaseModel, Field, field_validator
|
16 |
|
|
|
21 |
|
22 |
router = APIRouter(prefix="/documents", tags=["documents"])
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Temporary file prefix
|
25 |
temp_prefix = "__tmp__"
|
26 |
|
|
|
147 |
"""Scan input directory for new files"""
|
148 |
new_files = []
|
149 |
for ext in self.supported_extensions:
|
150 |
+
logger.debug(f"Scanning for {ext} files in {self.input_dir}")
|
151 |
for file_path in self.input_dir.rglob(f"*{ext}"):
|
152 |
if file_path not in self.indexed_files:
|
153 |
new_files.append(file_path)
|
154 |
return new_files
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def mark_as_indexed(self, file_path: Path):
|
157 |
self.indexed_files.add(file_path)
|
158 |
|
|
|
266 |
)
|
267 |
content += "\n"
|
268 |
case _:
|
269 |
+
logger.error(
|
270 |
f"Unsupported file type: {file_path.name} (extension {ext})"
|
271 |
)
|
272 |
return False
|
|
|
274 |
# Insert into the RAG queue
|
275 |
if content:
|
276 |
await rag.apipeline_enqueue_documents(content)
|
277 |
+
logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
278 |
return True
|
279 |
else:
|
280 |
+
logger.error(f"No content could be extracted from file: {file_path.name}")
|
281 |
|
282 |
except Exception as e:
|
283 |
+
logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
|
284 |
+
logger.error(traceback.format_exc())
|
285 |
finally:
|
286 |
if file_path.name.startswith(temp_prefix):
|
287 |
try:
|
288 |
file_path.unlink()
|
289 |
except Exception as e:
|
290 |
+
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
291 |
return False
|
292 |
|
293 |
|
|
|
303 |
await rag.apipeline_process_enqueue_documents()
|
304 |
|
305 |
except Exception as e:
|
306 |
+
logger.error(f"Error indexing file {file_path.name}: {str(e)}")
|
307 |
+
logger.error(traceback.format_exc())
|
308 |
|
309 |
|
310 |
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
|
|
328 |
if enqueued:
|
329 |
await rag.apipeline_process_enqueue_documents()
|
330 |
except Exception as e:
|
331 |
+
logger.error(f"Error indexing files: {str(e)}")
|
332 |
+
logger.error(traceback.format_exc())
|
333 |
|
334 |
|
335 |
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
|
|
372 |
"""Background task to scan and index documents"""
|
373 |
try:
|
374 |
new_files = doc_manager.scan_directory_for_new_files()
|
375 |
+
total_files = len(new_files)
|
376 |
+
logger.info(f"Found {total_files} new files to index.")
|
377 |
|
378 |
+
for idx, file_path in enumerate(new_files):
|
|
|
379 |
try:
|
|
|
|
|
|
|
380 |
await pipeline_index_file(rag, file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
except Exception as e:
|
382 |
+
logger.error(f"Error indexing file {file_path}: {str(e)}")
|
383 |
|
384 |
except Exception as e:
|
385 |
+
logger.error(f"Error during scanning process: {str(e)}")
|
|
|
|
|
|
|
386 |
|
387 |
|
388 |
def create_document_routes(
|
|
|
402 |
Returns:
|
403 |
dict: A dictionary containing the scanning status
|
404 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
# Start the scanning process in the background
|
406 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
407 |
return {"status": "scanning_started"}
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
410 |
async def upload_to_input_dir(
|
411 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
|
|
446 |
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
|
447 |
)
|
448 |
except Exception as e:
|
449 |
+
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
450 |
+
logger.error(traceback.format_exc())
|
451 |
raise HTTPException(status_code=500, detail=str(e))
|
452 |
|
453 |
@router.post(
|
|
|
479 |
message="Text successfully received. Processing will continue in background.",
|
480 |
)
|
481 |
except Exception as e:
|
482 |
+
logger.error(f"Error /documents/text: {str(e)}")
|
483 |
+
logger.error(traceback.format_exc())
|
484 |
raise HTTPException(status_code=500, detail=str(e))
|
485 |
|
486 |
@router.post(
|
|
|
514 |
message="Text successfully received. Processing will continue in background.",
|
515 |
)
|
516 |
except Exception as e:
|
517 |
+
logger.error(f"Error /documents/text: {str(e)}")
|
518 |
+
logger.error(traceback.format_exc())
|
519 |
raise HTTPException(status_code=500, detail=str(e))
|
520 |
|
521 |
@router.post(
|
|
|
557 |
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
|
558 |
)
|
559 |
except Exception as e:
|
560 |
+
logger.error(f"Error /documents/file: {str(e)}")
|
561 |
+
logger.error(traceback.format_exc())
|
562 |
raise HTTPException(status_code=500, detail=str(e))
|
563 |
|
564 |
@router.post(
|
|
|
620 |
|
621 |
return InsertResponse(status=status, message=status_message)
|
622 |
except Exception as e:
|
623 |
+
logger.error(f"Error /documents/batch: {str(e)}")
|
624 |
+
logger.error(traceback.format_exc())
|
625 |
raise HTTPException(status_code=500, detail=str(e))
|
626 |
|
627 |
@router.delete(
|
|
|
648 |
status="success", message="All documents cleared successfully"
|
649 |
)
|
650 |
except Exception as e:
|
651 |
+
logger.error(f"Error DELETE /documents: {str(e)}")
|
652 |
+
logger.error(traceback.format_exc())
|
653 |
+
raise HTTPException(status_code=500, detail=str(e))
|
654 |
+
|
655 |
+
@router.get("/pipeline_status", dependencies=[Depends(optional_api_key)])
|
656 |
+
async def get_pipeline_status():
|
657 |
+
"""
|
658 |
+
Get the current status of the document indexing pipeline.
|
659 |
+
|
660 |
+
This endpoint returns information about the current state of the document processing pipeline,
|
661 |
+
including whether it's busy, the current job name, when it started, how many documents
|
662 |
+
are being processed, how many batches there are, and which batch is currently being processed.
|
663 |
+
|
664 |
+
Returns:
|
665 |
+
dict: A dictionary containing the pipeline status information
|
666 |
+
"""
|
667 |
+
try:
|
668 |
+
from lightrag.kg.shared_storage import get_namespace_data
|
669 |
+
|
670 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
671 |
+
|
672 |
+
# Convert to regular dict if it's a Manager.dict
|
673 |
+
status_dict = dict(pipeline_status)
|
674 |
+
|
675 |
+
# Convert history_messages to a regular list if it's a Manager.list
|
676 |
+
if "history_messages" in status_dict:
|
677 |
+
status_dict["history_messages"] = list(status_dict["history_messages"])
|
678 |
+
|
679 |
+
# Format the job_start time if it exists
|
680 |
+
if status_dict.get("job_start"):
|
681 |
+
status_dict["job_start"] = str(status_dict["job_start"])
|
682 |
+
|
683 |
+
return status_dict
|
684 |
+
except Exception as e:
|
685 |
+
logger.error(f"Error getting pipeline status: {str(e)}")
|
686 |
+
logger.error(traceback.format_exc())
|
687 |
raise HTTPException(status_code=500, detail=str(e))
|
688 |
|
689 |
@router.get("", dependencies=[Depends(optional_api_key)])
|
|
|
739 |
)
|
740 |
return response
|
741 |
except Exception as e:
|
742 |
+
logger.error(f"Error GET /documents: {str(e)}")
|
743 |
+
logger.error(traceback.format_exc())
|
744 |
raise HTTPException(status_code=500, detail=str(e))
|
745 |
|
746 |
return router
|
lightrag/api/run_with_gunicorn.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Start LightRAG server with Gunicorn
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import signal
|
9 |
+
import pipmaster as pm
|
10 |
+
from lightrag.api.utils_api import parse_args, display_splash_screen
|
11 |
+
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
12 |
+
|
13 |
+
|
14 |
+
def check_and_install_dependencies():
|
15 |
+
"""Check and install required dependencies"""
|
16 |
+
required_packages = [
|
17 |
+
"gunicorn",
|
18 |
+
"tiktoken",
|
19 |
+
"psutil",
|
20 |
+
# Add other required packages here
|
21 |
+
]
|
22 |
+
|
23 |
+
for package in required_packages:
|
24 |
+
if not pm.is_installed(package):
|
25 |
+
print(f"Installing {package}...")
|
26 |
+
pm.install(package)
|
27 |
+
print(f"{package} installed successfully")
|
28 |
+
|
29 |
+
|
30 |
+
# Signal handler for graceful shutdown
|
31 |
+
def signal_handler(sig, frame):
|
32 |
+
print("\n\n" + "=" * 80)
|
33 |
+
print("RECEIVED TERMINATION SIGNAL")
|
34 |
+
print(f"Process ID: {os.getpid()}")
|
35 |
+
print("=" * 80 + "\n")
|
36 |
+
|
37 |
+
# Release shared resources
|
38 |
+
finalize_share_data()
|
39 |
+
|
40 |
+
# Exit with success status
|
41 |
+
sys.exit(0)
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
# Check and install dependencies
|
46 |
+
check_and_install_dependencies()
|
47 |
+
|
48 |
+
# Register signal handlers for graceful shutdown
|
49 |
+
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
50 |
+
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
51 |
+
|
52 |
+
# Parse all arguments using parse_args
|
53 |
+
args = parse_args(is_uvicorn_mode=False)
|
54 |
+
|
55 |
+
# Display startup information
|
56 |
+
display_splash_screen(args)
|
57 |
+
|
58 |
+
print("🚀 Starting LightRAG with Gunicorn")
|
59 |
+
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
60 |
+
print("🔍 Preloading app: Enabled")
|
61 |
+
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
62 |
+
print("\n\n" + "=" * 80)
|
63 |
+
print("MAIN PROCESS INITIALIZATION")
|
64 |
+
print(f"Process ID: {os.getpid()}")
|
65 |
+
print(f"Workers setting: {args.workers}")
|
66 |
+
print("=" * 80 + "\n")
|
67 |
+
|
68 |
+
# Import Gunicorn's StandaloneApplication
|
69 |
+
from gunicorn.app.base import BaseApplication
|
70 |
+
|
71 |
+
# Define a custom application class that loads our config
|
72 |
+
class GunicornApp(BaseApplication):
|
73 |
+
def __init__(self, app, options=None):
|
74 |
+
self.options = options or {}
|
75 |
+
self.application = app
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
def load_config(self):
|
79 |
+
# Define valid Gunicorn configuration options
|
80 |
+
valid_options = {
|
81 |
+
"bind",
|
82 |
+
"workers",
|
83 |
+
"worker_class",
|
84 |
+
"timeout",
|
85 |
+
"keepalive",
|
86 |
+
"preload_app",
|
87 |
+
"errorlog",
|
88 |
+
"accesslog",
|
89 |
+
"loglevel",
|
90 |
+
"certfile",
|
91 |
+
"keyfile",
|
92 |
+
"limit_request_line",
|
93 |
+
"limit_request_fields",
|
94 |
+
"limit_request_field_size",
|
95 |
+
"graceful_timeout",
|
96 |
+
"max_requests",
|
97 |
+
"max_requests_jitter",
|
98 |
+
}
|
99 |
+
|
100 |
+
# Special hooks that need to be set separately
|
101 |
+
special_hooks = {
|
102 |
+
"on_starting",
|
103 |
+
"on_reload",
|
104 |
+
"on_exit",
|
105 |
+
"pre_fork",
|
106 |
+
"post_fork",
|
107 |
+
"pre_exec",
|
108 |
+
"pre_request",
|
109 |
+
"post_request",
|
110 |
+
"worker_init",
|
111 |
+
"worker_exit",
|
112 |
+
"nworkers_changed",
|
113 |
+
"child_exit",
|
114 |
+
}
|
115 |
+
|
116 |
+
# Import and configure the gunicorn_config module
|
117 |
+
from lightrag.api import gunicorn_config
|
118 |
+
|
119 |
+
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
120 |
+
gunicorn_config.workers = (
|
121 |
+
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
122 |
+
)
|
123 |
+
|
124 |
+
# Bind configuration prioritizes command line arguments
|
125 |
+
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
126 |
+
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
127 |
+
gunicorn_config.bind = f"{host}:{port}"
|
128 |
+
|
129 |
+
# Log level configuration prioritizes command line arguments
|
130 |
+
gunicorn_config.loglevel = (
|
131 |
+
args.log_level.lower()
|
132 |
+
if args.log_level
|
133 |
+
else os.getenv("LOG_LEVEL", "info")
|
134 |
+
)
|
135 |
+
|
136 |
+
# Timeout configuration prioritizes command line arguments
|
137 |
+
gunicorn_config.timeout = (
|
138 |
+
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
|
139 |
+
)
|
140 |
+
|
141 |
+
# Keepalive configuration
|
142 |
+
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
143 |
+
|
144 |
+
# SSL configuration prioritizes command line arguments
|
145 |
+
if args.ssl or os.getenv("SSL", "").lower() in (
|
146 |
+
"true",
|
147 |
+
"1",
|
148 |
+
"yes",
|
149 |
+
"t",
|
150 |
+
"on",
|
151 |
+
):
|
152 |
+
gunicorn_config.certfile = (
|
153 |
+
args.ssl_certfile
|
154 |
+
if args.ssl_certfile
|
155 |
+
else os.getenv("SSL_CERTFILE")
|
156 |
+
)
|
157 |
+
gunicorn_config.keyfile = (
|
158 |
+
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
159 |
+
)
|
160 |
+
|
161 |
+
# Set configuration options from the module
|
162 |
+
for key in dir(gunicorn_config):
|
163 |
+
if key in valid_options:
|
164 |
+
value = getattr(gunicorn_config, key)
|
165 |
+
# Skip functions like on_starting and None values
|
166 |
+
if not callable(value) and value is not None:
|
167 |
+
self.cfg.set(key, value)
|
168 |
+
# Set special hooks
|
169 |
+
elif key in special_hooks:
|
170 |
+
value = getattr(gunicorn_config, key)
|
171 |
+
if callable(value):
|
172 |
+
self.cfg.set(key, value)
|
173 |
+
|
174 |
+
if hasattr(gunicorn_config, "logconfig_dict"):
|
175 |
+
self.cfg.set(
|
176 |
+
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
|
177 |
+
)
|
178 |
+
|
179 |
+
def load(self):
|
180 |
+
# Import the application
|
181 |
+
from lightrag.api.lightrag_server import get_application
|
182 |
+
|
183 |
+
return get_application(args)
|
184 |
+
|
185 |
+
# Create the application
|
186 |
+
app = GunicornApp("")
|
187 |
+
|
188 |
+
# Force workers to be an integer and greater than 1 for multi-process mode
|
189 |
+
workers_count = int(args.workers)
|
190 |
+
if workers_count > 1:
|
191 |
+
# Set a flag to indicate we're in the main process
|
192 |
+
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
193 |
+
initialize_share_data(workers_count)
|
194 |
+
else:
|
195 |
+
initialize_share_data(1)
|
196 |
+
|
197 |
+
# Run the application
|
198 |
+
print("\nStarting Gunicorn with direct Python API...")
|
199 |
+
app.run()
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
main()
|
lightrag/api/utils_api.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
import argparse
|
7 |
from typing import Optional
|
8 |
import sys
|
|
|
9 |
from ascii_colors import ASCIIColors
|
10 |
from lightrag.api import __api_version__
|
11 |
from fastapi import HTTPException, Security
|
@@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
|
110 |
return default
|
111 |
|
112 |
|
113 |
-
def parse_args() -> argparse.Namespace:
|
114 |
"""
|
115 |
Parse command line arguments with environment variable fallback
|
116 |
|
|
|
|
|
|
|
117 |
Returns:
|
118 |
argparse.Namespace: Parsed arguments
|
119 |
"""
|
@@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace:
|
|
260 |
help="Enable automatic scanning when the program starts",
|
261 |
)
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# LLM and embedding bindings
|
264 |
parser.add_argument(
|
265 |
"--llm-binding",
|
@@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace:
|
|
278 |
|
279 |
args = parser.parse_args()
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
# convert relative path to absolute path
|
282 |
args.working_dir = os.path.abspath(args.working_dir)
|
283 |
args.input_dir = os.path.abspath(args.input_dir)
|
@@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
346 |
ASCIIColors.yellow(f"{args.host}")
|
347 |
ASCIIColors.white(" ├─ Port: ", end="")
|
348 |
ASCIIColors.yellow(f"{args.port}")
|
|
|
|
|
349 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
350 |
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
351 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
352 |
ASCIIColors.yellow(f"{args.ssl}")
|
353 |
-
ASCIIColors.white(" └─ API Key: ", end="")
|
354 |
-
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
355 |
if args.ssl:
|
356 |
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
357 |
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
358 |
-
ASCIIColors.white("
|
359 |
ASCIIColors.yellow(f"{args.ssl_keyfile}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
# Directory Configuration
|
362 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
@@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
415 |
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
416 |
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
417 |
|
418 |
-
ASCIIColors.magenta("\n🛠️ System Configuration:")
|
419 |
-
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
420 |
-
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
421 |
-
ASCIIColors.white(" ├─ Log Level: ", end="")
|
422 |
-
ASCIIColors.yellow(f"{args.log_level}")
|
423 |
-
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
|
424 |
-
ASCIIColors.yellow(f"{args.verbose}")
|
425 |
-
ASCIIColors.white(" └─ Timeout: ", end="")
|
426 |
-
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
427 |
-
|
428 |
# Server Status
|
429 |
ASCIIColors.green("\n✨ Server starting up...\n")
|
430 |
|
@@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
478 |
ASCIIColors.cyan(""" 3. Basic Operations:
|
479 |
- POST /upload_document: Upload new documents to RAG
|
480 |
- POST /query: Query your document collection
|
481 |
-
- GET /collections: List available collections
|
482 |
|
483 |
4. Monitor the server:
|
484 |
- Check server logs for detailed operation information
|
|
|
6 |
import argparse
|
7 |
from typing import Optional
|
8 |
import sys
|
9 |
+
import logging
|
10 |
from ascii_colors import ASCIIColors
|
11 |
from lightrag.api import __api_version__
|
12 |
from fastapi import HTTPException, Security
|
|
|
111 |
return default
|
112 |
|
113 |
|
114 |
+
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
115 |
"""
|
116 |
Parse command line arguments with environment variable fallback
|
117 |
|
118 |
+
Args:
|
119 |
+
is_uvicorn_mode: Whether running under uvicorn mode
|
120 |
+
|
121 |
Returns:
|
122 |
argparse.Namespace: Parsed arguments
|
123 |
"""
|
|
|
264 |
help="Enable automatic scanning when the program starts",
|
265 |
)
|
266 |
|
267 |
+
# Server workers configuration
|
268 |
+
parser.add_argument(
|
269 |
+
"--workers",
|
270 |
+
type=int,
|
271 |
+
default=get_env_value("WORKERS", 1, int),
|
272 |
+
help="Number of worker processes (default: from env or 1)",
|
273 |
+
)
|
274 |
+
|
275 |
# LLM and embedding bindings
|
276 |
parser.add_argument(
|
277 |
"--llm-binding",
|
|
|
290 |
|
291 |
args = parser.parse_args()
|
292 |
|
293 |
+
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
294 |
+
if is_uvicorn_mode and args.workers > 1:
|
295 |
+
original_workers = args.workers
|
296 |
+
args.workers = 1
|
297 |
+
# Log warning directly here
|
298 |
+
logging.warning(
|
299 |
+
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
300 |
+
)
|
301 |
+
|
302 |
# convert relative path to absolute path
|
303 |
args.working_dir = os.path.abspath(args.working_dir)
|
304 |
args.input_dir = os.path.abspath(args.input_dir)
|
|
|
367 |
ASCIIColors.yellow(f"{args.host}")
|
368 |
ASCIIColors.white(" ├─ Port: ", end="")
|
369 |
ASCIIColors.yellow(f"{args.port}")
|
370 |
+
ASCIIColors.white(" ├─ Workers: ", end="")
|
371 |
+
ASCIIColors.yellow(f"{args.workers}")
|
372 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
373 |
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
374 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
375 |
ASCIIColors.yellow(f"{args.ssl}")
|
|
|
|
|
376 |
if args.ssl:
|
377 |
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
378 |
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
379 |
+
ASCIIColors.white(" ├─ SSL Key: ", end="")
|
380 |
ASCIIColors.yellow(f"{args.ssl_keyfile}")
|
381 |
+
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
382 |
+
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
383 |
+
ASCIIColors.white(" ├─ Log Level: ", end="")
|
384 |
+
ASCIIColors.yellow(f"{args.log_level}")
|
385 |
+
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
|
386 |
+
ASCIIColors.yellow(f"{args.verbose}")
|
387 |
+
ASCIIColors.white(" ├─ Timeout: ", end="")
|
388 |
+
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
389 |
+
ASCIIColors.white(" └─ API Key: ", end="")
|
390 |
+
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
391 |
|
392 |
# Directory Configuration
|
393 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
|
|
446 |
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
447 |
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
# Server Status
|
450 |
ASCIIColors.green("\n✨ Server starting up...\n")
|
451 |
|
|
|
499 |
ASCIIColors.cyan(""" 3. Basic Operations:
|
500 |
- POST /upload_document: Upload new documents to RAG
|
501 |
- POST /query: Query your document collection
|
|
|
502 |
|
503 |
4. Monitor the server:
|
504 |
- Check server logs for detailed operation information
|
lightrag/kg/faiss_impl.py
CHANGED
@@ -2,25 +2,25 @@ import os
|
|
2 |
import time
|
3 |
import asyncio
|
4 |
from typing import Any, final
|
5 |
-
|
6 |
import json
|
7 |
import numpy as np
|
8 |
|
9 |
from dataclasses import dataclass
|
10 |
import pipmaster as pm
|
11 |
|
12 |
-
from lightrag.utils import
|
13 |
-
|
14 |
-
compute_mdhash_id,
|
15 |
-
)
|
16 |
-
from lightrag.base import (
|
17 |
-
BaseVectorStorage,
|
18 |
-
)
|
19 |
|
20 |
if not pm.is_installed("faiss"):
|
21 |
pm.install("faiss")
|
22 |
|
23 |
-
import faiss
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
@final
|
@@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
55 |
# If you have a large number of vectors, you might want IVF or other indexes.
|
56 |
# For demonstration, we use a simple IndexFlatIP.
|
57 |
self._index = faiss.IndexFlatIP(self._dim)
|
58 |
-
|
59 |
# Keep a local store for metadata, IDs, etc.
|
60 |
# Maps <int faiss_id> → metadata (including your original ID).
|
61 |
self._id_to_meta = {}
|
62 |
|
63 |
-
# Attempt to load an existing index + metadata from disk
|
64 |
self._load_faiss_index()
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
67 |
"""
|
68 |
Insert or update vectors in the Faiss index.
|
@@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
113 |
)
|
114 |
return []
|
115 |
|
116 |
-
#
|
|
|
117 |
faiss.normalize_L2(embeddings)
|
118 |
|
119 |
# Upsert logic:
|
@@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
127 |
existing_ids_to_remove.append(faiss_internal_id)
|
128 |
|
129 |
if existing_ids_to_remove:
|
130 |
-
self._remove_faiss_ids(existing_ids_to_remove)
|
131 |
|
132 |
# Step 2: Add new vectors
|
133 |
-
|
134 |
-
|
|
|
135 |
|
136 |
# Step 3: Store metadata + vector for each new ID
|
137 |
for i, meta in enumerate(list_data):
|
138 |
fid = start_idx + i
|
139 |
# Store the raw vector so we can rebuild if something is removed
|
140 |
meta["__vector__"] = embeddings[i].tolist()
|
141 |
-
self._id_to_meta
|
142 |
|
143 |
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
144 |
return [m["__id__"] for m in list_data]
|
@@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
157 |
)
|
158 |
|
159 |
# Perform the similarity search
|
160 |
-
|
|
|
161 |
|
162 |
distances = distances[0]
|
163 |
indices = indices[0]
|
@@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
201 |
to_remove.append(fid)
|
202 |
|
203 |
if to_remove:
|
204 |
-
self._remove_faiss_ids(to_remove)
|
205 |
-
logger.
|
206 |
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
207 |
)
|
208 |
|
@@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
223 |
|
224 |
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
225 |
if relations:
|
226 |
-
self._remove_faiss_ids(relations)
|
227 |
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
228 |
|
229 |
-
async def index_done_callback(self) -> None:
|
230 |
-
self._save_faiss_index()
|
231 |
-
|
232 |
# --------------------------------------------------------------------------------
|
233 |
# Internal helper methods
|
234 |
# --------------------------------------------------------------------------------
|
@@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
242 |
return fid
|
243 |
return None
|
244 |
|
245 |
-
def _remove_faiss_ids(self, fid_list):
|
246 |
"""
|
247 |
Remove a list of internal Faiss IDs from the index.
|
248 |
Because IndexFlatIP doesn't support 'removals',
|
@@ -258,13 +284,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
258 |
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
259 |
new_id_to_meta[new_fid] = vec_meta
|
260 |
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
266 |
|
267 |
-
|
268 |
|
269 |
def _save_faiss_index(self):
|
270 |
"""
|
@@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
312 |
logger.warning("Starting with an empty Faiss index.")
|
313 |
self._index = faiss.IndexFlatIP(self._dim)
|
314 |
self._id_to_meta = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import time
|
3 |
import asyncio
|
4 |
from typing import Any, final
|
|
|
5 |
import json
|
6 |
import numpy as np
|
7 |
|
8 |
from dataclasses import dataclass
|
9 |
import pipmaster as pm
|
10 |
|
11 |
+
from lightrag.utils import logger, compute_mdhash_id
|
12 |
+
from lightrag.base import BaseVectorStorage
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
if not pm.is_installed("faiss"):
|
15 |
pm.install("faiss")
|
16 |
|
17 |
+
import faiss # type: ignore
|
18 |
+
from .shared_storage import (
|
19 |
+
get_storage_lock,
|
20 |
+
get_update_flag,
|
21 |
+
set_all_update_flags,
|
22 |
+
is_multiprocess,
|
23 |
+
)
|
24 |
|
25 |
|
26 |
@final
|
|
|
55 |
# If you have a large number of vectors, you might want IVF or other indexes.
|
56 |
# For demonstration, we use a simple IndexFlatIP.
|
57 |
self._index = faiss.IndexFlatIP(self._dim)
|
|
|
58 |
# Keep a local store for metadata, IDs, etc.
|
59 |
# Maps <int faiss_id> → metadata (including your original ID).
|
60 |
self._id_to_meta = {}
|
61 |
|
|
|
62 |
self._load_faiss_index()
|
63 |
|
64 |
+
async def initialize(self):
|
65 |
+
"""Initialize storage data"""
|
66 |
+
# Get the update flag for cross-process update notification
|
67 |
+
self.storage_updated = await get_update_flag(self.namespace)
|
68 |
+
# Get the storage lock for use in other methods
|
69 |
+
self._storage_lock = get_storage_lock()
|
70 |
+
|
71 |
+
async def _get_index(self):
|
72 |
+
"""Check if the shtorage should be reloaded"""
|
73 |
+
# Acquire lock to prevent concurrent read and write
|
74 |
+
async with self._storage_lock:
|
75 |
+
# Check if storage was updated by another process
|
76 |
+
if (is_multiprocess and self.storage_updated.value) or (
|
77 |
+
not is_multiprocess and self.storage_updated
|
78 |
+
):
|
79 |
+
logger.info(
|
80 |
+
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
|
81 |
+
)
|
82 |
+
# Reload data
|
83 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
84 |
+
self._id_to_meta = {}
|
85 |
+
self._load_faiss_index()
|
86 |
+
if is_multiprocess:
|
87 |
+
self.storage_updated.value = False
|
88 |
+
else:
|
89 |
+
self.storage_updated = False
|
90 |
+
return self._index
|
91 |
+
|
92 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
93 |
"""
|
94 |
Insert or update vectors in the Faiss index.
|
|
|
139 |
)
|
140 |
return []
|
141 |
|
142 |
+
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
|
143 |
+
embeddings = embeddings.astype(np.float32)
|
144 |
faiss.normalize_L2(embeddings)
|
145 |
|
146 |
# Upsert logic:
|
|
|
154 |
existing_ids_to_remove.append(faiss_internal_id)
|
155 |
|
156 |
if existing_ids_to_remove:
|
157 |
+
await self._remove_faiss_ids(existing_ids_to_remove)
|
158 |
|
159 |
# Step 2: Add new vectors
|
160 |
+
index = await self._get_index()
|
161 |
+
start_idx = index.ntotal
|
162 |
+
index.add(embeddings)
|
163 |
|
164 |
# Step 3: Store metadata + vector for each new ID
|
165 |
for i, meta in enumerate(list_data):
|
166 |
fid = start_idx + i
|
167 |
# Store the raw vector so we can rebuild if something is removed
|
168 |
meta["__vector__"] = embeddings[i].tolist()
|
169 |
+
self._id_to_meta.update({fid: meta})
|
170 |
|
171 |
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
172 |
return [m["__id__"] for m in list_data]
|
|
|
185 |
)
|
186 |
|
187 |
# Perform the similarity search
|
188 |
+
index = await self._get_index()
|
189 |
+
distances, indices = index.search(embedding, top_k)
|
190 |
|
191 |
distances = distances[0]
|
192 |
indices = indices[0]
|
|
|
230 |
to_remove.append(fid)
|
231 |
|
232 |
if to_remove:
|
233 |
+
await self._remove_faiss_ids(to_remove)
|
234 |
+
logger.debug(
|
235 |
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
236 |
)
|
237 |
|
|
|
252 |
|
253 |
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
254 |
if relations:
|
255 |
+
await self._remove_faiss_ids(relations)
|
256 |
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
257 |
|
|
|
|
|
|
|
258 |
# --------------------------------------------------------------------------------
|
259 |
# Internal helper methods
|
260 |
# --------------------------------------------------------------------------------
|
|
|
268 |
return fid
|
269 |
return None
|
270 |
|
271 |
+
async def _remove_faiss_ids(self, fid_list):
|
272 |
"""
|
273 |
Remove a list of internal Faiss IDs from the index.
|
274 |
Because IndexFlatIP doesn't support 'removals',
|
|
|
284 |
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
285 |
new_id_to_meta[new_fid] = vec_meta
|
286 |
|
287 |
+
async with self._storage_lock:
|
288 |
+
# Re-init index
|
289 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
290 |
+
if vectors_to_keep:
|
291 |
+
arr = np.array(vectors_to_keep, dtype=np.float32)
|
292 |
+
self._index.add(arr)
|
293 |
|
294 |
+
self._id_to_meta = new_id_to_meta
|
295 |
|
296 |
def _save_faiss_index(self):
|
297 |
"""
|
|
|
339 |
logger.warning("Starting with an empty Faiss index.")
|
340 |
self._index = faiss.IndexFlatIP(self._dim)
|
341 |
self._id_to_meta = {}
|
342 |
+
|
343 |
+
async def index_done_callback(self) -> None:
|
344 |
+
# Check if storage was updated by another process
|
345 |
+
if is_multiprocess and self.storage_updated.value:
|
346 |
+
# Storage was updated by another process, reload data instead of saving
|
347 |
+
logger.warning(
|
348 |
+
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
349 |
+
)
|
350 |
+
async with self._storage_lock:
|
351 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
352 |
+
self._id_to_meta = {}
|
353 |
+
self._load_faiss_index()
|
354 |
+
self.storage_updated.value = False
|
355 |
+
return False # Return error
|
356 |
+
|
357 |
+
# Acquire lock and perform persistence
|
358 |
+
async with self._storage_lock:
|
359 |
+
try:
|
360 |
+
# Save data to disk
|
361 |
+
self._save_faiss_index()
|
362 |
+
# Notify other processes that data has been updated
|
363 |
+
await set_all_update_flags(self.namespace)
|
364 |
+
# Reset own update flag to avoid self-reloading
|
365 |
+
if is_multiprocess:
|
366 |
+
self.storage_updated.value = False
|
367 |
+
else:
|
368 |
+
self.storage_updated = False
|
369 |
+
except Exception as e:
|
370 |
+
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
371 |
+
return False # Return error
|
372 |
+
|
373 |
+
return True # Return success
|
lightrag/kg/json_doc_status_impl.py
CHANGED
@@ -12,6 +12,11 @@ from lightrag.utils import (
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@final
|
@@ -22,26 +27,42 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
22 |
def __post_init__(self):
|
23 |
working_dir = self.global_config["working_dir"]
|
24 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
25 |
-
self.
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
29 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
30 |
-
|
|
|
31 |
|
32 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
33 |
result: list[dict[str, Any]] = []
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
return result
|
39 |
|
40 |
async def get_status_counts(self) -> dict[str, int]:
|
41 |
"""Get counts of documents in each status"""
|
42 |
counts = {status.value: 0 for status in DocStatus}
|
43 |
-
|
44 |
-
|
|
|
45 |
return counts
|
46 |
|
47 |
async def get_docs_by_status(
|
@@ -49,39 +70,48 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
49 |
) -> dict[str, DocProcessingStatus]:
|
50 |
"""Get all documents with a specific status"""
|
51 |
result = {}
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
return result
|
65 |
|
66 |
async def index_done_callback(self) -> None:
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
|
69 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
70 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
71 |
if not data:
|
72 |
return
|
73 |
|
74 |
-
self.
|
|
|
75 |
await self.index_done_callback()
|
76 |
|
77 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
78 |
-
|
|
|
79 |
|
80 |
async def delete(self, doc_ids: list[str]):
|
81 |
-
|
82 |
-
|
|
|
83 |
await self.index_done_callback()
|
84 |
|
85 |
async def drop(self) -> None:
|
86 |
"""Drop the storage"""
|
87 |
-
self.
|
|
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
15 |
+
from .shared_storage import (
|
16 |
+
get_namespace_data,
|
17 |
+
get_storage_lock,
|
18 |
+
try_initialize_namespace,
|
19 |
+
)
|
20 |
|
21 |
|
22 |
@final
|
|
|
27 |
def __post_init__(self):
|
28 |
working_dir = self.global_config["working_dir"]
|
29 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
30 |
+
self._storage_lock = get_storage_lock()
|
31 |
+
self._data = None
|
32 |
+
|
33 |
+
async def initialize(self):
|
34 |
+
"""Initialize storage data"""
|
35 |
+
# check need_init must before get_namespace_data
|
36 |
+
need_init = try_initialize_namespace(self.namespace)
|
37 |
+
self._data = await get_namespace_data(self.namespace)
|
38 |
+
if need_init:
|
39 |
+
loaded_data = load_json(self._file_name) or {}
|
40 |
+
async with self._storage_lock:
|
41 |
+
self._data.update(loaded_data)
|
42 |
+
logger.info(
|
43 |
+
f"Loaded document status storage with {len(loaded_data)} records"
|
44 |
+
)
|
45 |
|
46 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
47 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
48 |
+
async with self._storage_lock:
|
49 |
+
return set(keys) - set(self._data.keys())
|
50 |
|
51 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
52 |
result: list[dict[str, Any]] = []
|
53 |
+
async with self._storage_lock:
|
54 |
+
for id in ids:
|
55 |
+
data = self._data.get(id, None)
|
56 |
+
if data:
|
57 |
+
result.append(data)
|
58 |
return result
|
59 |
|
60 |
async def get_status_counts(self) -> dict[str, int]:
|
61 |
"""Get counts of documents in each status"""
|
62 |
counts = {status.value: 0 for status in DocStatus}
|
63 |
+
async with self._storage_lock:
|
64 |
+
for doc in self._data.values():
|
65 |
+
counts[doc["status"]] += 1
|
66 |
return counts
|
67 |
|
68 |
async def get_docs_by_status(
|
|
|
70 |
) -> dict[str, DocProcessingStatus]:
|
71 |
"""Get all documents with a specific status"""
|
72 |
result = {}
|
73 |
+
async with self._storage_lock:
|
74 |
+
for k, v in self._data.items():
|
75 |
+
if v["status"] == status.value:
|
76 |
+
try:
|
77 |
+
# Make a copy of the data to avoid modifying the original
|
78 |
+
data = v.copy()
|
79 |
+
# If content is missing, use content_summary as content
|
80 |
+
if "content" not in data and "content_summary" in data:
|
81 |
+
data["content"] = data["content_summary"]
|
82 |
+
result[k] = DocProcessingStatus(**data)
|
83 |
+
except KeyError as e:
|
84 |
+
logger.error(f"Missing required field for document {k}: {e}")
|
85 |
+
continue
|
86 |
return result
|
87 |
|
88 |
async def index_done_callback(self) -> None:
|
89 |
+
async with self._storage_lock:
|
90 |
+
data_dict = (
|
91 |
+
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
92 |
+
)
|
93 |
+
write_json(data_dict, self._file_name)
|
94 |
|
95 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
96 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
97 |
if not data:
|
98 |
return
|
99 |
|
100 |
+
async with self._storage_lock:
|
101 |
+
self._data.update(data)
|
102 |
await self.index_done_callback()
|
103 |
|
104 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
105 |
+
async with self._storage_lock:
|
106 |
+
return self._data.get(id)
|
107 |
|
108 |
async def delete(self, doc_ids: list[str]):
|
109 |
+
async with self._storage_lock:
|
110 |
+
for doc_id in doc_ids:
|
111 |
+
self._data.pop(doc_id, None)
|
112 |
await self.index_done_callback()
|
113 |
|
114 |
async def drop(self) -> None:
|
115 |
"""Drop the storage"""
|
116 |
+
async with self._storage_lock:
|
117 |
+
self._data.clear()
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Any, final
|
@@ -11,6 +10,11 @@ from lightrag.utils import (
|
|
11 |
logger,
|
12 |
write_json,
|
13 |
)
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
@final
|
@@ -19,37 +23,56 @@ class JsonKVStorage(BaseKVStorage):
|
|
19 |
def __post_init__(self):
|
20 |
working_dir = self.global_config["working_dir"]
|
21 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
22 |
-
self.
|
23 |
-
self.
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
async def index_done_callback(self) -> None:
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
30 |
-
|
|
|
31 |
|
32 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
43 |
-
|
|
|
44 |
|
45 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
46 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
47 |
if not data:
|
48 |
return
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
async def delete(self, ids: list[str]) -> None:
|
53 |
-
|
54 |
-
|
|
|
55 |
await self.index_done_callback()
|
|
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
|
|
10 |
logger,
|
11 |
write_json,
|
12 |
)
|
13 |
+
from .shared_storage import (
|
14 |
+
get_namespace_data,
|
15 |
+
get_storage_lock,
|
16 |
+
try_initialize_namespace,
|
17 |
+
)
|
18 |
|
19 |
|
20 |
@final
|
|
|
23 |
def __post_init__(self):
|
24 |
working_dir = self.global_config["working_dir"]
|
25 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
26 |
+
self._storage_lock = get_storage_lock()
|
27 |
+
self._data = None
|
28 |
+
|
29 |
+
async def initialize(self):
|
30 |
+
"""Initialize storage data"""
|
31 |
+
# check need_init must before get_namespace_data
|
32 |
+
need_init = try_initialize_namespace(self.namespace)
|
33 |
+
self._data = await get_namespace_data(self.namespace)
|
34 |
+
if need_init:
|
35 |
+
loaded_data = load_json(self._file_name) or {}
|
36 |
+
async with self._storage_lock:
|
37 |
+
self._data.update(loaded_data)
|
38 |
+
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
39 |
|
40 |
async def index_done_callback(self) -> None:
|
41 |
+
async with self._storage_lock:
|
42 |
+
data_dict = (
|
43 |
+
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
44 |
+
)
|
45 |
+
write_json(data_dict, self._file_name)
|
46 |
|
47 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
48 |
+
async with self._storage_lock:
|
49 |
+
return self._data.get(id)
|
50 |
|
51 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
52 |
+
async with self._storage_lock:
|
53 |
+
return [
|
54 |
+
(
|
55 |
+
{k: v for k, v in self._data[id].items()}
|
56 |
+
if self._data.get(id, None)
|
57 |
+
else None
|
58 |
+
)
|
59 |
+
for id in ids
|
60 |
+
]
|
61 |
|
62 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
63 |
+
async with self._storage_lock:
|
64 |
+
return set(keys) - set(self._data.keys())
|
65 |
|
66 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
67 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
68 |
if not data:
|
69 |
return
|
70 |
+
async with self._storage_lock:
|
71 |
+
left_data = {k: v for k, v in data.items() if k not in self._data}
|
72 |
+
self._data.update(left_data)
|
73 |
|
74 |
async def delete(self, ids: list[str]) -> None:
|
75 |
+
async with self._storage_lock:
|
76 |
+
for doc_id in ids:
|
77 |
+
self._data.pop(doc_id, None)
|
78 |
await self.index_done_callback()
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
from typing import Any, final
|
4 |
from dataclasses import dataclass
|
5 |
import numpy as np
|
6 |
-
|
7 |
import time
|
8 |
|
9 |
from lightrag.utils import (
|
@@ -11,22 +10,29 @@ from lightrag.utils import (
|
|
11 |
compute_mdhash_id,
|
12 |
)
|
13 |
import pipmaster as pm
|
14 |
-
from lightrag.base import
|
15 |
-
BaseVectorStorage,
|
16 |
-
)
|
17 |
|
18 |
if not pm.is_installed("nano-vectordb"):
|
19 |
pm.install("nano-vectordb")
|
20 |
|
21 |
from nano_vectordb import NanoVectorDB
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
@final
|
25 |
@dataclass
|
26 |
class NanoVectorDBStorage(BaseVectorStorage):
|
27 |
def __post_init__(self):
|
28 |
-
# Initialize
|
29 |
-
self.
|
|
|
|
|
|
|
30 |
# Use global config value if specified, otherwise use default
|
31 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
32 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
@@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
40 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
41 |
)
|
42 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
43 |
self._client = NanoVectorDB(
|
44 |
-
self.embedding_func.embedding_dim,
|
|
|
45 |
)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
48 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
49 |
if not data:
|
@@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
64 |
for i in range(0, len(contents), self._max_batch_size)
|
65 |
]
|
66 |
|
|
|
67 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
68 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
69 |
|
@@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
71 |
if len(embeddings) == len(list_data):
|
72 |
for i, d in enumerate(list_data):
|
73 |
d["__vector__"] = embeddings[i]
|
74 |
-
|
|
|
75 |
return results
|
76 |
else:
|
77 |
# sometimes the embedding is not returned correctly. just log it.
|
@@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
80 |
)
|
81 |
|
82 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
|
|
83 |
embedding = await self.embedding_func([query])
|
84 |
embedding = embedding[0]
|
85 |
-
|
|
|
|
|
86 |
query=embedding,
|
87 |
top_k=top_k,
|
88 |
better_than_threshold=self.cosine_better_than_threshold,
|
@@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
99 |
return results
|
100 |
|
101 |
@property
|
102 |
-
def client_storage(self):
|
103 |
-
|
|
|
104 |
|
105 |
async def delete(self, ids: list[str]):
|
106 |
"""Delete vectors with specified IDs
|
@@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
109 |
ids: List of vector IDs to be deleted
|
110 |
"""
|
111 |
try:
|
112 |
-
self.
|
113 |
-
|
|
|
114 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
115 |
)
|
116 |
except Exception as e:
|
@@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
122 |
logger.debug(
|
123 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
124 |
)
|
|
|
125 |
# Check if the entity exists
|
126 |
-
|
127 |
-
|
|
|
128 |
logger.debug(f"Successfully deleted entity {entity_name}")
|
129 |
else:
|
130 |
logger.debug(f"Entity {entity_name} not found in storage")
|
@@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
133 |
|
134 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
135 |
try:
|
|
|
|
|
136 |
relations = [
|
137 |
dp
|
138 |
-
for dp in
|
139 |
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
140 |
]
|
141 |
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
142 |
ids_to_delete = [relation["__id__"] for relation in relations]
|
143 |
|
144 |
if ids_to_delete:
|
145 |
-
await self.
|
|
|
146 |
logger.debug(
|
147 |
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
148 |
)
|
@@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
151 |
except Exception as e:
|
152 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
153 |
|
154 |
-
async def index_done_callback(self) ->
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from typing import Any, final
|
4 |
from dataclasses import dataclass
|
5 |
import numpy as np
|
|
|
6 |
import time
|
7 |
|
8 |
from lightrag.utils import (
|
|
|
10 |
compute_mdhash_id,
|
11 |
)
|
12 |
import pipmaster as pm
|
13 |
+
from lightrag.base import BaseVectorStorage
|
|
|
|
|
14 |
|
15 |
if not pm.is_installed("nano-vectordb"):
|
16 |
pm.install("nano-vectordb")
|
17 |
|
18 |
from nano_vectordb import NanoVectorDB
|
19 |
+
from .shared_storage import (
|
20 |
+
get_storage_lock,
|
21 |
+
get_update_flag,
|
22 |
+
set_all_update_flags,
|
23 |
+
is_multiprocess,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
@final
|
28 |
@dataclass
|
29 |
class NanoVectorDBStorage(BaseVectorStorage):
|
30 |
def __post_init__(self):
|
31 |
+
# Initialize basic attributes
|
32 |
+
self._client = None
|
33 |
+
self._storage_lock = None
|
34 |
+
self.storage_updated = None
|
35 |
+
|
36 |
# Use global config value if specified, otherwise use default
|
37 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
38 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
|
|
46 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
47 |
)
|
48 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
49 |
+
|
50 |
self._client = NanoVectorDB(
|
51 |
+
self.embedding_func.embedding_dim,
|
52 |
+
storage_file=self._client_file_name,
|
53 |
)
|
54 |
|
55 |
+
async def initialize(self):
|
56 |
+
"""Initialize storage data"""
|
57 |
+
# Get the update flag for cross-process update notification
|
58 |
+
self.storage_updated = await get_update_flag(self.namespace)
|
59 |
+
# Get the storage lock for use in other methods
|
60 |
+
self._storage_lock = get_storage_lock()
|
61 |
+
|
62 |
+
async def _get_client(self):
|
63 |
+
"""Check if the storage should be reloaded"""
|
64 |
+
# Acquire lock to prevent concurrent read and write
|
65 |
+
async with self._storage_lock:
|
66 |
+
# Check if data needs to be reloaded
|
67 |
+
if (is_multiprocess and self.storage_updated.value) or (
|
68 |
+
not is_multiprocess and self.storage_updated
|
69 |
+
):
|
70 |
+
logger.info(
|
71 |
+
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
|
72 |
+
)
|
73 |
+
# Reload data
|
74 |
+
self._client = NanoVectorDB(
|
75 |
+
self.embedding_func.embedding_dim,
|
76 |
+
storage_file=self._client_file_name,
|
77 |
+
)
|
78 |
+
# Reset update flag
|
79 |
+
if is_multiprocess:
|
80 |
+
self.storage_updated.value = False
|
81 |
+
else:
|
82 |
+
self.storage_updated = False
|
83 |
+
|
84 |
+
return self._client
|
85 |
+
|
86 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
87 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
88 |
if not data:
|
|
|
103 |
for i in range(0, len(contents), self._max_batch_size)
|
104 |
]
|
105 |
|
106 |
+
# Execute embedding outside of lock to avoid long lock times
|
107 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
108 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
109 |
|
|
|
111 |
if len(embeddings) == len(list_data):
|
112 |
for i, d in enumerate(list_data):
|
113 |
d["__vector__"] = embeddings[i]
|
114 |
+
client = await self._get_client()
|
115 |
+
results = client.upsert(datas=list_data)
|
116 |
return results
|
117 |
else:
|
118 |
# sometimes the embedding is not returned correctly. just log it.
|
|
|
121 |
)
|
122 |
|
123 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
124 |
+
# Execute embedding outside of lock to avoid long lock times
|
125 |
embedding = await self.embedding_func([query])
|
126 |
embedding = embedding[0]
|
127 |
+
|
128 |
+
client = await self._get_client()
|
129 |
+
results = client.query(
|
130 |
query=embedding,
|
131 |
top_k=top_k,
|
132 |
better_than_threshold=self.cosine_better_than_threshold,
|
|
|
143 |
return results
|
144 |
|
145 |
@property
|
146 |
+
async def client_storage(self):
|
147 |
+
client = await self._get_client()
|
148 |
+
return getattr(client, "_NanoVectorDB__storage")
|
149 |
|
150 |
async def delete(self, ids: list[str]):
|
151 |
"""Delete vectors with specified IDs
|
|
|
154 |
ids: List of vector IDs to be deleted
|
155 |
"""
|
156 |
try:
|
157 |
+
client = await self._get_client()
|
158 |
+
client.delete(ids)
|
159 |
+
logger.debug(
|
160 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
161 |
)
|
162 |
except Exception as e:
|
|
|
168 |
logger.debug(
|
169 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
170 |
)
|
171 |
+
|
172 |
# Check if the entity exists
|
173 |
+
client = await self._get_client()
|
174 |
+
if client.get([entity_id]):
|
175 |
+
client.delete([entity_id])
|
176 |
logger.debug(f"Successfully deleted entity {entity_name}")
|
177 |
else:
|
178 |
logger.debug(f"Entity {entity_name} not found in storage")
|
|
|
181 |
|
182 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
183 |
try:
|
184 |
+
client = await self._get_client()
|
185 |
+
storage = getattr(client, "_NanoVectorDB__storage")
|
186 |
relations = [
|
187 |
dp
|
188 |
+
for dp in storage["data"]
|
189 |
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
190 |
]
|
191 |
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
192 |
ids_to_delete = [relation["__id__"] for relation in relations]
|
193 |
|
194 |
if ids_to_delete:
|
195 |
+
client = await self._get_client()
|
196 |
+
client.delete(ids_to_delete)
|
197 |
logger.debug(
|
198 |
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
199 |
)
|
|
|
202 |
except Exception as e:
|
203 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
204 |
|
205 |
+
async def index_done_callback(self) -> bool:
|
206 |
+
"""Save data to disk"""
|
207 |
+
# Check if storage was updated by another process
|
208 |
+
if is_multiprocess and self.storage_updated.value:
|
209 |
+
# Storage was updated by another process, reload data instead of saving
|
210 |
+
logger.warning(
|
211 |
+
f"Storage for {self.namespace} was updated by another process, reloading..."
|
212 |
+
)
|
213 |
+
self._client = NanoVectorDB(
|
214 |
+
self.embedding_func.embedding_dim,
|
215 |
+
storage_file=self._client_file_name,
|
216 |
+
)
|
217 |
+
# Reset update flag
|
218 |
+
self.storage_updated.value = False
|
219 |
+
return False # Return error
|
220 |
+
|
221 |
+
# Acquire lock and perform persistence
|
222 |
+
async with self._storage_lock:
|
223 |
+
try:
|
224 |
+
# Save data to disk
|
225 |
+
self._client.save()
|
226 |
+
# Notify other processes that data has been updated
|
227 |
+
await set_all_update_flags(self.namespace)
|
228 |
+
# Reset own update flag to avoid self-reloading
|
229 |
+
if is_multiprocess:
|
230 |
+
self.storage_updated.value = False
|
231 |
+
else:
|
232 |
+
self.storage_updated = False
|
233 |
+
return True # Return success
|
234 |
+
except Exception as e:
|
235 |
+
logger.error(f"Error saving data for {self.namespace}: {e}")
|
236 |
+
return False # Return error
|
237 |
+
|
238 |
+
return True # Return success
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -1,18 +1,12 @@
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
4 |
-
|
5 |
import numpy as np
|
6 |
|
7 |
-
|
8 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
9 |
-
from lightrag.utils import
|
10 |
-
|
11 |
-
)
|
12 |
|
13 |
-
from lightrag.base import (
|
14 |
-
BaseGraphStorage,
|
15 |
-
)
|
16 |
import pipmaster as pm
|
17 |
|
18 |
if not pm.is_installed("networkx"):
|
@@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"):
|
|
23 |
|
24 |
import networkx as nx
|
25 |
from graspologic import embed
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
@final
|
@@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage):
|
|
78 |
self._graphml_xml_file = os.path.join(
|
79 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
81 |
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
82 |
if preloaded_graph is not None:
|
83 |
logger.info(
|
84 |
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
85 |
)
|
|
|
|
|
86 |
self._graph = preloaded_graph or nx.Graph()
|
|
|
87 |
self._node_embed_algorithms = {
|
88 |
"node2vec": self._node2vec_embed,
|
89 |
}
|
90 |
|
91 |
-
async def
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
async def has_node(self, node_id: str) -> bool:
|
95 |
-
|
|
|
96 |
|
97 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
98 |
-
|
|
|
99 |
|
100 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
101 |
-
|
|
|
102 |
|
103 |
async def node_degree(self, node_id: str) -> int:
|
104 |
-
|
|
|
105 |
|
106 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
107 |
-
|
|
|
108 |
|
109 |
async def get_edge(
|
110 |
self, source_node_id: str, target_node_id: str
|
111 |
) -> dict[str, str] | None:
|
112 |
-
|
|
|
113 |
|
114 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
115 |
-
|
116 |
-
|
|
|
117 |
return None
|
118 |
|
119 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
120 |
-
self.
|
|
|
121 |
|
122 |
async def upsert_edge(
|
123 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
124 |
) -> None:
|
125 |
-
self.
|
|
|
126 |
|
127 |
async def delete_node(self, node_id: str) -> None:
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
131 |
else:
|
132 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
133 |
|
@@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage):
|
|
138 |
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
139 |
return await self._node_embed_algorithms[algorithm]()
|
140 |
|
141 |
-
#
|
142 |
async def _node2vec_embed(self):
|
|
|
143 |
embeddings, nodes = embed.node2vec_embed(
|
144 |
-
|
145 |
**self.global_config["node2vec_params"],
|
146 |
)
|
147 |
-
|
148 |
-
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
149 |
return embeddings, nodes_ids
|
150 |
|
151 |
-
def remove_nodes(self, nodes: list[str]):
|
152 |
"""Delete multiple nodes
|
153 |
|
154 |
Args:
|
155 |
nodes: List of node IDs to be deleted
|
156 |
"""
|
|
|
157 |
for node in nodes:
|
158 |
-
if
|
159 |
-
|
160 |
|
161 |
-
def remove_edges(self, edges: list[tuple[str, str]]):
|
162 |
"""Delete multiple edges
|
163 |
|
164 |
Args:
|
165 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
166 |
"""
|
|
|
167 |
for source, target in edges:
|
168 |
-
if
|
169 |
-
|
170 |
|
171 |
async def get_all_labels(self) -> list[str]:
|
172 |
"""
|
@@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|
174 |
Returns:
|
175 |
[label1, label2, ...] # Alphabetically sorted label list
|
176 |
"""
|
|
|
177 |
labels = set()
|
178 |
-
for node in
|
179 |
labels.add(str(node)) # Add node id as a label
|
180 |
|
181 |
# Return sorted list
|
@@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage):
|
|
198 |
seen_nodes = set()
|
199 |
seen_edges = set()
|
200 |
|
|
|
|
|
201 |
# Handle special case for "*" label
|
202 |
if node_label == "*":
|
203 |
# For "*", return the entire graph including all nodes and edges
|
204 |
subgraph = (
|
205 |
-
|
206 |
) # Create a copy to avoid modifying the original graph
|
207 |
else:
|
208 |
# Find nodes with matching node id (partial match)
|
209 |
nodes_to_explore = []
|
210 |
-
for n, attr in
|
211 |
if node_label in str(n): # Use partial matching
|
212 |
nodes_to_explore.append(n)
|
213 |
|
@@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
216 |
return result
|
217 |
|
218 |
# Get subgraph using ego_graph
|
219 |
-
subgraph = nx.ego_graph(
|
220 |
|
221 |
# Check if number of nodes exceeds max_graph_nodes
|
222 |
max_graph_nodes = 500
|
@@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage):
|
|
278 |
)
|
279 |
seen_edges.add(edge_id)
|
280 |
|
281 |
-
# logger.info(result.edges)
|
282 |
-
|
283 |
logger.info(
|
284 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
285 |
)
|
286 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
|
|
4 |
import numpy as np
|
5 |
|
|
|
6 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
7 |
+
from lightrag.utils import logger
|
8 |
+
from lightrag.base import BaseGraphStorage
|
|
|
9 |
|
|
|
|
|
|
|
10 |
import pipmaster as pm
|
11 |
|
12 |
if not pm.is_installed("networkx"):
|
|
|
17 |
|
18 |
import networkx as nx
|
19 |
from graspologic import embed
|
20 |
+
from .shared_storage import (
|
21 |
+
get_storage_lock,
|
22 |
+
get_update_flag,
|
23 |
+
set_all_update_flags,
|
24 |
+
is_multiprocess,
|
25 |
+
)
|
26 |
|
27 |
|
28 |
@final
|
|
|
78 |
self._graphml_xml_file = os.path.join(
|
79 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
80 |
)
|
81 |
+
self._storage_lock = None
|
82 |
+
self.storage_updated = None
|
83 |
+
self._graph = None
|
84 |
+
|
85 |
+
# Load initial graph
|
86 |
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
87 |
if preloaded_graph is not None:
|
88 |
logger.info(
|
89 |
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
90 |
)
|
91 |
+
else:
|
92 |
+
logger.info("Created new empty graph")
|
93 |
self._graph = preloaded_graph or nx.Graph()
|
94 |
+
|
95 |
self._node_embed_algorithms = {
|
96 |
"node2vec": self._node2vec_embed,
|
97 |
}
|
98 |
|
99 |
+
async def initialize(self):
|
100 |
+
"""Initialize storage data"""
|
101 |
+
# Get the update flag for cross-process update notification
|
102 |
+
self.storage_updated = await get_update_flag(self.namespace)
|
103 |
+
# Get the storage lock for use in other methods
|
104 |
+
self._storage_lock = get_storage_lock()
|
105 |
+
|
106 |
+
async def _get_graph(self):
|
107 |
+
"""Check if the storage should be reloaded"""
|
108 |
+
# Acquire lock to prevent concurrent read and write
|
109 |
+
async with self._storage_lock:
|
110 |
+
# Check if data needs to be reloaded
|
111 |
+
if (is_multiprocess and self.storage_updated.value) or (
|
112 |
+
not is_multiprocess and self.storage_updated
|
113 |
+
):
|
114 |
+
logger.info(
|
115 |
+
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
|
116 |
+
)
|
117 |
+
# Reload data
|
118 |
+
self._graph = (
|
119 |
+
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
120 |
+
)
|
121 |
+
# Reset update flag
|
122 |
+
if is_multiprocess:
|
123 |
+
self.storage_updated.value = False
|
124 |
+
else:
|
125 |
+
self.storage_updated = False
|
126 |
+
|
127 |
+
return self._graph
|
128 |
|
129 |
async def has_node(self, node_id: str) -> bool:
|
130 |
+
graph = await self._get_graph()
|
131 |
+
return graph.has_node(node_id)
|
132 |
|
133 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
134 |
+
graph = await self._get_graph()
|
135 |
+
return graph.has_edge(source_node_id, target_node_id)
|
136 |
|
137 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
138 |
+
graph = await self._get_graph()
|
139 |
+
return graph.nodes.get(node_id)
|
140 |
|
141 |
async def node_degree(self, node_id: str) -> int:
|
142 |
+
graph = await self._get_graph()
|
143 |
+
return graph.degree(node_id)
|
144 |
|
145 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
146 |
+
graph = await self._get_graph()
|
147 |
+
return graph.degree(src_id) + graph.degree(tgt_id)
|
148 |
|
149 |
async def get_edge(
|
150 |
self, source_node_id: str, target_node_id: str
|
151 |
) -> dict[str, str] | None:
|
152 |
+
graph = await self._get_graph()
|
153 |
+
return graph.edges.get((source_node_id, target_node_id))
|
154 |
|
155 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
156 |
+
graph = await self._get_graph()
|
157 |
+
if graph.has_node(source_node_id):
|
158 |
+
return list(graph.edges(source_node_id))
|
159 |
return None
|
160 |
|
161 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
162 |
+
graph = await self._get_graph()
|
163 |
+
graph.add_node(node_id, **node_data)
|
164 |
|
165 |
async def upsert_edge(
|
166 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
167 |
) -> None:
|
168 |
+
graph = await self._get_graph()
|
169 |
+
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
170 |
|
171 |
async def delete_node(self, node_id: str) -> None:
|
172 |
+
graph = await self._get_graph()
|
173 |
+
if graph.has_node(node_id):
|
174 |
+
graph.remove_node(node_id)
|
175 |
+
logger.debug(f"Node {node_id} deleted from the graph.")
|
176 |
else:
|
177 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
178 |
|
|
|
183 |
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
184 |
return await self._node_embed_algorithms[algorithm]()
|
185 |
|
186 |
+
# TODO: NOT USED
|
187 |
async def _node2vec_embed(self):
|
188 |
+
graph = await self._get_graph()
|
189 |
embeddings, nodes = embed.node2vec_embed(
|
190 |
+
graph,
|
191 |
**self.global_config["node2vec_params"],
|
192 |
)
|
193 |
+
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
|
|
194 |
return embeddings, nodes_ids
|
195 |
|
196 |
+
async def remove_nodes(self, nodes: list[str]):
|
197 |
"""Delete multiple nodes
|
198 |
|
199 |
Args:
|
200 |
nodes: List of node IDs to be deleted
|
201 |
"""
|
202 |
+
graph = await self._get_graph()
|
203 |
for node in nodes:
|
204 |
+
if graph.has_node(node):
|
205 |
+
graph.remove_node(node)
|
206 |
|
207 |
+
async def remove_edges(self, edges: list[tuple[str, str]]):
|
208 |
"""Delete multiple edges
|
209 |
|
210 |
Args:
|
211 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
212 |
"""
|
213 |
+
graph = await self._get_graph()
|
214 |
for source, target in edges:
|
215 |
+
if graph.has_edge(source, target):
|
216 |
+
graph.remove_edge(source, target)
|
217 |
|
218 |
async def get_all_labels(self) -> list[str]:
|
219 |
"""
|
|
|
221 |
Returns:
|
222 |
[label1, label2, ...] # Alphabetically sorted label list
|
223 |
"""
|
224 |
+
graph = await self._get_graph()
|
225 |
labels = set()
|
226 |
+
for node in graph.nodes():
|
227 |
labels.add(str(node)) # Add node id as a label
|
228 |
|
229 |
# Return sorted list
|
|
|
246 |
seen_nodes = set()
|
247 |
seen_edges = set()
|
248 |
|
249 |
+
graph = await self._get_graph()
|
250 |
+
|
251 |
# Handle special case for "*" label
|
252 |
if node_label == "*":
|
253 |
# For "*", return the entire graph including all nodes and edges
|
254 |
subgraph = (
|
255 |
+
graph.copy()
|
256 |
) # Create a copy to avoid modifying the original graph
|
257 |
else:
|
258 |
# Find nodes with matching node id (partial match)
|
259 |
nodes_to_explore = []
|
260 |
+
for n, attr in graph.nodes(data=True):
|
261 |
if node_label in str(n): # Use partial matching
|
262 |
nodes_to_explore.append(n)
|
263 |
|
|
|
266 |
return result
|
267 |
|
268 |
# Get subgraph using ego_graph
|
269 |
+
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
270 |
|
271 |
# Check if number of nodes exceeds max_graph_nodes
|
272 |
max_graph_nodes = 500
|
|
|
328 |
)
|
329 |
seen_edges.add(edge_id)
|
330 |
|
|
|
|
|
331 |
logger.info(
|
332 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
333 |
)
|
334 |
return result
|
335 |
+
|
336 |
+
async def index_done_callback(self) -> bool:
|
337 |
+
"""Save data to disk"""
|
338 |
+
# Check if storage was updated by another process
|
339 |
+
if is_multiprocess and self.storage_updated.value:
|
340 |
+
# Storage was updated by another process, reload data instead of saving
|
341 |
+
logger.warning(
|
342 |
+
f"Graph for {self.namespace} was updated by another process, reloading..."
|
343 |
+
)
|
344 |
+
self._graph = (
|
345 |
+
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
346 |
+
)
|
347 |
+
# Reset update flag
|
348 |
+
self.storage_updated.value = False
|
349 |
+
return False # Return error
|
350 |
+
|
351 |
+
# Acquire lock and perform persistence
|
352 |
+
async with self._storage_lock:
|
353 |
+
try:
|
354 |
+
# Save data to disk
|
355 |
+
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
356 |
+
# Notify other processes that data has been updated
|
357 |
+
await set_all_update_flags(self.namespace)
|
358 |
+
# Reset own update flag to avoid self-reloading
|
359 |
+
if is_multiprocess:
|
360 |
+
self.storage_updated.value = False
|
361 |
+
else:
|
362 |
+
self.storage_updated = False
|
363 |
+
return True # Return success
|
364 |
+
except Exception as e:
|
365 |
+
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
366 |
+
return False # Return error
|
367 |
+
|
368 |
+
return True
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -38,8 +38,8 @@ import pipmaster as pm
|
|
38 |
if not pm.is_installed("asyncpg"):
|
39 |
pm.install("asyncpg")
|
40 |
|
41 |
-
import asyncpg
|
42 |
-
from asyncpg import Pool
|
43 |
|
44 |
|
45 |
class PostgreSQLDB:
|
|
|
38 |
if not pm.is_installed("asyncpg"):
|
39 |
pm.install("asyncpg")
|
40 |
|
41 |
+
import asyncpg # type: ignore
|
42 |
+
from asyncpg import Pool # type: ignore
|
43 |
|
44 |
|
45 |
class PostgreSQLDB:
|
lightrag/kg/shared_storage.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import asyncio
|
4 |
+
from multiprocessing.synchronize import Lock as ProcessLock
|
5 |
+
from multiprocessing import Manager
|
6 |
+
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
7 |
+
|
8 |
+
|
9 |
+
# Define a direct print function for critical logs that must be visible in all processes
|
10 |
+
def direct_log(message, level="INFO"):
|
11 |
+
"""
|
12 |
+
Log a message directly to stderr to ensure visibility in all processes,
|
13 |
+
including the Gunicorn master process.
|
14 |
+
"""
|
15 |
+
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
16 |
+
|
17 |
+
|
18 |
+
T = TypeVar("T")
|
19 |
+
LockType = Union[ProcessLock, asyncio.Lock]
|
20 |
+
|
21 |
+
is_multiprocess = None
|
22 |
+
_workers = None
|
23 |
+
_manager = None
|
24 |
+
_initialized = None
|
25 |
+
|
26 |
+
# shared data for storage across processes
|
27 |
+
_shared_dicts: Optional[Dict[str, Any]] = None
|
28 |
+
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
29 |
+
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
30 |
+
|
31 |
+
# locks for mutex access
|
32 |
+
_storage_lock: Optional[LockType] = None
|
33 |
+
_internal_lock: Optional[LockType] = None
|
34 |
+
_pipeline_status_lock: Optional[LockType] = None
|
35 |
+
|
36 |
+
|
37 |
+
class UnifiedLock(Generic[T]):
|
38 |
+
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
39 |
+
|
40 |
+
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
41 |
+
self._lock = lock
|
42 |
+
self._is_async = is_async
|
43 |
+
|
44 |
+
async def __aenter__(self) -> "UnifiedLock[T]":
|
45 |
+
if self._is_async:
|
46 |
+
await self._lock.acquire()
|
47 |
+
else:
|
48 |
+
self._lock.acquire()
|
49 |
+
return self
|
50 |
+
|
51 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
52 |
+
if self._is_async:
|
53 |
+
self._lock.release()
|
54 |
+
else:
|
55 |
+
self._lock.release()
|
56 |
+
|
57 |
+
def __enter__(self) -> "UnifiedLock[T]":
|
58 |
+
"""For backward compatibility"""
|
59 |
+
if self._is_async:
|
60 |
+
raise RuntimeError("Use 'async with' for shared_storage lock")
|
61 |
+
self._lock.acquire()
|
62 |
+
return self
|
63 |
+
|
64 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
65 |
+
"""For backward compatibility"""
|
66 |
+
if self._is_async:
|
67 |
+
raise RuntimeError("Use 'async with' for shared_storage lock")
|
68 |
+
self._lock.release()
|
69 |
+
|
70 |
+
|
71 |
+
def get_internal_lock() -> UnifiedLock:
|
72 |
+
"""return unified storage lock for data consistency"""
|
73 |
+
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
|
74 |
+
|
75 |
+
|
76 |
+
def get_storage_lock() -> UnifiedLock:
|
77 |
+
"""return unified storage lock for data consistency"""
|
78 |
+
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
|
79 |
+
|
80 |
+
|
81 |
+
def get_pipeline_status_lock() -> UnifiedLock:
|
82 |
+
"""return unified storage lock for data consistency"""
|
83 |
+
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
|
84 |
+
|
85 |
+
|
86 |
+
def initialize_share_data(workers: int = 1):
|
87 |
+
"""
|
88 |
+
Initialize shared storage data for single or multi-process mode.
|
89 |
+
|
90 |
+
When used with Gunicorn's preload feature, this function is called once in the
|
91 |
+
master process before forking worker processes, allowing all workers to share
|
92 |
+
the same initialized data.
|
93 |
+
|
94 |
+
In single-process mode, this function is called in FASTAPI lifespan function.
|
95 |
+
|
96 |
+
The function determines whether to use cross-process shared variables for data storage
|
97 |
+
based on the number of workers. If workers=1, it uses thread locks and local dictionaries.
|
98 |
+
If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
workers (int): Number of worker processes. If 1, single-process mode is used.
|
102 |
+
If > 1, multi-process mode with shared memory is used.
|
103 |
+
"""
|
104 |
+
global \
|
105 |
+
_manager, \
|
106 |
+
_workers, \
|
107 |
+
is_multiprocess, \
|
108 |
+
_storage_lock, \
|
109 |
+
_internal_lock, \
|
110 |
+
_pipeline_status_lock, \
|
111 |
+
_shared_dicts, \
|
112 |
+
_init_flags, \
|
113 |
+
_initialized, \
|
114 |
+
_update_flags
|
115 |
+
|
116 |
+
# Check if already initialized
|
117 |
+
if _initialized:
|
118 |
+
direct_log(
|
119 |
+
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
|
120 |
+
)
|
121 |
+
return
|
122 |
+
|
123 |
+
_manager = Manager()
|
124 |
+
_workers = workers
|
125 |
+
|
126 |
+
if workers > 1:
|
127 |
+
is_multiprocess = True
|
128 |
+
_internal_lock = _manager.Lock()
|
129 |
+
_storage_lock = _manager.Lock()
|
130 |
+
_pipeline_status_lock = _manager.Lock()
|
131 |
+
_shared_dicts = _manager.dict()
|
132 |
+
_init_flags = _manager.dict()
|
133 |
+
_update_flags = _manager.dict()
|
134 |
+
direct_log(
|
135 |
+
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
is_multiprocess = False
|
139 |
+
_internal_lock = asyncio.Lock()
|
140 |
+
_storage_lock = asyncio.Lock()
|
141 |
+
_pipeline_status_lock = asyncio.Lock()
|
142 |
+
_shared_dicts = {}
|
143 |
+
_init_flags = {}
|
144 |
+
_update_flags = {}
|
145 |
+
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
146 |
+
|
147 |
+
# Mark as initialized
|
148 |
+
_initialized = True
|
149 |
+
|
150 |
+
|
151 |
+
async def initialize_pipeline_status():
|
152 |
+
"""
|
153 |
+
Initialize pipeline namespace with default values.
|
154 |
+
This function is called during FASTAPI lifespan for each worker.
|
155 |
+
"""
|
156 |
+
pipeline_namespace = await get_namespace_data("pipeline_status")
|
157 |
+
|
158 |
+
async with get_internal_lock():
|
159 |
+
# Check if already initialized by checking for required fields
|
160 |
+
if "busy" in pipeline_namespace:
|
161 |
+
return
|
162 |
+
|
163 |
+
# Create a shared list object for history_messages
|
164 |
+
history_messages = _manager.list() if is_multiprocess else []
|
165 |
+
pipeline_namespace.update(
|
166 |
+
{
|
167 |
+
"busy": False, # Control concurrent processes
|
168 |
+
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
169 |
+
"job_start": None, # Job start time
|
170 |
+
"docs": 0, # Total number of documents to be indexed
|
171 |
+
"batchs": 0, # Number of batches for processing documents
|
172 |
+
"cur_batch": 0, # Current processing batch
|
173 |
+
"request_pending": False, # Flag for pending request for processing
|
174 |
+
"latest_message": "", # Latest message from pipeline processing
|
175 |
+
"history_messages": history_messages, # 使用共享列表对象
|
176 |
+
}
|
177 |
+
)
|
178 |
+
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
|
179 |
+
|
180 |
+
|
181 |
+
async def get_update_flag(namespace: str):
|
182 |
+
"""
|
183 |
+
Create a namespace's update flag for a workers.
|
184 |
+
Returen the update flag to caller for referencing or reset.
|
185 |
+
"""
|
186 |
+
global _update_flags
|
187 |
+
if _update_flags is None:
|
188 |
+
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
189 |
+
|
190 |
+
async with get_internal_lock():
|
191 |
+
if namespace not in _update_flags:
|
192 |
+
if is_multiprocess and _manager is not None:
|
193 |
+
_update_flags[namespace] = _manager.list()
|
194 |
+
else:
|
195 |
+
_update_flags[namespace] = []
|
196 |
+
direct_log(
|
197 |
+
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
|
198 |
+
)
|
199 |
+
|
200 |
+
if is_multiprocess and _manager is not None:
|
201 |
+
new_update_flag = _manager.Value("b", False)
|
202 |
+
else:
|
203 |
+
new_update_flag = False
|
204 |
+
|
205 |
+
_update_flags[namespace].append(new_update_flag)
|
206 |
+
return new_update_flag
|
207 |
+
|
208 |
+
|
209 |
+
async def set_all_update_flags(namespace: str):
|
210 |
+
"""Set all update flag of namespace indicating all workers need to reload data from files"""
|
211 |
+
global _update_flags
|
212 |
+
if _update_flags is None:
|
213 |
+
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
214 |
+
|
215 |
+
async with get_internal_lock():
|
216 |
+
if namespace not in _update_flags:
|
217 |
+
raise ValueError(f"Namespace {namespace} not found in update flags")
|
218 |
+
# Update flags for both modes
|
219 |
+
for i in range(len(_update_flags[namespace])):
|
220 |
+
if is_multiprocess:
|
221 |
+
_update_flags[namespace][i].value = True
|
222 |
+
else:
|
223 |
+
_update_flags[namespace][i] = True
|
224 |
+
|
225 |
+
|
226 |
+
async def get_all_update_flags_status() -> Dict[str, list]:
|
227 |
+
"""
|
228 |
+
Get update flags status for all namespaces.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
|
232 |
+
"""
|
233 |
+
if _update_flags is None:
|
234 |
+
return {}
|
235 |
+
|
236 |
+
result = {}
|
237 |
+
async with get_internal_lock():
|
238 |
+
for namespace, flags in _update_flags.items():
|
239 |
+
worker_statuses = []
|
240 |
+
for flag in flags:
|
241 |
+
if is_multiprocess:
|
242 |
+
worker_statuses.append(flag.value)
|
243 |
+
else:
|
244 |
+
worker_statuses.append(flag)
|
245 |
+
result[namespace] = worker_statuses
|
246 |
+
|
247 |
+
return result
|
248 |
+
|
249 |
+
|
250 |
+
def try_initialize_namespace(namespace: str) -> bool:
|
251 |
+
"""
|
252 |
+
Returns True if the current worker(process) gets initialization permission for loading data later.
|
253 |
+
The worker does not get the permission is prohibited to load data from files.
|
254 |
+
"""
|
255 |
+
global _init_flags, _manager
|
256 |
+
|
257 |
+
if _init_flags is None:
|
258 |
+
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
259 |
+
|
260 |
+
if namespace not in _init_flags:
|
261 |
+
_init_flags[namespace] = True
|
262 |
+
direct_log(
|
263 |
+
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
264 |
+
)
|
265 |
+
return True
|
266 |
+
direct_log(
|
267 |
+
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
268 |
+
)
|
269 |
+
return False
|
270 |
+
|
271 |
+
|
272 |
+
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
273 |
+
"""get the shared data reference for specific namespace"""
|
274 |
+
if _shared_dicts is None:
|
275 |
+
direct_log(
|
276 |
+
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
|
277 |
+
level="ERROR",
|
278 |
+
)
|
279 |
+
raise ValueError("Shared dictionaries not initialized")
|
280 |
+
|
281 |
+
async with get_internal_lock():
|
282 |
+
if namespace not in _shared_dicts:
|
283 |
+
if is_multiprocess and _manager is not None:
|
284 |
+
_shared_dicts[namespace] = _manager.dict()
|
285 |
+
else:
|
286 |
+
_shared_dicts[namespace] = {}
|
287 |
+
|
288 |
+
return _shared_dicts[namespace]
|
289 |
+
|
290 |
+
|
291 |
+
def finalize_share_data():
|
292 |
+
"""
|
293 |
+
Release shared resources and clean up.
|
294 |
+
|
295 |
+
This function should be called when the application is shutting down
|
296 |
+
to properly release shared resources and avoid memory leaks.
|
297 |
+
|
298 |
+
In multi-process mode, it shuts down the Manager and releases all shared objects.
|
299 |
+
In single-process mode, it simply resets the global variables.
|
300 |
+
"""
|
301 |
+
global \
|
302 |
+
_manager, \
|
303 |
+
is_multiprocess, \
|
304 |
+
_storage_lock, \
|
305 |
+
_internal_lock, \
|
306 |
+
_pipeline_status_lock, \
|
307 |
+
_shared_dicts, \
|
308 |
+
_init_flags, \
|
309 |
+
_initialized, \
|
310 |
+
_update_flags
|
311 |
+
|
312 |
+
# Check if already initialized
|
313 |
+
if not _initialized:
|
314 |
+
direct_log(
|
315 |
+
f"Process {os.getpid()} storage data not initialized, nothing to finalize"
|
316 |
+
)
|
317 |
+
return
|
318 |
+
|
319 |
+
direct_log(
|
320 |
+
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
|
321 |
+
)
|
322 |
+
|
323 |
+
# In multi-process mode, shut down the Manager
|
324 |
+
if is_multiprocess and _manager is not None:
|
325 |
+
try:
|
326 |
+
# Clear shared resources before shutting down Manager
|
327 |
+
if _shared_dicts is not None:
|
328 |
+
# Clear pipeline status history messages first if exists
|
329 |
+
try:
|
330 |
+
pipeline_status = _shared_dicts.get("pipeline_status", {})
|
331 |
+
if "history_messages" in pipeline_status:
|
332 |
+
pipeline_status["history_messages"].clear()
|
333 |
+
except Exception:
|
334 |
+
pass # Ignore any errors during history messages cleanup
|
335 |
+
_shared_dicts.clear()
|
336 |
+
if _init_flags is not None:
|
337 |
+
_init_flags.clear()
|
338 |
+
if _update_flags is not None:
|
339 |
+
# Clear each namespace's update flags list and Value objects
|
340 |
+
try:
|
341 |
+
for namespace in _update_flags:
|
342 |
+
flags_list = _update_flags[namespace]
|
343 |
+
if isinstance(flags_list, list):
|
344 |
+
# Clear Value objects in the list
|
345 |
+
for flag in flags_list:
|
346 |
+
if hasattr(
|
347 |
+
flag, "value"
|
348 |
+
): # Check if it's a Value object
|
349 |
+
flag.value = False
|
350 |
+
flags_list.clear()
|
351 |
+
except Exception:
|
352 |
+
pass # Ignore any errors during update flags cleanup
|
353 |
+
_update_flags.clear()
|
354 |
+
|
355 |
+
# Shut down the Manager - this will automatically clean up all shared resources
|
356 |
+
_manager.shutdown()
|
357 |
+
direct_log(f"Process {os.getpid()} Manager shutdown complete")
|
358 |
+
except Exception as e:
|
359 |
+
direct_log(
|
360 |
+
f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR"
|
361 |
+
)
|
362 |
+
|
363 |
+
# Reset global variables
|
364 |
+
_manager = None
|
365 |
+
_initialized = None
|
366 |
+
is_multiprocess = None
|
367 |
+
_shared_dicts = None
|
368 |
+
_init_flags = None
|
369 |
+
_storage_lock = None
|
370 |
+
_internal_lock = None
|
371 |
+
_pipeline_status_lock = None
|
372 |
+
_update_flags = None
|
373 |
+
|
374 |
+
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
lightrag/lightrag.py
CHANGED
@@ -45,7 +45,6 @@ from .utils import (
|
|
45 |
lazy_external_import,
|
46 |
limit_async_func_call,
|
47 |
logger,
|
48 |
-
set_logger,
|
49 |
)
|
50 |
from .types import KnowledgeGraph
|
51 |
from dotenv import load_dotenv
|
@@ -268,9 +267,14 @@ class LightRAG:
|
|
268 |
|
269 |
def __post_init__(self):
|
270 |
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
271 |
-
set_logger(self.log_file_path, self.log_level)
|
272 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
if not os.path.exists(self.working_dir):
|
275 |
logger.info(f"Creating working directory {self.working_dir}")
|
276 |
os.makedirs(self.working_dir)
|
@@ -692,117 +696,221 @@ class LightRAG:
|
|
692 |
3. Process each chunk for entity and relation extraction
|
693 |
4. Update the document status
|
694 |
"""
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
699 |
-
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
700 |
-
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
701 |
)
|
702 |
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
717 |
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
status_doc.content,
|
741 |
-
split_by_character,
|
742 |
-
split_by_character_only,
|
743 |
-
self.chunk_overlap_token_size,
|
744 |
-
self.chunk_token_size,
|
745 |
-
self.tiktoken_model_name,
|
746 |
-
)
|
747 |
-
}
|
748 |
-
# Process document (text chunks and full docs) in parallel
|
749 |
-
tasks = [
|
750 |
-
self.doc_status.upsert(
|
751 |
-
{
|
752 |
-
doc_id: {
|
753 |
-
"status": DocStatus.PROCESSING,
|
754 |
-
"updated_at": datetime.now().isoformat(),
|
755 |
-
"content": status_doc.content,
|
756 |
-
"content_summary": status_doc.content_summary,
|
757 |
-
"content_length": status_doc.content_length,
|
758 |
-
"created_at": status_doc.created_at,
|
759 |
-
}
|
760 |
-
}
|
761 |
-
),
|
762 |
-
self.chunks_vdb.upsert(chunks),
|
763 |
-
self._process_entity_relation_graph(chunks),
|
764 |
-
self.full_docs.upsert(
|
765 |
-
{doc_id: {"content": status_doc.content}}
|
766 |
-
),
|
767 |
-
self.text_chunks.upsert(chunks),
|
768 |
-
]
|
769 |
-
try:
|
770 |
-
await asyncio.gather(*tasks)
|
771 |
-
await self.doc_status.upsert(
|
772 |
-
{
|
773 |
-
doc_id: {
|
774 |
-
"status": DocStatus.PROCESSED,
|
775 |
-
"chunks_count": len(chunks),
|
776 |
-
"content": status_doc.content,
|
777 |
-
"content_summary": status_doc.content_summary,
|
778 |
-
"content_length": status_doc.content_length,
|
779 |
-
"created_at": status_doc.created_at,
|
780 |
-
"updated_at": datetime.now().isoformat(),
|
781 |
-
}
|
782 |
-
}
|
783 |
)
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
"
|
795 |
-
"updated_at": datetime.now().isoformat(),
|
796 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
)
|
799 |
-
|
800 |
-
|
|
|
|
|
|
|
801 |
|
802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
803 |
|
804 |
-
|
805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
|
807 |
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
808 |
try:
|
@@ -833,7 +941,16 @@ class LightRAG:
|
|
833 |
if storage_inst is not None
|
834 |
]
|
835 |
await asyncio.gather(*tasks)
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
839 |
loop = always_get_an_event_loop()
|
|
|
45 |
lazy_external_import,
|
46 |
limit_async_func_call,
|
47 |
logger,
|
|
|
48 |
)
|
49 |
from .types import KnowledgeGraph
|
50 |
from dotenv import load_dotenv
|
|
|
267 |
|
268 |
def __post_init__(self):
|
269 |
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
|
|
270 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
271 |
|
272 |
+
from lightrag.kg.shared_storage import (
|
273 |
+
initialize_share_data,
|
274 |
+
)
|
275 |
+
|
276 |
+
initialize_share_data()
|
277 |
+
|
278 |
if not os.path.exists(self.working_dir):
|
279 |
logger.info(f"Creating working directory {self.working_dir}")
|
280 |
os.makedirs(self.working_dir)
|
|
|
696 |
3. Process each chunk for entity and relation extraction
|
697 |
4. Update the document status
|
698 |
"""
|
699 |
+
from lightrag.kg.shared_storage import (
|
700 |
+
get_namespace_data,
|
701 |
+
get_pipeline_status_lock,
|
|
|
|
|
|
|
702 |
)
|
703 |
|
704 |
+
# Get pipeline status shared data and lock
|
705 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
706 |
+
pipeline_status_lock = get_pipeline_status_lock()
|
707 |
+
|
708 |
+
# Check if another process is already processing the queue
|
709 |
+
async with pipeline_status_lock:
|
710 |
+
# Ensure only one worker is processing documents
|
711 |
+
if not pipeline_status.get("busy", False):
|
712 |
+
# 先检查是否有需要处理的文档
|
713 |
+
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
714 |
+
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
715 |
+
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
716 |
+
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
717 |
+
)
|
718 |
|
719 |
+
to_process_docs: dict[str, DocProcessingStatus] = {}
|
720 |
+
to_process_docs.update(processing_docs)
|
721 |
+
to_process_docs.update(failed_docs)
|
722 |
+
to_process_docs.update(pending_docs)
|
723 |
+
|
724 |
+
# 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
|
725 |
+
if not to_process_docs:
|
726 |
+
logger.info("No documents to process")
|
727 |
+
return
|
728 |
+
|
729 |
+
# 有文档需要处理,更新 pipeline_status
|
730 |
+
pipeline_status.update(
|
731 |
+
{
|
732 |
+
"busy": True,
|
733 |
+
"job_name": "indexing files",
|
734 |
+
"job_start": datetime.now().isoformat(),
|
735 |
+
"docs": 0,
|
736 |
+
"batchs": 0,
|
737 |
+
"cur_batch": 0,
|
738 |
+
"request_pending": False, # Clear any previous request
|
739 |
+
"latest_message": "",
|
740 |
+
}
|
741 |
+
)
|
742 |
+
# Cleaning history_messages without breaking it as a shared list object
|
743 |
+
del pipeline_status["history_messages"][:]
|
744 |
+
else:
|
745 |
+
# Another process is busy, just set request flag and return
|
746 |
+
pipeline_status["request_pending"] = True
|
747 |
+
logger.info(
|
748 |
+
"Another process is already processing the document queue. Request queued."
|
749 |
+
)
|
750 |
+
return
|
751 |
|
752 |
+
try:
|
753 |
+
# Process documents until no more documents or requests
|
754 |
+
while True:
|
755 |
+
if not to_process_docs:
|
756 |
+
log_message = "All documents have been processed or are duplicates"
|
757 |
+
logger.info(log_message)
|
758 |
+
pipeline_status["latest_message"] = log_message
|
759 |
+
pipeline_status["history_messages"].append(log_message)
|
760 |
+
break
|
761 |
+
|
762 |
+
# 2. split docs into chunks, insert chunks, update doc status
|
763 |
+
docs_batches = [
|
764 |
+
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
765 |
+
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
766 |
+
]
|
767 |
|
768 |
+
log_message = f"Number of batches to process: {len(docs_batches)}."
|
769 |
+
logger.info(log_message)
|
770 |
+
|
771 |
+
# Update pipeline status with current batch information
|
772 |
+
pipeline_status["docs"] += len(to_process_docs)
|
773 |
+
pipeline_status["batchs"] += len(docs_batches)
|
774 |
+
pipeline_status["latest_message"] = log_message
|
775 |
+
pipeline_status["history_messages"].append(log_message)
|
776 |
+
|
777 |
+
batches: list[Any] = []
|
778 |
+
# 3. iterate over batches
|
779 |
+
for batch_idx, docs_batch in enumerate(docs_batches):
|
780 |
+
# Update current batch in pipeline status (directly, as it's atomic)
|
781 |
+
pipeline_status["cur_batch"] += 1
|
782 |
+
|
783 |
+
async def batch(
|
784 |
+
batch_idx: int,
|
785 |
+
docs_batch: list[tuple[str, DocProcessingStatus]],
|
786 |
+
size_batch: int,
|
787 |
+
) -> None:
|
788 |
+
log_message = (
|
789 |
+
f"Start processing batch {batch_idx + 1} of {size_batch}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
)
|
791 |
+
logger.info(log_message)
|
792 |
+
pipeline_status["latest_message"] = log_message
|
793 |
+
pipeline_status["history_messages"].append(log_message)
|
794 |
+
# 4. iterate over batch
|
795 |
+
for doc_id_processing_status in docs_batch:
|
796 |
+
doc_id, status_doc = doc_id_processing_status
|
797 |
+
# Generate chunks from document
|
798 |
+
chunks: dict[str, Any] = {
|
799 |
+
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
800 |
+
**dp,
|
801 |
+
"full_doc_id": doc_id,
|
|
|
802 |
}
|
803 |
+
for dp in self.chunking_func(
|
804 |
+
status_doc.content,
|
805 |
+
split_by_character,
|
806 |
+
split_by_character_only,
|
807 |
+
self.chunk_overlap_token_size,
|
808 |
+
self.chunk_token_size,
|
809 |
+
self.tiktoken_model_name,
|
810 |
+
)
|
811 |
}
|
812 |
+
# Process document (text chunks and full docs) in parallel
|
813 |
+
tasks = [
|
814 |
+
self.doc_status.upsert(
|
815 |
+
{
|
816 |
+
doc_id: {
|
817 |
+
"status": DocStatus.PROCESSING,
|
818 |
+
"updated_at": datetime.now().isoformat(),
|
819 |
+
"content": status_doc.content,
|
820 |
+
"content_summary": status_doc.content_summary,
|
821 |
+
"content_length": status_doc.content_length,
|
822 |
+
"created_at": status_doc.created_at,
|
823 |
+
}
|
824 |
+
}
|
825 |
+
),
|
826 |
+
self.chunks_vdb.upsert(chunks),
|
827 |
+
self._process_entity_relation_graph(chunks),
|
828 |
+
self.full_docs.upsert(
|
829 |
+
{doc_id: {"content": status_doc.content}}
|
830 |
+
),
|
831 |
+
self.text_chunks.upsert(chunks),
|
832 |
+
]
|
833 |
+
try:
|
834 |
+
await asyncio.gather(*tasks)
|
835 |
+
await self.doc_status.upsert(
|
836 |
+
{
|
837 |
+
doc_id: {
|
838 |
+
"status": DocStatus.PROCESSED,
|
839 |
+
"chunks_count": len(chunks),
|
840 |
+
"content": status_doc.content,
|
841 |
+
"content_summary": status_doc.content_summary,
|
842 |
+
"content_length": status_doc.content_length,
|
843 |
+
"created_at": status_doc.created_at,
|
844 |
+
"updated_at": datetime.now().isoformat(),
|
845 |
+
}
|
846 |
+
}
|
847 |
+
)
|
848 |
+
except Exception as e:
|
849 |
+
logger.error(
|
850 |
+
f"Failed to process document {doc_id}: {str(e)}"
|
851 |
+
)
|
852 |
+
await self.doc_status.upsert(
|
853 |
+
{
|
854 |
+
doc_id: {
|
855 |
+
"status": DocStatus.FAILED,
|
856 |
+
"error": str(e),
|
857 |
+
"content": status_doc.content,
|
858 |
+
"content_summary": status_doc.content_summary,
|
859 |
+
"content_length": status_doc.content_length,
|
860 |
+
"created_at": status_doc.created_at,
|
861 |
+
"updated_at": datetime.now().isoformat(),
|
862 |
+
}
|
863 |
+
}
|
864 |
+
)
|
865 |
+
continue
|
866 |
+
log_message = (
|
867 |
+
f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
|
868 |
)
|
869 |
+
logger.info(log_message)
|
870 |
+
pipeline_status["latest_message"] = log_message
|
871 |
+
pipeline_status["history_messages"].append(log_message)
|
872 |
+
|
873 |
+
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
|
874 |
|
875 |
+
await asyncio.gather(*batches)
|
876 |
+
await self._insert_done()
|
877 |
+
|
878 |
+
# Check if there's a pending request to process more documents (with lock)
|
879 |
+
has_pending_request = False
|
880 |
+
async with pipeline_status_lock:
|
881 |
+
has_pending_request = pipeline_status.get("request_pending", False)
|
882 |
+
if has_pending_request:
|
883 |
+
# Clear the request flag before checking for more documents
|
884 |
+
pipeline_status["request_pending"] = False
|
885 |
+
|
886 |
+
if not has_pending_request:
|
887 |
+
break
|
888 |
+
|
889 |
+
log_message = "Processing additional documents due to pending request"
|
890 |
+
logger.info(log_message)
|
891 |
+
pipeline_status["latest_message"] = log_message
|
892 |
+
pipeline_status["history_messages"].append(log_message)
|
893 |
+
|
894 |
+
# 获取新的待处理文档
|
895 |
+
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
896 |
+
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
897 |
+
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
898 |
+
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
899 |
+
)
|
900 |
+
|
901 |
+
to_process_docs = {}
|
902 |
+
to_process_docs.update(processing_docs)
|
903 |
+
to_process_docs.update(failed_docs)
|
904 |
+
to_process_docs.update(pending_docs)
|
905 |
|
906 |
+
finally:
|
907 |
+
log_message = "Document processing pipeline completed"
|
908 |
+
logger.info(log_message)
|
909 |
+
# Always reset busy status when done or if an exception occurs (with lock)
|
910 |
+
async with pipeline_status_lock:
|
911 |
+
pipeline_status["busy"] = False
|
912 |
+
pipeline_status["latest_message"] = log_message
|
913 |
+
pipeline_status["history_messages"].append(log_message)
|
914 |
|
915 |
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
916 |
try:
|
|
|
941 |
if storage_inst is not None
|
942 |
]
|
943 |
await asyncio.gather(*tasks)
|
944 |
+
|
945 |
+
log_message = "All Insert done"
|
946 |
+
logger.info(log_message)
|
947 |
+
|
948 |
+
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
949 |
+
from lightrag.kg.shared_storage import get_namespace_data
|
950 |
+
|
951 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
952 |
+
pipeline_status["latest_message"] = log_message
|
953 |
+
pipeline_status["history_messages"].append(log_message)
|
954 |
|
955 |
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
956 |
loop = always_get_an_event_loop()
|
lightrag/operate.py
CHANGED
@@ -339,6 +339,9 @@ async def extract_entities(
|
|
339 |
global_config: dict[str, str],
|
340 |
llm_response_cache: BaseKVStorage | None = None,
|
341 |
) -> None:
|
|
|
|
|
|
|
342 |
use_llm_func: callable = global_config["llm_model_func"]
|
343 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
344 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
@@ -499,9 +502,10 @@ async def extract_entities(
|
|
499 |
processed_chunks += 1
|
500 |
entities_count = len(maybe_nodes)
|
501 |
relations_count = len(maybe_edges)
|
502 |
-
|
503 |
-
|
504 |
-
|
|
|
505 |
return dict(maybe_nodes), dict(maybe_edges)
|
506 |
|
507 |
tasks = [_process_single_content(c) for c in ordered_chunks]
|
@@ -530,17 +534,27 @@ async def extract_entities(
|
|
530 |
)
|
531 |
|
532 |
if not (all_entities_data or all_relationships_data):
|
533 |
-
|
|
|
|
|
|
|
534 |
return
|
535 |
|
536 |
if not all_entities_data:
|
537 |
-
|
|
|
|
|
|
|
538 |
if not all_relationships_data:
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
|
|
|
|
|
|
|
|
544 |
verbose_debug(
|
545 |
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
546 |
)
|
|
|
339 |
global_config: dict[str, str],
|
340 |
llm_response_cache: BaseKVStorage | None = None,
|
341 |
) -> None:
|
342 |
+
from lightrag.kg.shared_storage import get_namespace_data
|
343 |
+
|
344 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
345 |
use_llm_func: callable = global_config["llm_model_func"]
|
346 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
347 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
|
|
502 |
processed_chunks += 1
|
503 |
entities_count = len(maybe_nodes)
|
504 |
relations_count = len(maybe_edges)
|
505 |
+
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
506 |
+
logger.info(log_message)
|
507 |
+
pipeline_status["latest_message"] = log_message
|
508 |
+
pipeline_status["history_messages"].append(log_message)
|
509 |
return dict(maybe_nodes), dict(maybe_edges)
|
510 |
|
511 |
tasks = [_process_single_content(c) for c in ordered_chunks]
|
|
|
534 |
)
|
535 |
|
536 |
if not (all_entities_data or all_relationships_data):
|
537 |
+
log_message = "Didn't extract any entities and relationships."
|
538 |
+
logger.info(log_message)
|
539 |
+
pipeline_status["latest_message"] = log_message
|
540 |
+
pipeline_status["history_messages"].append(log_message)
|
541 |
return
|
542 |
|
543 |
if not all_entities_data:
|
544 |
+
log_message = "Didn't extract any entities"
|
545 |
+
logger.info(log_message)
|
546 |
+
pipeline_status["latest_message"] = log_message
|
547 |
+
pipeline_status["history_messages"].append(log_message)
|
548 |
if not all_relationships_data:
|
549 |
+
log_message = "Didn't extract any relationships"
|
550 |
+
logger.info(log_message)
|
551 |
+
pipeline_status["latest_message"] = log_message
|
552 |
+
pipeline_status["history_messages"].append(log_message)
|
553 |
+
|
554 |
+
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
555 |
+
logger.info(log_message)
|
556 |
+
pipeline_status["latest_message"] = log_message
|
557 |
+
pipeline_status["history_messages"].append(log_message)
|
558 |
verbose_debug(
|
559 |
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
560 |
)
|
lightrag/utils.py
CHANGED
@@ -56,45 +56,29 @@ def set_verbose_debug(enabled: bool):
|
|
56 |
VERBOSE_DEBUG = enabled
|
57 |
|
58 |
|
59 |
-
class UnlimitedSemaphore:
|
60 |
-
"""A context manager that allows unlimited access."""
|
61 |
-
|
62 |
-
async def __aenter__(self):
|
63 |
-
pass
|
64 |
-
|
65 |
-
async def __aexit__(self, exc_type, exc, tb):
|
66 |
-
pass
|
67 |
-
|
68 |
-
|
69 |
-
ENCODER = None
|
70 |
-
|
71 |
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
72 |
|
|
|
73 |
logger = logging.getLogger("lightrag")
|
|
|
|
|
|
|
74 |
|
75 |
# Set httpx logging level to WARNING
|
76 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
77 |
|
78 |
|
79 |
-
|
80 |
-
"""
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
level: Logging level (e.g. logging.DEBUG, logging.INFO)
|
85 |
-
"""
|
86 |
-
logger.setLevel(level)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
|
91 |
-
formatter = logging.Formatter(
|
92 |
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
93 |
-
)
|
94 |
-
file_handler.setFormatter(formatter)
|
95 |
|
96 |
-
|
97 |
-
logger.addHandler(file_handler)
|
98 |
|
99 |
|
100 |
@dataclass
|
|
|
56 |
VERBOSE_DEBUG = enabled
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
60 |
|
61 |
+
# Initialize logger
|
62 |
logger = logging.getLogger("lightrag")
|
63 |
+
logger.propagate = False # prevent log message send to root loggger
|
64 |
+
# Let the main application configure the handlers
|
65 |
+
logger.setLevel(logging.INFO)
|
66 |
|
67 |
# Set httpx logging level to WARNING
|
68 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
69 |
|
70 |
|
71 |
+
class UnlimitedSemaphore:
|
72 |
+
"""A context manager that allows unlimited access."""
|
73 |
|
74 |
+
async def __aenter__(self):
|
75 |
+
pass
|
|
|
|
|
|
|
76 |
|
77 |
+
async def __aexit__(self, exc_type, exc, tb):
|
78 |
+
pass
|
79 |
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
ENCODER = None
|
|
|
82 |
|
83 |
|
84 |
@dataclass
|
run_with_gunicorn.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Start LightRAG server with Gunicorn
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import signal
|
9 |
+
import pipmaster as pm
|
10 |
+
from lightrag.api.utils_api import parse_args, display_splash_screen
|
11 |
+
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
12 |
+
|
13 |
+
|
14 |
+
def check_and_install_dependencies():
|
15 |
+
"""Check and install required dependencies"""
|
16 |
+
required_packages = [
|
17 |
+
"gunicorn",
|
18 |
+
"tiktoken",
|
19 |
+
"psutil",
|
20 |
+
# Add other required packages here
|
21 |
+
]
|
22 |
+
|
23 |
+
for package in required_packages:
|
24 |
+
if not pm.is_installed(package):
|
25 |
+
print(f"Installing {package}...")
|
26 |
+
pm.install(package)
|
27 |
+
print(f"{package} installed successfully")
|
28 |
+
|
29 |
+
|
30 |
+
# Signal handler for graceful shutdown
|
31 |
+
def signal_handler(sig, frame):
|
32 |
+
print("\n\n" + "=" * 80)
|
33 |
+
print("RECEIVED TERMINATION SIGNAL")
|
34 |
+
print(f"Process ID: {os.getpid()}")
|
35 |
+
print("=" * 80 + "\n")
|
36 |
+
|
37 |
+
# Release shared resources
|
38 |
+
finalize_share_data()
|
39 |
+
|
40 |
+
# Exit with success status
|
41 |
+
sys.exit(0)
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
# Check and install dependencies
|
46 |
+
check_and_install_dependencies()
|
47 |
+
|
48 |
+
# Register signal handlers for graceful shutdown
|
49 |
+
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
50 |
+
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
51 |
+
|
52 |
+
# Parse all arguments using parse_args
|
53 |
+
args = parse_args(is_uvicorn_mode=False)
|
54 |
+
|
55 |
+
# Display startup information
|
56 |
+
display_splash_screen(args)
|
57 |
+
|
58 |
+
print("🚀 Starting LightRAG with Gunicorn")
|
59 |
+
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
60 |
+
print("🔍 Preloading app: Enabled")
|
61 |
+
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
62 |
+
print("\n\n" + "=" * 80)
|
63 |
+
print("MAIN PROCESS INITIALIZATION")
|
64 |
+
print(f"Process ID: {os.getpid()}")
|
65 |
+
print(f"Workers setting: {args.workers}")
|
66 |
+
print("=" * 80 + "\n")
|
67 |
+
|
68 |
+
# Import Gunicorn's StandaloneApplication
|
69 |
+
from gunicorn.app.base import BaseApplication
|
70 |
+
|
71 |
+
# Define a custom application class that loads our config
|
72 |
+
class GunicornApp(BaseApplication):
|
73 |
+
def __init__(self, app, options=None):
|
74 |
+
self.options = options or {}
|
75 |
+
self.application = app
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
def load_config(self):
|
79 |
+
# Define valid Gunicorn configuration options
|
80 |
+
valid_options = {
|
81 |
+
"bind",
|
82 |
+
"workers",
|
83 |
+
"worker_class",
|
84 |
+
"timeout",
|
85 |
+
"keepalive",
|
86 |
+
"preload_app",
|
87 |
+
"errorlog",
|
88 |
+
"accesslog",
|
89 |
+
"loglevel",
|
90 |
+
"certfile",
|
91 |
+
"keyfile",
|
92 |
+
"limit_request_line",
|
93 |
+
"limit_request_fields",
|
94 |
+
"limit_request_field_size",
|
95 |
+
"graceful_timeout",
|
96 |
+
"max_requests",
|
97 |
+
"max_requests_jitter",
|
98 |
+
}
|
99 |
+
|
100 |
+
# Special hooks that need to be set separately
|
101 |
+
special_hooks = {
|
102 |
+
"on_starting",
|
103 |
+
"on_reload",
|
104 |
+
"on_exit",
|
105 |
+
"pre_fork",
|
106 |
+
"post_fork",
|
107 |
+
"pre_exec",
|
108 |
+
"pre_request",
|
109 |
+
"post_request",
|
110 |
+
"worker_init",
|
111 |
+
"worker_exit",
|
112 |
+
"nworkers_changed",
|
113 |
+
"child_exit",
|
114 |
+
}
|
115 |
+
|
116 |
+
# Import and configure the gunicorn_config module
|
117 |
+
import gunicorn_config
|
118 |
+
|
119 |
+
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
120 |
+
gunicorn_config.workers = (
|
121 |
+
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
122 |
+
)
|
123 |
+
|
124 |
+
# Bind configuration prioritizes command line arguments
|
125 |
+
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
126 |
+
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
127 |
+
gunicorn_config.bind = f"{host}:{port}"
|
128 |
+
|
129 |
+
# Log level configuration prioritizes command line arguments
|
130 |
+
gunicorn_config.loglevel = (
|
131 |
+
args.log_level.lower()
|
132 |
+
if args.log_level
|
133 |
+
else os.getenv("LOG_LEVEL", "info")
|
134 |
+
)
|
135 |
+
|
136 |
+
# Timeout configuration prioritizes command line arguments
|
137 |
+
gunicorn_config.timeout = (
|
138 |
+
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
|
139 |
+
)
|
140 |
+
|
141 |
+
# Keepalive configuration
|
142 |
+
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
143 |
+
|
144 |
+
# SSL configuration prioritizes command line arguments
|
145 |
+
if args.ssl or os.getenv("SSL", "").lower() in (
|
146 |
+
"true",
|
147 |
+
"1",
|
148 |
+
"yes",
|
149 |
+
"t",
|
150 |
+
"on",
|
151 |
+
):
|
152 |
+
gunicorn_config.certfile = (
|
153 |
+
args.ssl_certfile
|
154 |
+
if args.ssl_certfile
|
155 |
+
else os.getenv("SSL_CERTFILE")
|
156 |
+
)
|
157 |
+
gunicorn_config.keyfile = (
|
158 |
+
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
159 |
+
)
|
160 |
+
|
161 |
+
# Set configuration options from the module
|
162 |
+
for key in dir(gunicorn_config):
|
163 |
+
if key in valid_options:
|
164 |
+
value = getattr(gunicorn_config, key)
|
165 |
+
# Skip functions like on_starting and None values
|
166 |
+
if not callable(value) and value is not None:
|
167 |
+
self.cfg.set(key, value)
|
168 |
+
# Set special hooks
|
169 |
+
elif key in special_hooks:
|
170 |
+
value = getattr(gunicorn_config, key)
|
171 |
+
if callable(value):
|
172 |
+
self.cfg.set(key, value)
|
173 |
+
|
174 |
+
if hasattr(gunicorn_config, "logconfig_dict"):
|
175 |
+
self.cfg.set(
|
176 |
+
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
|
177 |
+
)
|
178 |
+
|
179 |
+
def load(self):
|
180 |
+
# Import the application
|
181 |
+
from lightrag.api.lightrag_server import get_application
|
182 |
+
|
183 |
+
return get_application(args)
|
184 |
+
|
185 |
+
# Create the application
|
186 |
+
app = GunicornApp("")
|
187 |
+
|
188 |
+
# Force workers to be an integer and greater than 1 for multi-process mode
|
189 |
+
workers_count = int(args.workers)
|
190 |
+
if workers_count > 1:
|
191 |
+
# Set a flag to indicate we're in the main process
|
192 |
+
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
193 |
+
initialize_share_data(workers_count)
|
194 |
+
else:
|
195 |
+
initialize_share_data(1)
|
196 |
+
|
197 |
+
# Run the application
|
198 |
+
print("\nStarting Gunicorn with direct Python API...")
|
199 |
+
app.run()
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
main()
|
setup.py
CHANGED
@@ -112,6 +112,7 @@ setuptools.setup(
|
|
112 |
entry_points={
|
113 |
"console_scripts": [
|
114 |
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
|
|
115 |
"lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
|
116 |
],
|
117 |
},
|
|
|
112 |
entry_points={
|
113 |
"console_scripts": [
|
114 |
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
115 |
+
"lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]",
|
116 |
"lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
|
117 |
],
|
118 |
},
|