anderson-ufrj
feat(security): implement IP whitelist for production environments
f70869e
raw
history blame
8.57 kB
"""
Module: api.middleware.ip_whitelist
Description: IP whitelist middleware for production environments
Author: Anderson H. Silva
Date: 2025-01-25
License: Proprietary - All rights reserved
"""
from typing import Optional, List, Set
import ipaddress
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from src.core import get_logger
from src.core.config import settings
from src.services.ip_whitelist_service import ip_whitelist_service
from src.db.session import get_session
logger = get_logger(__name__)
class IPWhitelistMiddleware(BaseHTTPMiddleware):
"""
IP whitelist middleware for production environments.
Features:
- Environment-based activation
- Path exclusions
- Multiple IP extraction methods
- Performance optimization with caching
"""
def __init__(
self,
app,
enabled: Optional[bool] = None,
excluded_paths: Optional[List[str]] = None,
strict_mode: bool = True
):
"""
Initialize IP whitelist middleware.
Args:
app: FastAPI application
enabled: Force enable/disable (None = auto based on environment)
excluded_paths: Paths to exclude from whitelist check
strict_mode: If True, reject if can't determine IP
"""
super().__init__(app)
# Auto-enable in production if not specified
if enabled is None:
self.enabled = settings.is_production
else:
self.enabled = enabled
self.excluded_paths = set(excluded_paths or self._get_default_excluded_paths())
self.strict_mode = strict_mode
logger.info(
"ip_whitelist_middleware_initialized",
enabled=self.enabled,
environment=settings.app_env,
strict_mode=self.strict_mode
)
async def dispatch(self, request: Request, call_next):
"""Process request with IP whitelist check."""
# Skip if disabled
if not self.enabled:
return await call_next(request)
# Skip excluded paths
if self._should_skip(request.url.path):
return await call_next(request)
try:
# Extract client IP
client_ip = self._get_client_ip(request)
if not client_ip:
if self.strict_mode:
logger.warning(
"ip_whitelist_no_client_ip",
path=request.url.path,
headers=dict(request.headers)
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "Client IP could not be determined",
"error": "IP_NOT_DETERMINED"
}
)
else:
# Allow through if can't determine IP
logger.debug("ip_whitelist_no_client_ip_allowing")
return await call_next(request)
# Check whitelist
async with get_session() as session:
is_whitelisted = await ip_whitelist_service.check_ip(
session,
client_ip,
environment=settings.app_env
)
if not is_whitelisted:
logger.warning(
"ip_whitelist_access_denied",
client_ip=client_ip,
path=request.url.path,
method=request.method
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "Access denied: IP not whitelisted",
"error": "IP_NOT_WHITELISTED"
}
)
# IP is whitelisted, proceed
request.state.client_ip = client_ip
request.state.ip_whitelisted = True
return await call_next(request)
except Exception as e:
logger.error(
"ip_whitelist_error",
error=str(e),
exc_info=True
)
# In case of error, fail open or closed based on strict mode
if self.strict_mode:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "IP whitelist check failed",
"error": "WHITELIST_CHECK_ERROR"
}
)
else:
# Allow through on error
return await call_next(request)
def _get_default_excluded_paths(self) -> List[str]:
"""Get default paths to exclude from whitelist."""
return [
# Health checks
"/health",
"/healthz",
"/ping",
"/ready",
# Documentation
"/docs",
"/redoc",
"/openapi.json",
# Metrics
"/metrics",
# Static assets
"/static",
"/favicon.ico",
"/_next",
# Public endpoints
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/public",
# Webhook endpoints (they have their own auth)
"/api/v1/webhooks/incoming"
]
def _should_skip(self, path: str) -> bool:
"""Check if path should skip whitelist check."""
# Exact match
if path in self.excluded_paths:
return True
# Prefix match
for excluded in self.excluded_paths:
if excluded.endswith("*") and path.startswith(excluded[:-1]):
return True
if path.startswith(excluded + "/"):
return True
return False
def _get_client_ip(self, request: Request) -> Optional[str]:
"""
Extract client IP from request.
Tries multiple methods in order:
1. X-Real-IP header
2. X-Forwarded-For header
3. CloudFlare headers
4. Direct client connection
"""
# Try X-Real-IP first (nginx)
real_ip = request.headers.get("X-Real-IP")
if real_ip and self._is_valid_ip(real_ip):
return real_ip
# Try X-Forwarded-For (proxy chains)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
ips = [ip.strip() for ip in forwarded_for.split(",")]
for ip in ips:
if self._is_valid_ip(ip):
return ip
# Try CloudFlare headers
cf_ip = request.headers.get("CF-Connecting-IP")
if cf_ip and self._is_valid_ip(cf_ip):
return cf_ip
# Try True-Client-IP (Akamai, CloudFlare Enterprise)
true_client_ip = request.headers.get("True-Client-IP")
if true_client_ip and self._is_valid_ip(true_client_ip):
return true_client_ip
# Try Fastly header
fastly_ip = request.headers.get("Fastly-Client-IP")
if fastly_ip and self._is_valid_ip(fastly_ip):
return fastly_ip
# Fallback to direct connection
if request.client and request.client.host:
return request.client.host
return None
def _is_valid_ip(self, ip: str) -> bool:
"""Check if string is a valid IP address."""
if not ip:
return False
try:
# Validate IP
ipaddress.ip_address(ip)
# Reject private/local IPs unless in development
if not settings.is_development:
ip_obj = ipaddress.ip_address(ip)
if ip_obj.is_private or ip_obj.is_loopback:
return False
return True
except ValueError:
return False