yangdx
commited on
Commit
·
37b1fa9
1
Parent(s):
7c2f5b5
Refactor authentication logic and Swagger UI config
Browse files- Consolidate authentication dependencies
- Improve Swagger UI security parameters
lightrag/api/lightrag_server.py
CHANGED
@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
|
|
19 |
from dotenv import load_dotenv
|
20 |
from lightrag.api.utils_api import (
|
21 |
get_api_key_dependency,
|
|
|
22 |
parse_args,
|
23 |
get_default_host,
|
24 |
display_splash_screen,
|
@@ -135,19 +136,28 @@ def create_app(args):
|
|
135 |
await rag.finalize_storages()
|
136 |
|
137 |
# Initialize FastAPI
|
138 |
-
|
139 |
-
title
|
140 |
-
description
|
141 |
+ "(With authentication)"
|
142 |
if api_key
|
143 |
else "",
|
144 |
-
version
|
145 |
-
openapi_url
|
146 |
-
docs_url
|
147 |
-
redoc_url
|
148 |
-
openapi_tags
|
149 |
-
lifespan
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def get_cors_origins():
|
153 |
"""Get allowed origins from environment variable
|
@@ -167,13 +177,8 @@ def create_app(args):
|
|
167 |
allow_headers=["*"],
|
168 |
)
|
169 |
|
170 |
-
# Create
|
171 |
-
|
172 |
-
async def optional_api_key_dependency(request: Request):
|
173 |
-
# Create the dependency function with the request
|
174 |
-
api_key_dependency = get_api_key_dependency(api_key)
|
175 |
-
# Call the dependency function with the request
|
176 |
-
return await api_key_dependency(request)
|
177 |
|
178 |
# Create working directory if it doesn't exist
|
179 |
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
@@ -419,7 +424,7 @@ def create_app(args):
|
|
419 |
"api_version": __api_version__,
|
420 |
}
|
421 |
|
422 |
-
@app.get("/health", dependencies=[Depends(
|
423 |
async def get_status():
|
424 |
"""Get current system status"""
|
425 |
username = os.getenv("AUTH_USERNAME")
|
|
|
19 |
from dotenv import load_dotenv
|
20 |
from lightrag.api.utils_api import (
|
21 |
get_api_key_dependency,
|
22 |
+
get_combined_auth_dependency,
|
23 |
parse_args,
|
24 |
get_default_host,
|
25 |
display_splash_screen,
|
|
|
136 |
await rag.finalize_storages()
|
137 |
|
138 |
# Initialize FastAPI
|
139 |
+
app_kwargs = {
|
140 |
+
"title": "LightRAG Server API",
|
141 |
+
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
142 |
+ "(With authentication)"
|
143 |
if api_key
|
144 |
else "",
|
145 |
+
"version": __api_version__,
|
146 |
+
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
147 |
+
"docs_url": "/docs", # Explicitly set docs URL
|
148 |
+
"redoc_url": "/redoc", # Explicitly set redoc URL
|
149 |
+
"openapi_tags": [{"name": "api"}],
|
150 |
+
"lifespan": lifespan,
|
151 |
+
}
|
152 |
+
|
153 |
+
# Configure Swagger UI parameters
|
154 |
+
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
155 |
+
app_kwargs["swagger_ui_parameters"] = {
|
156 |
+
"persistAuthorization": True,
|
157 |
+
"tryItOutEnabled": True,
|
158 |
+
}
|
159 |
+
|
160 |
+
app = FastAPI(**app_kwargs)
|
161 |
|
162 |
def get_cors_origins():
|
163 |
"""Get allowed origins from environment variable
|
|
|
177 |
allow_headers=["*"],
|
178 |
)
|
179 |
|
180 |
+
# Create combined auth dependency for all endpoints
|
181 |
+
combined_auth = get_combined_auth_dependency(api_key)
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
# Create working directory if it doesn't exist
|
184 |
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
|
|
424 |
"api_version": __api_version__,
|
425 |
}
|
426 |
|
427 |
+
@app.get("/health", dependencies=[Depends(combined_auth)])
|
428 |
async def get_status():
|
429 |
"""Get current system status"""
|
430 |
username = os.getenv("AUTH_USERNAME")
|
lightrag/api/routers/document_routes.py
CHANGED
@@ -505,6 +505,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|
505 |
def create_document_routes(
|
506 |
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
507 |
):
|
|
|
508 |
combined_auth = get_combined_auth_dependency(api_key)
|
509 |
|
510 |
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
|
|
505 |
def create_document_routes(
|
506 |
rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
|
507 |
):
|
508 |
+
# Create combined auth dependency for document routes
|
509 |
combined_auth = get_combined_auth_dependency(api_key)
|
510 |
|
511 |
@router.post("/scan", dependencies=[Depends(combined_auth)])
|
lightrag/api/routers/ollama_api.py
CHANGED
@@ -11,7 +11,7 @@ import asyncio
|
|
11 |
from ascii_colors import trace_exception
|
12 |
from lightrag import LightRAG, QueryParam
|
13 |
from lightrag.utils import encode_string_by_tiktoken
|
14 |
-
from lightrag.api.utils_api import ollama_server_infos,
|
15 |
from fastapi import Depends
|
16 |
|
17 |
|
@@ -132,21 +132,17 @@ class OllamaAPI:
|
|
132 |
self.setup_routes()
|
133 |
|
134 |
def setup_routes(self):
|
135 |
-
# Create
|
136 |
-
|
137 |
-
# Create the dependency function with the request
|
138 |
-
api_key_dependency = get_api_key_dependency(self.api_key)
|
139 |
-
# Call the dependency function with the request
|
140 |
-
return await api_key_dependency(request)
|
141 |
|
142 |
@self.router.get(
|
143 |
-
"/version", dependencies=[Depends(
|
144 |
)
|
145 |
async def get_version():
|
146 |
"""Get Ollama version information"""
|
147 |
return OllamaVersionResponse(version="0.5.4")
|
148 |
|
149 |
-
@self.router.get("/tags", dependencies=[Depends(
|
150 |
async def get_tags():
|
151 |
"""Return available models acting as an Ollama server"""
|
152 |
return OllamaTagResponse(
|
@@ -170,7 +166,7 @@ class OllamaAPI:
|
|
170 |
)
|
171 |
|
172 |
@self.router.post(
|
173 |
-
"/generate", dependencies=[Depends(
|
174 |
)
|
175 |
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
176 |
"""Handle generate completion requests acting as an Ollama model
|
@@ -337,7 +333,7 @@ class OllamaAPI:
|
|
337 |
trace_exception(e)
|
338 |
raise HTTPException(status_code=500, detail=str(e))
|
339 |
|
340 |
-
@self.router.post("/chat", dependencies=[Depends(
|
341 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
342 |
"""Process chat completion requests acting as an Ollama model
|
343 |
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
|
|
11 |
from ascii_colors import trace_exception
|
12 |
from lightrag import LightRAG, QueryParam
|
13 |
from lightrag.utils import encode_string_by_tiktoken
|
14 |
+
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
15 |
from fastapi import Depends
|
16 |
|
17 |
|
|
|
132 |
self.setup_routes()
|
133 |
|
134 |
def setup_routes(self):
|
135 |
+
# Create combined auth dependency for Ollama API routes
|
136 |
+
combined_auth = get_combined_auth_dependency(self.api_key)
|
|
|
|
|
|
|
|
|
137 |
|
138 |
@self.router.get(
|
139 |
+
"/version", dependencies=[Depends(combined_auth)]
|
140 |
)
|
141 |
async def get_version():
|
142 |
"""Get Ollama version information"""
|
143 |
return OllamaVersionResponse(version="0.5.4")
|
144 |
|
145 |
+
@self.router.get("/tags", dependencies=[Depends(combined_auth)])
|
146 |
async def get_tags():
|
147 |
"""Return available models acting as an Ollama server"""
|
148 |
return OllamaTagResponse(
|
|
|
166 |
)
|
167 |
|
168 |
@self.router.post(
|
169 |
+
"/generate", dependencies=[Depends(combined_auth)]
|
170 |
)
|
171 |
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
172 |
"""Handle generate completion requests acting as an Ollama model
|
|
|
333 |
trace_exception(e)
|
334 |
raise HTTPException(status_code=500, detail=str(e))
|
335 |
|
336 |
+
@self.router.post("/chat", dependencies=[Depends(combined_auth)])
|
337 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
338 |
"""Process chat completion requests acting as an Ollama model
|
339 |
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
lightrag/api/utils_api.py
CHANGED
@@ -58,29 +58,43 @@ ollama_server_infos = OllamaServerInfos()
|
|
58 |
|
59 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
60 |
"""
|
61 |
-
Create a combined authentication dependency that implements
|
|
|
62 |
|
63 |
Args:
|
64 |
api_key (Optional[str]): API key for validation
|
65 |
|
66 |
Returns:
|
67 |
-
Callable: A dependency function that implements
|
68 |
"""
|
69 |
# Use global whitelist_patterns and auth_configured variables
|
70 |
# whitelist_patterns and auth_configured are already initialized at module level
|
71 |
|
72 |
# Only calculate api_key_configured as it depends on the function parameter
|
73 |
api_key_configured = bool(api_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
async def combined_dependency(
|
76 |
request: Request,
|
77 |
-
token: str =
|
|
|
78 |
):
|
79 |
-
#
|
80 |
-
if not auth_configured and not api_key_configured:
|
81 |
-
return
|
82 |
-
|
83 |
-
# Check if request path is in whitelist
|
84 |
path = request.url.path
|
85 |
for pattern, is_prefix in whitelist_patterns:
|
86 |
if (is_prefix and path.startswith(pattern)) or (
|
@@ -88,35 +102,54 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|
88 |
):
|
89 |
return # Whitelist path, allow access
|
90 |
|
91 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
if token:
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
if
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
raise HTTPException(
|
101 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
102 |
)
|
103 |
-
|
104 |
-
#
|
105 |
if api_key_configured:
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
else:
|
116 |
-
raise HTTPException(
|
117 |
-
status_code=HTTP_403_FORBIDDEN,
|
118 |
-
detail="API Key required or use guest authentication.",
|
119 |
-
)
|
120 |
|
121 |
return combined_dependency
|
122 |
|
@@ -145,12 +178,12 @@ def get_api_key_dependency(api_key: Optional[str]):
|
|
145 |
|
146 |
return no_auth
|
147 |
|
148 |
-
# If API key is configured, use proper authentication
|
149 |
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
150 |
|
151 |
async def api_key_auth(
|
152 |
request: Request,
|
153 |
-
api_key_header_value: Optional[str] = Security(api_key_header),
|
154 |
):
|
155 |
# Check if request path is in whitelist
|
156 |
path = request.url.path
|
|
|
58 |
|
59 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
60 |
"""
|
61 |
+
Create a combined authentication dependency that implements authentication logic
|
62 |
+
based on API key, OAuth2 token, and whitelist paths.
|
63 |
|
64 |
Args:
|
65 |
api_key (Optional[str]): API key for validation
|
66 |
|
67 |
Returns:
|
68 |
+
Callable: A dependency function that implements the authentication logic
|
69 |
"""
|
70 |
# Use global whitelist_patterns and auth_configured variables
|
71 |
# whitelist_patterns and auth_configured are already initialized at module level
|
72 |
|
73 |
# Only calculate api_key_configured as it depends on the function parameter
|
74 |
api_key_configured = bool(api_key)
|
75 |
+
|
76 |
+
# Create security dependencies with proper descriptions for Swagger UI
|
77 |
+
oauth2_scheme = OAuth2PasswordBearer(
|
78 |
+
tokenUrl="login",
|
79 |
+
auto_error=False,
|
80 |
+
description="OAuth2 Password Authentication"
|
81 |
+
)
|
82 |
+
|
83 |
+
# If API key is configured, create an API key header security
|
84 |
+
api_key_header = None
|
85 |
+
if api_key_configured:
|
86 |
+
api_key_header = APIKeyHeader(
|
87 |
+
name="X-API-Key",
|
88 |
+
auto_error=False,
|
89 |
+
description="API Key Authentication"
|
90 |
+
)
|
91 |
|
92 |
async def combined_dependency(
|
93 |
request: Request,
|
94 |
+
token: str = Security(oauth2_scheme),
|
95 |
+
api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
|
96 |
):
|
97 |
+
# 1. Check if path is in whitelist
|
|
|
|
|
|
|
|
|
98 |
path = request.url.path
|
99 |
for pattern, is_prefix in whitelist_patterns:
|
100 |
if (is_prefix and path.startswith(pattern)) or (
|
|
|
102 |
):
|
103 |
return # Whitelist path, allow access
|
104 |
|
105 |
+
# 2. Check for special endpoints (/health and Ollama API)
|
106 |
+
is_special_endpoint = path == "/health" or path.startswith("/api/")
|
107 |
+
if is_special_endpoint and not api_key_configured:
|
108 |
+
return # Special endpoint and no API key configured, allow access
|
109 |
+
|
110 |
+
# 3. Validate API key
|
111 |
+
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
|
112 |
+
return # API key validation successful
|
113 |
+
|
114 |
+
# 4. Validate token
|
115 |
if token:
|
116 |
+
try:
|
117 |
+
token_info = auth_handler.validate_token(token)
|
118 |
+
# Accept guest token if no auth is configured
|
119 |
+
if not auth_configured and token_info.get("role") == "guest":
|
120 |
+
return
|
121 |
+
# Accept non-guest token if auth is configured
|
122 |
+
if auth_configured and token_info.get("role") != "guest":
|
123 |
+
return
|
124 |
+
|
125 |
+
# Token validation failed, immediately return 401 error
|
126 |
+
raise HTTPException(
|
127 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
128 |
+
detail="Invalid token. Please login again."
|
129 |
+
)
|
130 |
+
except HTTPException as e:
|
131 |
+
# If already a 401 error, re-raise it
|
132 |
+
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
133 |
+
raise
|
134 |
+
# For other exceptions, continue processing
|
135 |
+
|
136 |
+
# If token exists but validation failed (didn't return above), return 401
|
137 |
raise HTTPException(
|
138 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
139 |
+
detail="Invalid token. Please login again."
|
140 |
)
|
141 |
+
|
142 |
+
# 5. No token and API key validation failed, return 403 error
|
143 |
if api_key_configured:
|
144 |
+
raise HTTPException(
|
145 |
+
status_code=HTTP_403_FORBIDDEN,
|
146 |
+
detail="API Key required or login authentication required."
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
raise HTTPException(
|
150 |
+
status_code=HTTP_403_FORBIDDEN,
|
151 |
+
detail="Login authentication required."
|
152 |
+
)
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
return combined_dependency
|
155 |
|
|
|
178 |
|
179 |
return no_auth
|
180 |
|
181 |
+
# If API key is configured, use proper authentication with Security for Swagger UI
|
182 |
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
183 |
|
184 |
async def api_key_auth(
|
185 |
request: Request,
|
186 |
+
api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
|
187 |
):
|
188 |
# Check if request path is in whitelist
|
189 |
path = request.url.path
|