neural-thinker's picture
feat: clean HuggingFace deployment with essential files only
824bf31
raw
history blame
5.64 kB
"""
Module: api.middleware.rate_limiting
Description: Rate limiting middleware for API endpoints
Author: Anderson H. Silva
Date: 2025-01-24
License: Proprietary - All rights reserved
"""
import time
from typing import Dict, Tuple
from collections import defaultdict
from fastapi import Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from src.core import get_logger
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware using sliding window algorithm."""
def __init__(
self,
app,
calls: int = 60,
period: int = 60,
per_minute: int = 60,
per_hour: int = 1000,
per_day: int = 10000,
):
"""
Initialize rate limiting middleware.
Args:
app: FastAPI application
calls: Number of calls allowed per period
period: Time period in seconds
per_minute: Calls per minute
per_hour: Calls per hour
per_day: Calls per day
"""
super().__init__(app)
self.calls = calls
self.period = period
self.per_minute = per_minute
self.per_hour = per_hour
self.per_day = per_day
# Storage for rate limit data
self.clients: Dict[str, Dict[str, list]] = defaultdict(lambda: {
"minute": [],
"hour": [],
"day": []
})
self.logger = get_logger(__name__)
async def dispatch(self, request: Request, call_next):
"""Process request with rate limiting."""
client_ip = self._get_client_ip(request)
current_time = time.time()
# Check rate limits
if not self._check_rate_limits(client_ip, current_time):
self.logger.warning(
"rate_limit_exceeded",
client_ip=client_ip,
path=request.url.path,
method=request.method,
)
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Too many requests.",
headers={"Retry-After": "60"}
)
# Record the request
self._record_request(client_ip, current_time)
# Process request
response = await call_next(request)
# Add rate limit headers
limits = self._get_remaining_limits(client_ip, current_time)
response.headers["X-RateLimit-Limit-Minute"] = str(self.per_minute)
response.headers["X-RateLimit-Limit-Hour"] = str(self.per_hour)
response.headers["X-RateLimit-Limit-Day"] = str(self.per_day)
response.headers["X-RateLimit-Remaining-Minute"] = str(limits["minute"])
response.headers["X-RateLimit-Remaining-Hour"] = str(limits["hour"])
response.headers["X-RateLimit-Remaining-Day"] = str(limits["day"])
response.headers["X-RateLimit-Reset"] = str(int(current_time) + 60)
return response
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address."""
# Check for forwarded headers first
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct connection
return request.client.host if request.client else "unknown"
def _check_rate_limits(self, client_ip: str, current_time: float) -> bool:
"""Check if client is within rate limits."""
client_data = self.clients[client_ip]
# Clean old requests
self._clean_old_requests(client_data, current_time)
# Check each time window
if len(client_data["minute"]) >= self.per_minute:
return False
if len(client_data["hour"]) >= self.per_hour:
return False
if len(client_data["day"]) >= self.per_day:
return False
return True
def _record_request(self, client_ip: str, current_time: float):
"""Record a request for rate limiting."""
client_data = self.clients[client_ip]
client_data["minute"].append(current_time)
client_data["hour"].append(current_time)
client_data["day"].append(current_time)
def _clean_old_requests(self, client_data: Dict[str, list], current_time: float):
"""Remove old requests outside the time windows."""
# Clean minute window
client_data["minute"] = [
t for t in client_data["minute"]
if current_time - t < 60
]
# Clean hour window
client_data["hour"] = [
t for t in client_data["hour"]
if current_time - t < 3600
]
# Clean day window
client_data["day"] = [
t for t in client_data["day"]
if current_time - t < 86400
]
def _get_remaining_limits(self, client_ip: str, current_time: float) -> Dict[str, int]:
"""Get remaining requests for each time window."""
client_data = self.clients[client_ip]
self._clean_old_requests(client_data, current_time)
return {
"minute": max(0, self.per_minute - len(client_data["minute"])),
"hour": max(0, self.per_hour - len(client_data["hour"])),
"day": max(0, self.per_day - len(client_data["day"])),
}