File size: 5,687 Bytes
824bf31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
"""
Module: api.middleware.authentication
Description: Authentication middleware for API endpoints
Author: Anderson H. Silva
Date: 2025-01-24
License: Proprietary - All rights reserved
"""
from typing import Optional
from datetime import datetime, timedelta
from fastapi import Request, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt
from src.core import get_logger, settings
class AuthenticationMiddleware:
"""Authentication middleware for API endpoints."""
def __init__(self):
"""Initialize authentication middleware."""
self.logger = get_logger(__name__)
self.security = HTTPBearer(auto_error=False)
async def __call__(self, request: Request):
"""Authenticate request."""
# Skip authentication for health check and docs
if request.url.path in ["/health", "/health/", "/docs", "/openapi.json", "/"]:
return True
# Check for API key in headers
api_key = request.headers.get("X-API-Key")
if api_key:
return await self._validate_api_key(api_key, request)
# Check for JWT token
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[7:] # Remove "Bearer " prefix
return await self._validate_jwt_token(token, request)
# For development, allow unauthenticated access
if settings.app_env == "development":
self.logger.warning(
"unauthenticated_request_allowed",
path=request.url.path,
method=request.method,
environment="development"
)
return True
# Production requires authentication
self.logger.warning(
"unauthenticated_request_rejected",
path=request.url.path,
method=request.method,
)
raise HTTPException(
status_code=401,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"}
)
async def _validate_api_key(self, api_key: str, request: Request) -> bool:
"""Validate API key."""
# In a real implementation, this would check against a database
# For now, we'll use a simple validation
if not api_key or len(api_key) < 32:
self.logger.warning(
"invalid_api_key_format",
api_key_length=len(api_key) if api_key else 0,
path=request.url.path,
)
raise HTTPException(
status_code=401,
detail="Invalid API key format"
)
# TODO: Implement proper API key validation
# For development, accept any key with correct format
self.logger.info(
"api_key_authentication_success",
path=request.url.path,
method=request.method,
)
return True
async def _validate_jwt_token(self, token: str, request: Request) -> bool:
"""Validate JWT token."""
try:
# Decode JWT token
payload = jwt.decode(
token,
settings.jwt_secret_key.get_secret_value(),
algorithms=[settings.jwt_algorithm]
)
# Check expiration
exp = payload.get("exp")
if exp and datetime.utcnow().timestamp() > exp:
raise HTTPException(
status_code=401,
detail="Token has expired"
)
# Store user info in request state
request.state.user_id = payload.get("sub")
request.state.user_email = payload.get("email")
request.state.user_roles = payload.get("roles", [])
self.logger.info(
"jwt_authentication_success",
user_id=request.state.user_id,
path=request.url.path,
method=request.method,
)
return True
except jwt.ExpiredSignatureError:
self.logger.warning(
"jwt_token_expired",
path=request.url.path,
)
raise HTTPException(
status_code=401,
detail="Token has expired"
)
except jwt.JWTError as e:
self.logger.warning(
"jwt_validation_failed",
error=str(e),
path=request.url.path,
)
raise HTTPException(
status_code=401,
detail="Invalid token"
)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_token_expire_minutes)
to_encode.update({"exp": expire.timestamp()})
encoded_jwt = jwt.encode(
to_encode,
settings.jwt_secret_key.get_secret_value(),
algorithm=settings.jwt_algorithm
)
return encoded_jwt
def get_current_user(request: Request) -> dict:
"""Get current authenticated user."""
return {
"user_id": getattr(request.state, "user_id", None),
"email": getattr(request.state, "user_email", None),
"roles": getattr(request.state, "user_roles", []),
} |