File size: 4,034 Bytes
daf23ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""Rate limiting middleware with distributed Redis backend"""
import time
from typing import Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from src.core import get_logger
from src.services.rate_limit_service import get_rate_limiter
logger = get_logger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware using distributed Redis backend"""
def __init__(self, app: ASGIApp):
super().__init__(app)
self.rate_limiter = get_rate_limiter()
async def dispatch(self, request: Request, call_next):
"""Apply rate limiting to requests"""
# Get client identifier (IP address)
client_ip = self._get_client_ip(request)
# Skip rate limiting for health checks and metrics
if request.url.path in ["/health", "/metrics", "/docs", "/redoc", "/openapi.json"]:
return await call_next(request)
# Check rate limit
allowed, rate_info = await self.rate_limiter.is_allowed(
identifier=client_ip,
endpoint=request.url.path
)
if not allowed:
# Log rate limit violation
logger.warning(
f"Rate limit exceeded for {client_ip} on {request.url.path}",
extra={
"client_ip": client_ip,
"endpoint": request.url.path,
"rate_info": rate_info
}
)
# Return 429 response
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"error": rate_info.get("reason", "rate_limit_exceeded"),
"retry_after": rate_info.get("reset_in", 60)
},
headers={
"Retry-After": str(rate_info.get("reset_in", 60)),
"X-RateLimit-Limit": str(rate_info.get("limit", 60)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time() + rate_info.get("reset_in", 60)))
}
)
# Process request and add rate limit headers
response = await call_next(request)
# Add rate limit headers to response
if "limits" in rate_info:
response.headers["X-RateLimit-Limit-Minute"] = str(rate_info["limits"]["per_minute"])
response.headers["X-RateLimit-Limit-Hour"] = str(rate_info["limits"]["per_hour"])
response.headers["X-RateLimit-Limit-Day"] = str(rate_info["limits"]["per_day"])
if "minute_count" in rate_info:
response.headers["X-RateLimit-Used-Minute"] = str(rate_info["minute_count"])
response.headers["X-RateLimit-Remaining-Minute"] = str(
rate_info["limits"]["per_minute"] - rate_info["minute_count"]
)
if "burst_remaining" in rate_info:
response.headers["X-RateLimit-Burst-Remaining"] = str(rate_info["burst_remaining"])
return response
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address considering proxies"""
# Check X-Forwarded-For header (reverse proxy)
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# Take the first IP (original client)
ip = forwarded_for.split(",")[0].strip()
return ip
# Check X-Real-IP header (nginx)
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to client address
if request.client and request.client.host:
return request.client.host
return "unknown" |