Milin commited on
Commit
ef7c03a
·
1 Parent(s): 950eb48

feat(api): Add user authentication functionality

Browse files

- Implement JWT-based user authentication logic
- Add login endpoint and token validation middleware
- Update API routes with authentication dependencies
- Add authentication-related environment variables
- Optimize requirements.txt with necessary dependencies

env.example CHANGED
@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
148
 
149
  ### Redis
150
  REDIS_URI=redis://localhost:6379
 
 
 
 
 
 
 
 
148
 
149
  ### Redis
150
  REDIS_URI=redis://localhost:6379
151
+
152
+ # For jwt auth
153
+ AUTH_USERNAME=admin # login name
154
+ AUTH_PASSWORD=admin123 # password
155
+ TOKEN_SECRET=your-key # JWT key
156
+ TOKEN_EXPIRE_HOURS=4 # expire duration
157
+ WHITELIST_PATHS=/login,/health # white list
lightrag/api/README.md CHANGED
@@ -295,26 +295,32 @@ You can not change storage implementation selection after you add documents to L
295
 
296
  ### LightRag API Server Comand Line Options
297
 
298
- | Parameter | Default | Description |
299
- |-----------|---------|-------------|
300
- | --host | 0.0.0.0 | Server host |
301
- | --port | 9621 | Server port |
302
- | --working-dir | ./rag_storage | Working directory for RAG storage |
303
- | --input-dir | ./inputs | Directory containing input documents |
304
- | --max-async | 4 | Maximum async operations |
305
- | --max-tokens | 32768 | Maximum token size |
306
- | --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) |
307
- | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
308
- | --verbose | - | Verbose debug output (True, Flase) |
309
- | --key | None | API key for authentication. Protects lightrag server against unauthorized access |
310
- | --ssl | False | Enable HTTPS |
311
- | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
312
- | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
313
- | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
314
- | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
315
- | --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
316
- | --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
317
- | auto-scan-at-startup | - | Scan input directory for new files and start indexing |
 
 
 
 
 
 
318
 
319
  ### Example Usage
320
 
@@ -387,6 +393,19 @@ Note: If you don't need the API functionality, you can install the base package
387
  pip install lightrag-hku
388
  ```
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  ## API Endpoints
391
 
392
  All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
 
295
 
296
  ### LightRag API Server Comand Line Options
297
 
298
+ | Parameter | Default | Description |
299
+ |-------------------------|----------------|-----------------------------------------------------------------------------------------------------------------------------|
300
+ | --host | 0.0.0.0 | Server host |
301
+ | --port | 9621 | Server port |
302
+ | --working-dir | ./rag_storage | Working directory for RAG storage |
303
+ | --input-dir | ./inputs | Directory containing input documents |
304
+ | --max-async | 4 | Maximum async operations |
305
+ | --max-tokens | 32768 | Maximum token size |
306
+ | --timeout | 150 | Timeout in seconds. None for infinite timeout(not recommended) |
307
+ | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
308
+ | --verbose | - | Verbose debug output (True, Flase) |
309
+ | --key | None | API key for authentication. Protects lightrag server against unauthorized access |
310
+ | --ssl | False | Enable HTTPS |
311
+ | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
312
+ | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
313
+ | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
314
+ | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
315
+ | --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
316
+ | --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
317
+ | --auto-scan-at-startup | - | Scan input directory for new files and start indexing |
318
+ | --auth-username | - | Enable jwt if not empty |
319
+ | --auth-password | - | Enable jwt if not empty |
320
+ | --token-secret | - | JWT key |
321
+ | --token-expire-hours | 4 | expire duration |
322
+ | --whitelist-paths | /login,/health | white list |
323
+
324
 
325
  ### Example Usage
326
 
 
393
  pip install lightrag-hku
394
  ```
395
 
396
+ ## Authentication Endpoints
397
+
398
+ ### JWT Authentication Mechanism
399
+ LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
400
+ ```bash
401
+ # For jwt auth
402
+ AUTH_USERNAME=admin # login name --auth-username
403
+ AUTH_PASSWORD=admin123 # password --auth-password
404
+ TOKEN_SECRET=your-key # JWT key --token-secret
405
+ TOKEN_EXPIRE_HOURS=4 # expire duration --token-expire-hours
406
+ WHITELIST_PATHS=/login,/health # white list --whitelist-paths
407
+ ```
408
+
409
  ## API Endpoints
410
 
411
  All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
lightrag/api/auth.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime, timedelta
3
+ import jwt
4
+ from fastapi import HTTPException, status
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class TokenPayload(BaseModel):
9
+ sub: str
10
+ exp: datetime
11
+
12
+
13
+ class AuthHandler:
14
+ def __init__(self):
15
+ self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
16
+ self.algorithm = "HS256"
17
+ self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
18
+
19
+ def create_token(self, username: str) -> str:
20
+ expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
21
+ payload = TokenPayload(sub=username, exp=expire)
22
+ return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
23
+
24
+ def validate_token(self, token: str) -> str:
25
+ try:
26
+ payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
27
+ expire_timestamp = payload["exp"]
28
+ expire_time = datetime.utcfromtimestamp(expire_timestamp)
29
+
30
+ if datetime.utcnow() > expire_time:
31
+ raise HTTPException(
32
+ status_code=status.HTTP_401_UNAUTHORIZED,
33
+ detail="Token expired"
34
+ )
35
+ return payload["sub"]
36
+ except jwt.PyJWTError:
37
+ raise HTTPException(
38
+ status_code=status.HTTP_401_UNAUTHORIZED,
39
+ detail="Invalid token"
40
+ )
41
+
42
+
43
+ auth_handler = AuthHandler()
lightrag/api/lightrag_server.py CHANGED
@@ -5,6 +5,9 @@ LightRAG FastAPI Server
5
  from fastapi import (
6
  FastAPI,
7
  Depends,
 
 
 
8
  )
9
  from fastapi.responses import FileResponse
10
  import asyncio
@@ -25,6 +28,7 @@ from .utils_api import (
25
  parse_args,
26
  get_default_host,
27
  display_splash_screen,
 
28
  )
29
  from lightrag import LightRAG
30
  from lightrag.types import GPTKeywordExtractionFormat
@@ -46,6 +50,8 @@ from lightrag.kg.shared_storage import (
46
  initialize_pipeline_status,
47
  get_all_update_flags_status,
48
  )
 
 
49
 
50
  # Load environment variables
51
  load_dotenv(override=True)
@@ -373,7 +379,29 @@ def create_app(args):
373
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
374
  app.include_router(ollama_api.router, prefix="/api")
375
 
376
- @app.get("/health", dependencies=[Depends(optional_api_key)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  async def get_status():
378
  """Get current system status"""
379
  # Get update flags status for all namespaces
@@ -414,6 +442,12 @@ def create_app(args):
414
  async def webui_root():
415
  return FileResponse(static_dir / "index.html")
416
 
 
 
 
 
 
 
417
  return app
418
 
419
 
 
5
  from fastapi import (
6
  FastAPI,
7
  Depends,
8
+ HTTPException,
9
+ Request,
10
+ status
11
  )
12
  from fastapi.responses import FileResponse
13
  import asyncio
 
28
  parse_args,
29
  get_default_host,
30
  display_splash_screen,
31
+ get_auth_dependency,
32
  )
33
  from lightrag import LightRAG
34
  from lightrag.types import GPTKeywordExtractionFormat
 
50
  initialize_pipeline_status,
51
  get_all_update_flags_status,
52
  )
53
+ from fastapi.security import OAuth2PasswordRequestForm
54
+ from .auth import auth_handler
55
 
56
  # Load environment variables
57
  load_dotenv(override=True)
 
379
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
380
  app.include_router(ollama_api.router, prefix="/api")
381
 
382
+ @app.post("/login")
383
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
384
+ username = os.getenv("AUTH_USERNAME")
385
+ password = os.getenv("AUTH_PASSWORD")
386
+
387
+ if not (username and password):
388
+ raise HTTPException(
389
+ status_code=status.HTTP_501_NOT_IMPLEMENTED,
390
+ detail="Authentication not configured"
391
+ )
392
+
393
+ if form_data.username != username or form_data.password != password:
394
+ raise HTTPException(
395
+ status_code=status.HTTP_401_UNAUTHORIZED,
396
+ detail="Incorrect credentials"
397
+ )
398
+
399
+ return {
400
+ "access_token": auth_handler.create_token(username),
401
+ "token_type": "bearer"
402
+ }
403
+
404
+ @app.get("/health", dependencies=[Depends(optional_api_key), Depends(get_auth_dependency())])
405
  async def get_status():
406
  """Get current system status"""
407
  # Get update flags status for all namespaces
 
442
  async def webui_root():
443
  return FileResponse(static_dir / "index.html")
444
 
445
+ @app.middleware("http")
446
+ async def debug_middleware(request: Request, call_next):
447
+ print(f"Request path: {request.url.path}")
448
+ response = await call_next(request)
449
+ return response
450
+
451
  return app
452
 
453
 
lightrag/api/requirements.txt CHANGED
@@ -8,3 +8,15 @@ python-multipart
8
  tenacity
9
  tiktoken
10
  uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tenacity
9
  tiktoken
10
  uvicorn
11
+ tqdm
12
+ jiter
13
+ httpcore
14
+ distro
15
+ httpx
16
+ openai
17
+ asyncpg
18
+ neo4j
19
+ pytz
20
+ python-jose[cryptography]
21
+ passlib[bcrypt]
22
+ PyJWT
lightrag/api/routers/document_routes.py CHANGED
@@ -16,10 +16,9 @@ from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
- from ..utils_api import get_api_key_dependency
20
 
21
-
22
- router = APIRouter(prefix="/documents", tags=["documents"])
23
 
24
  # Temporary file prefix
25
  temp_prefix = "__tmp__"
 
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
20
 
21
+ router = APIRouter(prefix="/documents", tags=["documents"], dependencies=[Depends(get_auth_dependency())])
 
22
 
23
  # Temporary file prefix
24
  temp_prefix = "__tmp__"
lightrag/api/routers/graph_routes.py CHANGED
@@ -6,9 +6,9 @@ from typing import Optional
6
 
7
  from fastapi import APIRouter, Depends
8
 
9
- from ..utils_api import get_api_key_dependency
10
 
11
- router = APIRouter(tags=["graph"])
12
 
13
 
14
  def create_graph_routes(rag, api_key: Optional[str] = None):
 
6
 
7
  from fastapi import APIRouter, Depends
8
 
9
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
10
 
11
+ router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
12
 
13
 
14
  def create_graph_routes(rag, api_key: Optional[str] = None):
lightrag/api/routers/ollama_api.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import APIRouter, HTTPException, Request
2
  from pydantic import BaseModel
3
  from typing import List, Dict, Any, Optional
4
  import logging
@@ -11,7 +11,7 @@ import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
  from lightrag.utils import encode_string_by_tiktoken
14
- from ..utils_api import ollama_server_infos
15
 
16
 
17
  # query mode according to query prefix (bypass is not LightRAG quer mode)
@@ -126,7 +126,7 @@ class OllamaAPI:
126
  self.rag = rag
127
  self.ollama_server_infos = ollama_server_infos
128
  self.top_k = top_k
129
- self.router = APIRouter(tags=["ollama"])
130
  self.setup_routes()
131
 
132
  def setup_routes(self):
 
1
+ from fastapi import APIRouter, HTTPException, Request, Depends
2
  from pydantic import BaseModel
3
  from typing import List, Dict, Any, Optional
4
  import logging
 
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
  from lightrag.utils import encode_string_by_tiktoken
14
+ from ..utils_api import ollama_server_infos, get_auth_dependency
15
 
16
 
17
  # query mode according to query prefix (bypass is not LightRAG quer mode)
 
126
  self.rag = rag
127
  self.ollama_server_infos = ollama_server_infos
128
  self.top_k = top_k
129
+ self.router = APIRouter(tags=["ollama"], dependencies=[Depends(get_auth_dependency())])
130
  self.setup_routes()
131
 
132
  def setup_routes(self):
lightrag/api/routers/query_routes.py CHANGED
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
8
 
9
  from fastapi import APIRouter, Depends, HTTPException
10
  from lightrag.base import QueryParam
11
- from ..utils_api import get_api_key_dependency
12
  from pydantic import BaseModel, Field, field_validator
13
 
14
  from ascii_colors import trace_exception
15
 
16
- router = APIRouter(tags=["query"])
17
 
18
 
19
  class QueryRequest(BaseModel):
 
8
 
9
  from fastapi import APIRouter, Depends, HTTPException
10
  from lightrag.base import QueryParam
11
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
12
  from pydantic import BaseModel, Field, field_validator
13
 
14
  from ascii_colors import trace_exception
15
 
16
+ router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
17
 
18
 
19
  class QueryRequest(BaseModel):
lightrag/api/utils_api.py CHANGED
@@ -9,10 +9,16 @@ import sys
9
  import logging
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
- from fastapi import HTTPException, Security
 
 
 
 
 
13
  from dotenv import load_dotenv
14
- from fastapi.security import APIKeyHeader
15
  from starlette.status import HTTP_403_FORBIDDEN
 
16
 
17
  # Load environment variables
18
  load_dotenv(override=True)
@@ -31,6 +37,24 @@ class OllamaServerInfos:
31
  ollama_server_infos = OllamaServerInfos()
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_api_key_dependency(api_key: Optional[str]):
35
  """
36
  Create an API key dependency for route protection.
@@ -288,6 +312,38 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
288
  help="Embedding binding type (default: from env or ollama)",
289
  )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  args = parser.parse_args()
292
 
293
  # If in uvicorn mode and workers > 1, force it to 1 and log warning
 
9
  import logging
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
+ from fastapi import (
13
+ HTTPException,
14
+ Security,
15
+ Depends,
16
+ Request
17
+ )
18
  from dotenv import load_dotenv
19
+ from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
20
  from starlette.status import HTTP_403_FORBIDDEN
21
+ from .auth import auth_handler
22
 
23
  # Load environment variables
24
  load_dotenv(override=True)
 
37
  ollama_server_infos = OllamaServerInfos()
38
 
39
 
40
+ def get_auth_dependency():
41
+ whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
42
+
43
+ async def dependency(
44
+ request: Request,
45
+ token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False))
46
+ ):
47
+ if request.url.path in whitelist:
48
+ return
49
+
50
+ if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
51
+ return
52
+
53
+ auth_handler.validate_token(token)
54
+
55
+ return dependency
56
+
57
+
58
  def get_api_key_dependency(api_key: Optional[str]):
59
  """
60
  Create an API key dependency for route protection.
 
312
  help="Embedding binding type (default: from env or ollama)",
313
  )
314
 
315
+ # Authentication configuration
316
+ parser.add_argument(
317
+ "--auth-username",
318
+ type=str,
319
+ default=get_env_value("AUTH_USERNAME", ""),
320
+ help="Login username (default: from env or empty)"
321
+ )
322
+ parser.add_argument(
323
+ "--auth-password",
324
+ type=str,
325
+ default=get_env_value("AUTH_PASSWORD", ""),
326
+ help="Login password (default: from env or empty)"
327
+ )
328
+ parser.add_argument(
329
+ "--token-secret",
330
+ type=str,
331
+ default=get_env_value("TOKEN_SECRET", ""),
332
+ help="JWT signing secret (default: from env or empty)"
333
+ )
334
+ parser.add_argument(
335
+ "--token-expire-hours",
336
+ type=int,
337
+ default=get_env_value("TOKEN_EXPIRE_HOURS", 4, int),
338
+ help="Token validity in hours (default: from env or 4)"
339
+ )
340
+ parser.add_argument(
341
+ "--whitelist-paths",
342
+ type=str,
343
+ default=get_env_value("WHITELIST_PATHS", "/login,/health"),
344
+ help="Comma-separated auth-exempt paths (default: from env or /login,/health)"
345
+ )
346
+
347
  args = parser.parse_args()
348
 
349
  # If in uvicorn mode and workers > 1, force it to 1 and log warning