yangdx commited on
Commit
37b1fa9
·
1 Parent(s): 7c2f5b5

Refactor authentication logic and Swagger UI config

Browse files

- Consolidate authentication dependencies
- Improve Swagger UI security parameters

lightrag/api/lightrag_server.py CHANGED
@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_api_key_dependency,
 
22
  parse_args,
23
  get_default_host,
24
  display_splash_screen,
@@ -135,19 +136,28 @@ def create_app(args):
135
  await rag.finalize_storages()
136
 
137
  # Initialize FastAPI
138
- app = FastAPI(
139
- title="LightRAG API",
140
- description="API for querying text using LightRAG with separate storage and input directories"
141
  + "(With authentication)"
142
  if api_key
143
  else "",
144
- version=__api_version__,
145
- openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
146
- docs_url="/docs", # Explicitly set docs URL
147
- redoc_url="/redoc", # Explicitly set redoc URL
148
- openapi_tags=[{"name": "api"}],
149
- lifespan=lifespan,
150
- )
 
 
 
 
 
 
 
 
 
151
 
152
  def get_cors_origins():
153
  """Get allowed origins from environment variable
@@ -167,13 +177,8 @@ def create_app(args):
167
  allow_headers=["*"],
168
  )
169
 
170
- # Create the optional API key dependency
171
- # Create a dependency that passes the request to get_api_key_dependency
172
- async def optional_api_key_dependency(request: Request):
173
- # Create the dependency function with the request
174
- api_key_dependency = get_api_key_dependency(api_key)
175
- # Call the dependency function with the request
176
- return await api_key_dependency(request)
177
 
178
  # Create working directory if it doesn't exist
179
  Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@@ -419,7 +424,7 @@ def create_app(args):
419
  "api_version": __api_version__,
420
  }
421
 
422
- @app.get("/health", dependencies=[Depends(optional_api_key_dependency)])
423
  async def get_status():
424
  """Get current system status"""
425
  username = os.getenv("AUTH_USERNAME")
 
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_api_key_dependency,
22
+ get_combined_auth_dependency,
23
  parse_args,
24
  get_default_host,
25
  display_splash_screen,
 
136
  await rag.finalize_storages()
137
 
138
  # Initialize FastAPI
139
+ app_kwargs = {
140
+ "title": "LightRAG Server API",
141
+ "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
142
  + "(With authentication)"
143
  if api_key
144
  else "",
145
+ "version": __api_version__,
146
+ "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
147
+ "docs_url": "/docs", # Explicitly set docs URL
148
+ "redoc_url": "/redoc", # Explicitly set redoc URL
149
+ "openapi_tags": [{"name": "api"}],
150
+ "lifespan": lifespan,
151
+ }
152
+
153
+ # Configure Swagger UI parameters
154
+ # Enable persistAuthorization and tryItOutEnabled for better user experience
155
+ app_kwargs["swagger_ui_parameters"] = {
156
+ "persistAuthorization": True,
157
+ "tryItOutEnabled": True,
158
+ }
159
+
160
+ app = FastAPI(**app_kwargs)
161
 
162
  def get_cors_origins():
163
  """Get allowed origins from environment variable
 
177
  allow_headers=["*"],
178
  )
179
 
180
+ # Create combined auth dependency for all endpoints
181
+ combined_auth = get_combined_auth_dependency(api_key)
 
 
 
 
 
182
 
183
  # Create working directory if it doesn't exist
184
  Path(args.working_dir).mkdir(parents=True, exist_ok=True)
 
424
  "api_version": __api_version__,
425
  }
426
 
427
+ @app.get("/health", dependencies=[Depends(combined_auth)])
428
  async def get_status():
429
  """Get current system status"""
430
  username = os.getenv("AUTH_USERNAME")
lightrag/api/routers/document_routes.py CHANGED
@@ -505,6 +505,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
505
  def create_document_routes(
506
  rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
507
  ):
 
508
  combined_auth = get_combined_auth_dependency(api_key)
509
 
510
  @router.post("/scan", dependencies=[Depends(combined_auth)])
 
505
  def create_document_routes(
506
  rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None
507
  ):
508
+ # Create combined auth dependency for document routes
509
  combined_auth = get_combined_auth_dependency(api_key)
510
 
511
  @router.post("/scan", dependencies=[Depends(combined_auth)])
lightrag/api/routers/ollama_api.py CHANGED
@@ -11,7 +11,7 @@ import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
  from lightrag.utils import encode_string_by_tiktoken
14
- from lightrag.api.utils_api import ollama_server_infos, get_api_key_dependency
15
  from fastapi import Depends
16
 
17
 
@@ -132,21 +132,17 @@ class OllamaAPI:
132
  self.setup_routes()
133
 
134
  def setup_routes(self):
135
- # Create a dependency that passes the request to get_api_key_dependency
136
- async def optional_api_key_dependency(request: Request):
137
- # Create the dependency function with the request
138
- api_key_dependency = get_api_key_dependency(self.api_key)
139
- # Call the dependency function with the request
140
- return await api_key_dependency(request)
141
 
142
  @self.router.get(
143
- "/version", dependencies=[Depends(optional_api_key_dependency)]
144
  )
145
  async def get_version():
146
  """Get Ollama version information"""
147
  return OllamaVersionResponse(version="0.5.4")
148
 
149
- @self.router.get("/tags", dependencies=[Depends(optional_api_key_dependency)])
150
  async def get_tags():
151
  """Return available models acting as an Ollama server"""
152
  return OllamaTagResponse(
@@ -170,7 +166,7 @@ class OllamaAPI:
170
  )
171
 
172
  @self.router.post(
173
- "/generate", dependencies=[Depends(optional_api_key_dependency)]
174
  )
175
  async def generate(raw_request: Request, request: OllamaGenerateRequest):
176
  """Handle generate completion requests acting as an Ollama model
@@ -337,7 +333,7 @@ class OllamaAPI:
337
  trace_exception(e)
338
  raise HTTPException(status_code=500, detail=str(e))
339
 
340
- @self.router.post("/chat", dependencies=[Depends(optional_api_key_dependency)])
341
  async def chat(raw_request: Request, request: OllamaChatRequest):
342
  """Process chat completion requests acting as an Ollama model
343
  Routes user queries through LightRAG by selecting query mode based on prefix indicators.
 
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
  from lightrag.utils import encode_string_by_tiktoken
14
+ from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
15
  from fastapi import Depends
16
 
17
 
 
132
  self.setup_routes()
133
 
134
  def setup_routes(self):
135
+ # Create combined auth dependency for Ollama API routes
136
+ combined_auth = get_combined_auth_dependency(self.api_key)
 
 
 
 
137
 
138
  @self.router.get(
139
+ "/version", dependencies=[Depends(combined_auth)]
140
  )
141
  async def get_version():
142
  """Get Ollama version information"""
143
  return OllamaVersionResponse(version="0.5.4")
144
 
145
+ @self.router.get("/tags", dependencies=[Depends(combined_auth)])
146
  async def get_tags():
147
  """Return available models acting as an Ollama server"""
148
  return OllamaTagResponse(
 
166
  )
167
 
168
  @self.router.post(
169
+ "/generate", dependencies=[Depends(combined_auth)]
170
  )
171
  async def generate(raw_request: Request, request: OllamaGenerateRequest):
172
  """Handle generate completion requests acting as an Ollama model
 
333
  trace_exception(e)
334
  raise HTTPException(status_code=500, detail=str(e))
335
 
336
+ @self.router.post("/chat", dependencies=[Depends(combined_auth)])
337
  async def chat(raw_request: Request, request: OllamaChatRequest):
338
  """Process chat completion requests acting as an Ollama model
339
  Routes user queries through LightRAG by selecting query mode based on prefix indicators.
lightrag/api/utils_api.py CHANGED
@@ -58,29 +58,43 @@ ollama_server_infos = OllamaServerInfos()
58
 
59
  def get_combined_auth_dependency(api_key: Optional[str] = None):
60
  """
61
- Create a combined authentication dependency that implements OR logic (pass through any authentication method)
 
62
 
63
  Args:
64
  api_key (Optional[str]): API key for validation
65
 
66
  Returns:
67
- Callable: A dependency function that implements OR authentication logic
68
  """
69
  # Use global whitelist_patterns and auth_configured variables
70
  # whitelist_patterns and auth_configured are already initialized at module level
71
 
72
  # Only calculate api_key_configured as it depends on the function parameter
73
  api_key_configured = bool(api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  async def combined_dependency(
76
  request: Request,
77
- token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
 
78
  ):
79
- # If both authentication methods are not configured, allow access
80
- if not auth_configured and not api_key_configured:
81
- return
82
-
83
- # Check if request path is in whitelist
84
  path = request.url.path
85
  for pattern, is_prefix in whitelist_patterns:
86
  if (is_prefix and path.startswith(pattern)) or (
@@ -88,35 +102,54 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
88
  ):
89
  return # Whitelist path, allow access
90
 
91
- # Access with token
 
 
 
 
 
 
 
 
 
92
  if token:
93
- token_info = auth_handler.validate_token(token)
94
- if auth_configured:
95
- if token_info.get("role") != "guest" or not api_key_configured:
96
- return # Password authentication successful
97
- else:
98
- if token_info.get("role") == "guest":
99
- return # Guest authentication successful
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  raise HTTPException(
101
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Token required"
 
102
  )
103
-
104
- # Try API key authentication (if configured)
105
  if api_key_configured:
106
- api_key_header = request.headers.get("X-API-Key")
107
- if api_key_header and api_key_header == api_key:
108
- return # API key authentication successful
109
- else:
110
- if auth_configured:
111
- raise HTTPException(
112
- status_code=HTTP_403_FORBIDDEN,
113
- detail="API Key required or use password authentication.",
114
- )
115
- else:
116
- raise HTTPException(
117
- status_code=HTTP_403_FORBIDDEN,
118
- detail="API Key required or use guest authentication.",
119
- )
120
 
121
  return combined_dependency
122
 
@@ -145,12 +178,12 @@ def get_api_key_dependency(api_key: Optional[str]):
145
 
146
  return no_auth
147
 
148
- # If API key is configured, use proper authentication
149
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
150
 
151
  async def api_key_auth(
152
  request: Request,
153
- api_key_header_value: Optional[str] = Security(api_key_header),
154
  ):
155
  # Check if request path is in whitelist
156
  path = request.url.path
 
58
 
59
  def get_combined_auth_dependency(api_key: Optional[str] = None):
60
  """
61
+ Create a combined authentication dependency that implements authentication logic
62
+ based on API key, OAuth2 token, and whitelist paths.
63
 
64
  Args:
65
  api_key (Optional[str]): API key for validation
66
 
67
  Returns:
68
+ Callable: A dependency function that implements the authentication logic
69
  """
70
  # Use global whitelist_patterns and auth_configured variables
71
  # whitelist_patterns and auth_configured are already initialized at module level
72
 
73
  # Only calculate api_key_configured as it depends on the function parameter
74
  api_key_configured = bool(api_key)
75
+
76
+ # Create security dependencies with proper descriptions for Swagger UI
77
+ oauth2_scheme = OAuth2PasswordBearer(
78
+ tokenUrl="login",
79
+ auto_error=False,
80
+ description="OAuth2 Password Authentication"
81
+ )
82
+
83
+ # If API key is configured, create an API key header security
84
+ api_key_header = None
85
+ if api_key_configured:
86
+ api_key_header = APIKeyHeader(
87
+ name="X-API-Key",
88
+ auto_error=False,
89
+ description="API Key Authentication"
90
+ )
91
 
92
  async def combined_dependency(
93
  request: Request,
94
+ token: str = Security(oauth2_scheme),
95
+ api_key_header_value: Optional[str] = None if api_key_header is None else Security(api_key_header),
96
  ):
97
+ # 1. Check if path is in whitelist
 
 
 
 
98
  path = request.url.path
99
  for pattern, is_prefix in whitelist_patterns:
100
  if (is_prefix and path.startswith(pattern)) or (
 
102
  ):
103
  return # Whitelist path, allow access
104
 
105
+ # 2. Check for special endpoints (/health and Ollama API)
106
+ is_special_endpoint = path == "/health" or path.startswith("/api/")
107
+ if is_special_endpoint and not api_key_configured:
108
+ return # Special endpoint and no API key configured, allow access
109
+
110
+ # 3. Validate API key
111
+ if api_key_configured and api_key_header_value and api_key_header_value == api_key:
112
+ return # API key validation successful
113
+
114
+ # 4. Validate token
115
  if token:
116
+ try:
117
+ token_info = auth_handler.validate_token(token)
118
+ # Accept guest token if no auth is configured
119
+ if not auth_configured and token_info.get("role") == "guest":
120
+ return
121
+ # Accept non-guest token if auth is configured
122
+ if auth_configured and token_info.get("role") != "guest":
123
+ return
124
+
125
+ # Token validation failed, immediately return 401 error
126
+ raise HTTPException(
127
+ status_code=status.HTTP_401_UNAUTHORIZED,
128
+ detail="Invalid token. Please login again."
129
+ )
130
+ except HTTPException as e:
131
+ # If already a 401 error, re-raise it
132
+ if e.status_code == status.HTTP_401_UNAUTHORIZED:
133
+ raise
134
+ # For other exceptions, continue processing
135
+
136
+ # If token exists but validation failed (didn't return above), return 401
137
  raise HTTPException(
138
+ status_code=status.HTTP_401_UNAUTHORIZED,
139
+ detail="Invalid token. Please login again."
140
  )
141
+
142
+ # 5. No token and API key validation failed, return 403 error
143
  if api_key_configured:
144
+ raise HTTPException(
145
+ status_code=HTTP_403_FORBIDDEN,
146
+ detail="API Key required or login authentication required."
147
+ )
148
+ else:
149
+ raise HTTPException(
150
+ status_code=HTTP_403_FORBIDDEN,
151
+ detail="Login authentication required."
152
+ )
 
 
 
 
 
153
 
154
  return combined_dependency
155
 
 
178
 
179
  return no_auth
180
 
181
+ # If API key is configured, use proper authentication with Security for Swagger UI
182
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
183
 
184
  async def api_key_auth(
185
  request: Request,
186
+ api_key_header_value: Optional[str] = Security(api_key_header, description="API Key for authentication"),
187
  ):
188
  # Check if request path is in whitelist
189
  path = request.url.path