Merge pull request #1000 from lcjqyml/feat_login-jwt
Browse files- env.example +7 -0
- lightrag/api/README.md +13 -0
- lightrag/api/auth.py +41 -0
- lightrag/api/lightrag_server.py +24 -4
- lightrag/api/requirements.txt +10 -0
- lightrag/api/routers/document_routes.py +6 -3
- lightrag/api/routers/graph_routes.py +2 -2
- lightrag/api/routers/query_routes.py +2 -2
- lightrag/api/utils_api.py +21 -2
env.example
CHANGED
@@ -148,3 +148,10 @@ QDRANT_URL=http://localhost:16333
|
|
148 |
|
149 |
### Redis
|
150 |
REDIS_URI=redis://localhost:6379
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
### Redis
|
150 |
REDIS_URI=redis://localhost:6379
|
151 |
+
|
152 |
+
# For jwt auth
|
153 |
+
AUTH_USERNAME=admin # login name
|
154 |
+
AUTH_PASSWORD=admin123 # password
|
155 |
+
TOKEN_SECRET=your-key # JWT key
|
156 |
+
TOKEN_EXPIRE_HOURS=4 # expire duration
|
157 |
+
WHITELIST_PATHS=/login,/health # white list
|
lightrag/api/README.md
CHANGED
@@ -387,6 +387,19 @@ Note: If you don't need the API functionality, you can install the base package
|
|
387 |
pip install lightrag-hku
|
388 |
```
|
389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
## API Endpoints
|
391 |
|
392 |
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
|
|
387 |
pip install lightrag-hku
|
388 |
```
|
389 |
|
390 |
+
## Authentication Endpoints
|
391 |
+
|
392 |
+
### JWT Authentication Mechanism
|
393 |
+
LightRAG API Server implements JWT-based authentication using HS256 algorithm. To enable secure access control, the following environment variables are required:
|
394 |
+
```bash
|
395 |
+
# For jwt auth
|
396 |
+
AUTH_USERNAME=admin # login name
|
397 |
+
AUTH_PASSWORD=admin123 # password
|
398 |
+
TOKEN_SECRET=your-key # JWT key
|
399 |
+
TOKEN_EXPIRE_HOURS=4 # expire duration
|
400 |
+
WHITELIST_PATHS=/api1,/api2 # white list. /login,/health,/docs,/redoc,/openapi.json are whitelisted by default.
|
401 |
+
```
|
402 |
+
|
403 |
## API Endpoints
|
404 |
|
405 |
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. When API Server is running, visit:
|
lightrag/api/auth.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
import jwt
|
4 |
+
from fastapi import HTTPException, status
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
|
8 |
+
class TokenPayload(BaseModel):
|
9 |
+
sub: str
|
10 |
+
exp: datetime
|
11 |
+
|
12 |
+
|
13 |
+
class AuthHandler:
|
14 |
+
def __init__(self):
|
15 |
+
self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
|
16 |
+
self.algorithm = "HS256"
|
17 |
+
self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
|
18 |
+
|
19 |
+
def create_token(self, username: str) -> str:
|
20 |
+
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
21 |
+
payload = TokenPayload(sub=username, exp=expire)
|
22 |
+
return jwt.encode(payload.dict(), self.secret, algorithm=self.algorithm)
|
23 |
+
|
24 |
+
def validate_token(self, token: str) -> str:
|
25 |
+
try:
|
26 |
+
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
27 |
+
expire_timestamp = payload["exp"]
|
28 |
+
expire_time = datetime.utcfromtimestamp(expire_timestamp)
|
29 |
+
|
30 |
+
if datetime.utcnow() > expire_time:
|
31 |
+
raise HTTPException(
|
32 |
+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
33 |
+
)
|
34 |
+
return payload["sub"]
|
35 |
+
except jwt.PyJWTError:
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
auth_handler = AuthHandler()
|
lightrag/api/lightrag_server.py
CHANGED
@@ -2,10 +2,7 @@
|
|
2 |
LightRAG FastAPI Server
|
3 |
"""
|
4 |
|
5 |
-
from fastapi import
|
6 |
-
FastAPI,
|
7 |
-
Depends,
|
8 |
-
)
|
9 |
import asyncio
|
10 |
import os
|
11 |
import logging
|
@@ -45,6 +42,8 @@ from lightrag.kg.shared_storage import (
|
|
45 |
initialize_pipeline_status,
|
46 |
get_all_update_flags_status,
|
47 |
)
|
|
|
|
|
48 |
|
49 |
# Load environment variables
|
50 |
# Updated to use the .env that is inside the current folder
|
@@ -372,6 +371,27 @@ def create_app(args):
|
|
372 |
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
373 |
app.include_router(ollama_api.router, prefix="/api")
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
376 |
async def get_status():
|
377 |
"""Get current system status"""
|
|
|
2 |
LightRAG FastAPI Server
|
3 |
"""
|
4 |
|
5 |
+
from fastapi import FastAPI, Depends, HTTPException, status
|
|
|
|
|
|
|
6 |
import asyncio
|
7 |
import os
|
8 |
import logging
|
|
|
42 |
initialize_pipeline_status,
|
43 |
get_all_update_flags_status,
|
44 |
)
|
45 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
46 |
+
from .auth import auth_handler
|
47 |
|
48 |
# Load environment variables
|
49 |
# Updated to use the .env that is inside the current folder
|
|
|
371 |
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
372 |
app.include_router(ollama_api.router, prefix="/api")
|
373 |
|
374 |
+
@app.post("/login")
|
375 |
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
376 |
+
username = os.getenv("AUTH_USERNAME")
|
377 |
+
password = os.getenv("AUTH_PASSWORD")
|
378 |
+
|
379 |
+
if not (username and password):
|
380 |
+
raise HTTPException(
|
381 |
+
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
382 |
+
detail="Authentication not configured",
|
383 |
+
)
|
384 |
+
|
385 |
+
if form_data.username != username or form_data.password != password:
|
386 |
+
raise HTTPException(
|
387 |
+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
|
388 |
+
)
|
389 |
+
|
390 |
+
return {
|
391 |
+
"access_token": auth_handler.create_token(username),
|
392 |
+
"token_type": "bearer",
|
393 |
+
}
|
394 |
+
|
395 |
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
396 |
async def get_status():
|
397 |
"""Get current system status"""
|
lightrag/api/requirements.txt
CHANGED
@@ -1,10 +1,20 @@
|
|
1 |
aiofiles
|
2 |
ascii_colors
|
|
|
|
|
3 |
fastapi
|
|
|
|
|
|
|
4 |
numpy
|
|
|
|
|
5 |
pipmaster
|
|
|
6 |
python-dotenv
|
|
|
7 |
python-multipart
|
|
|
8 |
tenacity
|
9 |
tiktoken
|
10 |
uvicorn
|
|
|
1 |
aiofiles
|
2 |
ascii_colors
|
3 |
+
asyncpg
|
4 |
+
distro
|
5 |
fastapi
|
6 |
+
httpcore
|
7 |
+
httpx
|
8 |
+
jiter
|
9 |
numpy
|
10 |
+
openai
|
11 |
+
passlib[bcrypt]
|
12 |
pipmaster
|
13 |
+
PyJWT
|
14 |
python-dotenv
|
15 |
+
python-jose[cryptography]
|
16 |
python-multipart
|
17 |
+
pytz
|
18 |
tenacity
|
19 |
tiktoken
|
20 |
uvicorn
|
lightrag/api/routers/document_routes.py
CHANGED
@@ -16,10 +16,13 @@ from pydantic import BaseModel, Field, field_validator
|
|
16 |
|
17 |
from lightrag import LightRAG
|
18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
19 |
-
from ..utils_api import get_api_key_dependency
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
# Temporary file prefix
|
25 |
temp_prefix = "__tmp__"
|
|
|
16 |
|
17 |
from lightrag import LightRAG
|
18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
19 |
+
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
20 |
|
21 |
+
router = APIRouter(
|
22 |
+
prefix="/documents",
|
23 |
+
tags=["documents"],
|
24 |
+
dependencies=[Depends(get_auth_dependency())],
|
25 |
+
)
|
26 |
|
27 |
# Temporary file prefix
|
28 |
temp_prefix = "__tmp__"
|
lightrag/api/routers/graph_routes.py
CHANGED
@@ -6,9 +6,9 @@ from typing import Optional
|
|
6 |
|
7 |
from fastapi import APIRouter, Depends
|
8 |
|
9 |
-
from ..utils_api import get_api_key_dependency
|
10 |
|
11 |
-
router = APIRouter(tags=["graph"])
|
12 |
|
13 |
|
14 |
def create_graph_routes(rag, api_key: Optional[str] = None):
|
|
|
6 |
|
7 |
from fastapi import APIRouter, Depends
|
8 |
|
9 |
+
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
10 |
|
11 |
+
router = APIRouter(tags=["graph"], dependencies=[Depends(get_auth_dependency())])
|
12 |
|
13 |
|
14 |
def create_graph_routes(rag, api_key: Optional[str] = None):
|
lightrag/api/routers/query_routes.py
CHANGED
@@ -8,12 +8,12 @@ from typing import Any, Dict, List, Literal, Optional
|
|
8 |
|
9 |
from fastapi import APIRouter, Depends, HTTPException
|
10 |
from lightrag.base import QueryParam
|
11 |
-
from ..utils_api import get_api_key_dependency
|
12 |
from pydantic import BaseModel, Field, field_validator
|
13 |
|
14 |
from ascii_colors import trace_exception
|
15 |
|
16 |
-
router = APIRouter(tags=["query"])
|
17 |
|
18 |
|
19 |
class QueryRequest(BaseModel):
|
|
|
8 |
|
9 |
from fastapi import APIRouter, Depends, HTTPException
|
10 |
from lightrag.base import QueryParam
|
11 |
+
from ..utils_api import get_api_key_dependency, get_auth_dependency
|
12 |
from pydantic import BaseModel, Field, field_validator
|
13 |
|
14 |
from ascii_colors import trace_exception
|
15 |
|
16 |
+
router = APIRouter(tags=["query"], dependencies=[Depends(get_auth_dependency())])
|
17 |
|
18 |
|
19 |
class QueryRequest(BaseModel):
|
lightrag/api/utils_api.py
CHANGED
@@ -9,10 +9,11 @@ import sys
|
|
9 |
import logging
|
10 |
from ascii_colors import ASCIIColors
|
11 |
from lightrag.api import __api_version__
|
12 |
-
from fastapi import HTTPException, Security
|
13 |
from dotenv import load_dotenv
|
14 |
-
from fastapi.security import APIKeyHeader
|
15 |
from starlette.status import HTTP_403_FORBIDDEN
|
|
|
16 |
|
17 |
# Load environment variables
|
18 |
load_dotenv(override=True)
|
@@ -31,6 +32,24 @@ class OllamaServerInfos:
|
|
31 |
ollama_server_infos = OllamaServerInfos()
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def get_api_key_dependency(api_key: Optional[str]):
|
35 |
"""
|
36 |
Create an API key dependency for route protection.
|
|
|
9 |
import logging
|
10 |
from ascii_colors import ASCIIColors
|
11 |
from lightrag.api import __api_version__
|
12 |
+
from fastapi import HTTPException, Security, Depends, Request
|
13 |
from dotenv import load_dotenv
|
14 |
+
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
15 |
from starlette.status import HTTP_403_FORBIDDEN
|
16 |
+
from .auth import auth_handler
|
17 |
|
18 |
# Load environment variables
|
19 |
load_dotenv(override=True)
|
|
|
32 |
ollama_server_infos = OllamaServerInfos()
|
33 |
|
34 |
|
35 |
+
def get_auth_dependency():
|
36 |
+
whitelist = os.getenv("WHITELIST_PATHS", "").split(",")
|
37 |
+
|
38 |
+
async def dependency(
|
39 |
+
request: Request,
|
40 |
+
token: str = Depends(OAuth2PasswordBearer(tokenUrl="login", auto_error=False)),
|
41 |
+
):
|
42 |
+
if request.url.path in whitelist:
|
43 |
+
return
|
44 |
+
|
45 |
+
if not (os.getenv("AUTH_USERNAME") and os.getenv("AUTH_PASSWORD")):
|
46 |
+
return
|
47 |
+
|
48 |
+
auth_handler.validate_token(token)
|
49 |
+
|
50 |
+
return dependency
|
51 |
+
|
52 |
+
|
53 |
def get_api_key_dependency(api_key: Optional[str]):
|
54 |
"""
|
55 |
Create an API key dependency for route protection.
|