File size: 5,687 Bytes
824bf31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Module: api.middleware.authentication
Description: Authentication middleware for API endpoints
Author: Anderson H. Silva
Date: 2025-01-24
License: Proprietary - All rights reserved
"""

from typing import Optional
from datetime import datetime, timedelta

from fastapi import Request, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt

from src.core import get_logger, settings


class AuthenticationMiddleware:
    """Authentication middleware for API endpoints."""
    
    def __init__(self):
        """Initialize authentication middleware."""
        self.logger = get_logger(__name__)
        self.security = HTTPBearer(auto_error=False)
    
    async def __call__(self, request: Request):
        """Authenticate request."""
        # Skip authentication for health check and docs
        if request.url.path in ["/health", "/health/", "/docs", "/openapi.json", "/"]:
            return True
        
        # Check for API key in headers
        api_key = request.headers.get("X-API-Key")
        if api_key:
            return await self._validate_api_key(api_key, request)
        
        # Check for JWT token
        auth_header = request.headers.get("Authorization")
        if auth_header and auth_header.startswith("Bearer "):
            token = auth_header[7:]  # Remove "Bearer " prefix
            return await self._validate_jwt_token(token, request)
        
        # For development, allow unauthenticated access
        if settings.app_env == "development":
            self.logger.warning(
                "unauthenticated_request_allowed",
                path=request.url.path,
                method=request.method,
                environment="development"
            )
            return True
        
        # Production requires authentication
        self.logger.warning(
            "unauthenticated_request_rejected",
            path=request.url.path,
            method=request.method,
        )
        
        raise HTTPException(
            status_code=401,
            detail="Authentication required",
            headers={"WWW-Authenticate": "Bearer"}
        )
    
    async def _validate_api_key(self, api_key: str, request: Request) -> bool:
        """Validate API key."""
        # In a real implementation, this would check against a database
        # For now, we'll use a simple validation
        
        if not api_key or len(api_key) < 32:
            self.logger.warning(
                "invalid_api_key_format",
                api_key_length=len(api_key) if api_key else 0,
                path=request.url.path,
            )
            raise HTTPException(
                status_code=401,
                detail="Invalid API key format"
            )
        
        # TODO: Implement proper API key validation
        # For development, accept any key with correct format
        self.logger.info(
            "api_key_authentication_success",
            path=request.url.path,
            method=request.method,
        )
        
        return True
    
    async def _validate_jwt_token(self, token: str, request: Request) -> bool:
        """Validate JWT token."""
        try:
            # Decode JWT token
            payload = jwt.decode(
                token,
                settings.jwt_secret_key.get_secret_value(),
                algorithms=[settings.jwt_algorithm]
            )
            
            # Check expiration
            exp = payload.get("exp")
            if exp and datetime.utcnow().timestamp() > exp:
                raise HTTPException(
                    status_code=401,
                    detail="Token has expired"
                )
            
            # Store user info in request state
            request.state.user_id = payload.get("sub")
            request.state.user_email = payload.get("email")
            request.state.user_roles = payload.get("roles", [])
            
            self.logger.info(
                "jwt_authentication_success",
                user_id=request.state.user_id,
                path=request.url.path,
                method=request.method,
            )
            
            return True
            
        except jwt.ExpiredSignatureError:
            self.logger.warning(
                "jwt_token_expired",
                path=request.url.path,
            )
            raise HTTPException(
                status_code=401,
                detail="Token has expired"
            )
        except jwt.JWTError as e:
            self.logger.warning(
                "jwt_validation_failed",
                error=str(e),
                path=request.url.path,
            )
            raise HTTPException(
                status_code=401,
                detail="Invalid token"
            )


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """Create JWT access token."""
    to_encode = data.copy()
    
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=settings.jwt_access_token_expire_minutes)
    
    to_encode.update({"exp": expire.timestamp()})
    
    encoded_jwt = jwt.encode(
        to_encode,
        settings.jwt_secret_key.get_secret_value(),
        algorithm=settings.jwt_algorithm
    )
    
    return encoded_jwt


def get_current_user(request: Request) -> dict:
    """Get current authenticated user."""
    return {
        "user_id": getattr(request.state, "user_id", None),
        "email": getattr(request.state, "user_email", None),
        "roles": getattr(request.state, "user_roles", []),
    }