|
|
""" |
|
|
Module: api.middleware.rate_limiting |
|
|
Description: Rate limiting middleware for API endpoints |
|
|
Author: Anderson H. Silva |
|
|
Date: 2025-01-24 |
|
|
License: Proprietary - All rights reserved |
|
|
""" |
|
|
|
|
|
import time |
|
|
from typing import Dict, Tuple |
|
|
from collections import defaultdict |
|
|
|
|
|
from fastapi import Request, Response, HTTPException |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
|
|
|
from src.core import get_logger |
|
|
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware): |
|
|
"""Rate limiting middleware using sliding window algorithm.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
app, |
|
|
calls: int = 60, |
|
|
period: int = 60, |
|
|
per_minute: int = 60, |
|
|
per_hour: int = 1000, |
|
|
per_day: int = 10000, |
|
|
): |
|
|
""" |
|
|
Initialize rate limiting middleware. |
|
|
|
|
|
Args: |
|
|
app: FastAPI application |
|
|
calls: Number of calls allowed per period |
|
|
period: Time period in seconds |
|
|
per_minute: Calls per minute |
|
|
per_hour: Calls per hour |
|
|
per_day: Calls per day |
|
|
""" |
|
|
super().__init__(app) |
|
|
self.calls = calls |
|
|
self.period = period |
|
|
self.per_minute = per_minute |
|
|
self.per_hour = per_hour |
|
|
self.per_day = per_day |
|
|
|
|
|
|
|
|
self.clients: Dict[str, Dict[str, list]] = defaultdict(lambda: { |
|
|
"minute": [], |
|
|
"hour": [], |
|
|
"day": [] |
|
|
}) |
|
|
|
|
|
self.logger = get_logger(__name__) |
|
|
|
|
|
async def dispatch(self, request: Request, call_next): |
|
|
"""Process request with rate limiting.""" |
|
|
client_ip = self._get_client_ip(request) |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
if not self._check_rate_limits(client_ip, current_time): |
|
|
self.logger.warning( |
|
|
"rate_limit_exceeded", |
|
|
client_ip=client_ip, |
|
|
path=request.url.path, |
|
|
method=request.method, |
|
|
) |
|
|
|
|
|
raise HTTPException( |
|
|
status_code=429, |
|
|
detail="Rate limit exceeded. Too many requests.", |
|
|
headers={"Retry-After": "60"} |
|
|
) |
|
|
|
|
|
|
|
|
self._record_request(client_ip, current_time) |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
limits = self._get_remaining_limits(client_ip, current_time) |
|
|
response.headers["X-RateLimit-Limit-Minute"] = str(self.per_minute) |
|
|
response.headers["X-RateLimit-Limit-Hour"] = str(self.per_hour) |
|
|
response.headers["X-RateLimit-Limit-Day"] = str(self.per_day) |
|
|
response.headers["X-RateLimit-Remaining-Minute"] = str(limits["minute"]) |
|
|
response.headers["X-RateLimit-Remaining-Hour"] = str(limits["hour"]) |
|
|
response.headers["X-RateLimit-Remaining-Day"] = str(limits["day"]) |
|
|
response.headers["X-RateLimit-Reset"] = str(int(current_time) + 60) |
|
|
|
|
|
return response |
|
|
|
|
|
def _get_client_ip(self, request: Request) -> str: |
|
|
"""Get client IP address.""" |
|
|
|
|
|
forwarded_for = request.headers.get("X-Forwarded-For") |
|
|
if forwarded_for: |
|
|
return forwarded_for.split(",")[0].strip() |
|
|
|
|
|
real_ip = request.headers.get("X-Real-IP") |
|
|
if real_ip: |
|
|
return real_ip |
|
|
|
|
|
|
|
|
return request.client.host if request.client else "unknown" |
|
|
|
|
|
def _check_rate_limits(self, client_ip: str, current_time: float) -> bool: |
|
|
"""Check if client is within rate limits.""" |
|
|
client_data = self.clients[client_ip] |
|
|
|
|
|
|
|
|
self._clean_old_requests(client_data, current_time) |
|
|
|
|
|
|
|
|
if len(client_data["minute"]) >= self.per_minute: |
|
|
return False |
|
|
|
|
|
if len(client_data["hour"]) >= self.per_hour: |
|
|
return False |
|
|
|
|
|
if len(client_data["day"]) >= self.per_day: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _record_request(self, client_ip: str, current_time: float): |
|
|
"""Record a request for rate limiting.""" |
|
|
client_data = self.clients[client_ip] |
|
|
|
|
|
client_data["minute"].append(current_time) |
|
|
client_data["hour"].append(current_time) |
|
|
client_data["day"].append(current_time) |
|
|
|
|
|
def _clean_old_requests(self, client_data: Dict[str, list], current_time: float): |
|
|
"""Remove old requests outside the time windows.""" |
|
|
|
|
|
client_data["minute"] = [ |
|
|
t for t in client_data["minute"] |
|
|
if current_time - t < 60 |
|
|
] |
|
|
|
|
|
|
|
|
client_data["hour"] = [ |
|
|
t for t in client_data["hour"] |
|
|
if current_time - t < 3600 |
|
|
] |
|
|
|
|
|
|
|
|
client_data["day"] = [ |
|
|
t for t in client_data["day"] |
|
|
if current_time - t < 86400 |
|
|
] |
|
|
|
|
|
def _get_remaining_limits(self, client_ip: str, current_time: float) -> Dict[str, int]: |
|
|
"""Get remaining requests for each time window.""" |
|
|
client_data = self.clients[client_ip] |
|
|
self._clean_old_requests(client_data, current_time) |
|
|
|
|
|
return { |
|
|
"minute": max(0, self.per_minute - len(client_data["minute"])), |
|
|
"hour": max(0, self.per_hour - len(client_data["hour"])), |
|
|
"day": max(0, self.per_day - len(client_data["day"])), |
|
|
} |