zrguo commited on
Commit
d00af94
·
unverified ·
2 Parent(s): b7d0d48 bd11593

Merge pull request #1000 from lcjqyml/feat_login-jwt

Browse files
env.example CHANGED
@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
148
 
149
  ### Redis
150
  REDIS_URI=redis://localhost:6379
 
 
 
 
 
 
 
 
148
 
149
  ### Redis
150
  REDIS_URI=redis://localhost:6379
151
+
152
+ # For jwt auth
153
+ AUTH_USERNAME=admin # login name
154
+ AUTH_PASSWORD=admin123 # password
155
+ TOKEN_SECRET=your-key # JWT key
156
+ TOKEN_EXPIRE_HOURS=4 # expire duration
157
+ WHITELIST_PATHS=/login,/health # white list
lightrag/api/README.md CHANGED
@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
387
  pip install lightrag-hku
388
  ```
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  ## API Endpoints
391
 
392
  All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
 
387
  pip install lightrag-hku
388
  ```
389
 
390
+ ## Authentication Endpoints
391
+
392
+ ### JWT Authentication Mechanism
393
+ LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
394
+ ```bash
395
+ # For jwt auth
396
+ AUTH_USERNAME=admin # login name
397
+ AUTH_PASSWORD=admin123 # password
398
+ TOKEN_SECRET=your-key # JWT key
399
+ TOKEN_EXPIRE_HOURS=4 # expire duration
400
+ WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
401
+ ```
402
+
403
  ## API Endpoints
404
 
405
  All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
lightrag/api/auth.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime, timedelta
3
+ import jwt
4
+ from fastapi import HTTPException, status
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class TokenPayload(BaseModel):
9
+ sub: str
10
+ exp: datetime
11
+
12
+
13
+ class AuthHandler:
14
+ def __init__(self):
15
+ self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
16
+ self.algorithm = "HS256"
17
+ self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
18
+
19
+ def create_token(self, username: str) -> str:
20
+ expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
21
+ payload = TokenPayload(sub=username, exp=expire)
22
+ return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
23
+
24
+ def validate_token(self, token: str) -> str:
25
+ try:
26
+ payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
27
+ expire_timestamp = payload["exp"]
28
+ expire_time = datetime.utcfromtimestamp(expire_timestamp)
29
+
30
+ if datetime.utcnow() > expire_time:
31
+ raise HTTPException(
32
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
33
+ )
34
+ return payload["sub"]
35
+ except jwt.PyJWTError:
36
+ raise HTTPException(
37
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
38
+ )
39
+
40
+
41
+ auth_handler = AuthHandler()
lightrag/api/lightrag_server.py CHANGED
@@ -2,10 +2,7 @@
2
  LightRAG FastAPI Server
3
  """
4
 
5
- from fastapi import (
6
- FastAPI,
7
- Depends,
8
- )
9
  import asyncio
10
  import os
11
  import logging
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
45
  initialize_pipeline_status,
46
  get_all_update_flags_status,
47
  )
 
 
48
 
49
  # Load environment variables
50
  # Updated to use the .env that is inside the current folder
@@ -372,6 +371,27 @@ def create_app(args):
372
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
373
  app.include_router(ollama_api.router, prefix="/api")
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  @app.get("/health", dependencies=[Depends(optional_api_key)])
376
  async def get_status():
377
  """Get current system status"""
 
2
  LightRAG FastAPI Server
3
  """
4
 
5
+ from fastapi import FastAPI, Depends, HTTPException, status
 
 
 
6
  import asyncio
7
  import os
8
  import logging
 
42
  initialize_pipeline_status,
43
  get_all_update_flags_status,
44
  )
45
+ from fastapi.security import OAuth2PasswordRequestForm
46
+ from .auth import auth_handler
47
 
48
  # Load environment variables
49
  # Updated to use the .env that is inside the current folder
 
371
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
372
  app.include_router(ollama_api.router, prefix="/api")
373
 
374
+ @app.post("/login")
375
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
376
+ username = os.getenv("AUTH_USERNAME")
377
+ password = os.getenv("AUTH_PASSWORD")
378
+
379
+ if not (username and password):
380
+ raise HTTPException(
381
+ status_code=status.HTTP_501_NOT_IMPLEMENTED,
382
+ detail="Authentication not configured",
383
+ )
384
+
385
+ if form_data.username != username or form_data.password != password:
386
+ raise HTTPException(
387
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
388
+ )
389
+
390
+ return {
391
+ "access_token": auth_handler.create_token(username),
392
+ "token_type": "bearer",
393
+ }
394
+
395
  @app.get("/health", dependencies=[Depends(optional_api_key)])
396
  async def get_status():
397
  """Get current system status"""
lightrag/api/requirements.txt CHANGED
@@ -1,10 +1,20 @@
1
  aiofiles
2
  ascii_colors
 
 
3
  fastapi
 
 
 
4
  numpy
 
 
5
  pipmaster
 
6
  python-dotenv
 
7
  python-multipart
 
8
  tenacity
9
  tiktoken
10
  uvicorn
 
1
  aiofiles
2
  ascii_colors
3
+ asyncpg
4
+ distro
5
  fastapi
6
+ httpcore
7
+ httpx
8
+ jiter
9
  numpy
10
+ openai
11
+ passlib[bcrypt]
12
  pipmaster
13
+ PyJWT
14
  python-dotenv
15
+ python-jose[cryptography]
16
  python-multipart
17
+ pytz
18
  tenacity
19
  tiktoken
20
  uvicorn
lightrag/api/routers/document_routes.py CHANGED
@@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
- from ..utils_api import get_api_key_dependency
20
 
21
-
22
- router = APIRouter(prefix="/documents", tags=["documents"])
 
 
 
23
 
24
  # Temporary file prefix
25
  temp_prefix = "__tmp__"
 
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
20
 
21
+ router = APIRouter(
22
+ prefix="/documents",
23
+ tags=["documents"],
24
+ dependencies=[Depends(get_auth_dependency())],
25
+ )
26
 
27
  # Temporary file prefix
28
  temp_prefix = "__tmp__"
lightrag/api/routers/graph_routes.py CHANGED
@@ -6,9 +6,9 @@ from typing import Optional
6
 
7
  from fastapi import APIRouter, Depends
8
 
9
- from ..utils_api import get_api_key_dependency
10
 
11
- router = APIRouter(tags=["graph"])
12
 
13
 
14
  def create_graph_routes(rag, api_key: Optional[str] = None):
 
6
 
7
  from fastapi import APIRouter, Depends
8
 
9
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
10
 
11
+ router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
12
 
13
 
14
  def create_graph_routes(rag, api_key: Optional[str] = None):
lightrag/api/routers/query_routes.py CHANGED
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
8
 
9
  from fastapi import APIRouter, Depends, HTTPException
10
  from lightrag.base import QueryParam
11
- from ..utils_api import get_api_key_dependency
12
  from pydantic import BaseModel, Field, field_validator
13
 
14
  from ascii_colors import trace_exception
15
 
16
- router = APIRouter(tags=["query"])
17
 
18
 
19
  class QueryRequest(BaseModel):
 
8
 
9
  from fastapi import APIRouter, Depends, HTTPException
10
  from lightrag.base import QueryParam
11
+ from ..utils_api import get_api_key_dependency, get_auth_dependency
12
  from pydantic import BaseModel, Field, field_validator
13
 
14
  from ascii_colors import trace_exception
15
 
16
+ router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
17
 
18
 
19
  class QueryRequest(BaseModel):
lightrag/api/utils_api.py CHANGED
@@ -9,10 +9,11 @@ import sys
9
  import logging
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
- from fastapi import HTTPException, Security
13
  from dotenv import load_dotenv
14
- from fastapi.security import APIKeyHeader
15
  from starlette.status import HTTP_403_FORBIDDEN
 
16
 
17
  # Load environment variables
18
  load_dotenv(override=True)
@@ -31,6 +32,24 @@ class OllamaServerInfos:
31
  ollama_server_infos = OllamaServerInfos()
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_api_key_dependency(api_key: Optional[str]):
35
  """
36
  Create an API key dependency for route protection.
 
9
  import logging
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
+ from fastapi import HTTPException, Security, Depends, Request
13
  from dotenv import load_dotenv
14
+ from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
15
  from starlette.status import HTTP_403_FORBIDDEN
16
+ from .auth import auth_handler
17
 
18
  # Load environment variables
19
  load_dotenv(override=True)
 
32
  ollama_server_infos = OllamaServerInfos()
33
 
34
 
35
+ def get_auth_dependency():
36
+ whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
37
+
38
+ async def dependency(
39
+ request: Request,
40
+ token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
41
+ ):
42
+ if request.url.path in whitelist:
43
+ return
44
+
45
+ if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
46
+ return
47
+
48
+ auth_handler.validate_token(token)
49
+
50
+ return dependency
51
+
52
+
53
  def get_api_key_dependency(api_key: Optional[str]):
54
  """
55
  Create an API key dependency for route protection.