anderson-ufrj
commited on
Commit
·
f70869e
1
Parent(s):
97c535b
feat(security): implement IP whitelist for production environments
Browse files- Add IP whitelist service with CIDR support
- Create middleware for IP validation and enforcement
- Implement admin API endpoints for whitelist management
- Add database migration for ip_whitelists table
- Configure automatic defaults for different environments
- Support expiration dates and metadata for entries
- Include comprehensive unit tests for service logic
- alembic/versions/006_add_ip_whitelist_table.py +70 -0
- src/api/app.py +51 -0
- src/api/middleware/__init__.py +24 -0
- src/api/middleware/ip_whitelist.py +261 -0
- src/api/middleware/webhook_verification.py +295 -0
- src/api/routes/admin/__init__.py +1 -0
- src/api/routes/admin/ip_whitelist.py +377 -0
- src/api/routes/webhooks.py +407 -0
- src/core/config.py +5 -0
- src/services/ip_whitelist_service.py +435 -0
- tests/unit/services/test_ip_whitelist_service.py +300 -0
alembic/versions/006_add_ip_whitelist_table.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""add ip whitelist table
|
| 2 |
+
|
| 3 |
+
Revision ID: 006
|
| 4 |
+
Revises: 005
|
| 5 |
+
Create Date: 2025-01-25
|
| 6 |
+
"""
|
| 7 |
+
from alembic import op
|
| 8 |
+
import sqlalchemy as sa
|
| 9 |
+
from sqlalchemy.dialects import postgresql
|
| 10 |
+
|
| 11 |
+
# revision identifiers
|
| 12 |
+
revision = '006'
|
| 13 |
+
down_revision = '005'
|
| 14 |
+
branch_labels = None
|
| 15 |
+
depends_on = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def upgrade() -> None:
|
| 19 |
+
"""Create ip_whitelists table."""
|
| 20 |
+
op.create_table(
|
| 21 |
+
'ip_whitelists',
|
| 22 |
+
sa.Column('id', sa.String(64), primary_key=True),
|
| 23 |
+
sa.Column('ip_address', sa.String(45), nullable=False),
|
| 24 |
+
sa.Column('description', sa.String(255), nullable=True),
|
| 25 |
+
sa.Column('environment', sa.String(20), nullable=False, server_default='production'),
|
| 26 |
+
sa.Column('active', sa.Boolean(), nullable=False, server_default='true'),
|
| 27 |
+
sa.Column('created_by', sa.String(255), nullable=False),
|
| 28 |
+
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
| 29 |
+
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
| 30 |
+
sa.Column('metadata', postgresql.JSON(), nullable=False, server_default='{}'),
|
| 31 |
+
sa.Column('is_cidr', sa.Boolean(), nullable=False, server_default='false'),
|
| 32 |
+
sa.Column('cidr_prefix', sa.Integer(), nullable=True),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Create indexes
|
| 36 |
+
op.create_index(
|
| 37 |
+
'ix_ip_whitelists_ip_address',
|
| 38 |
+
'ip_whitelists',
|
| 39 |
+
['ip_address']
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
op.create_index(
|
| 43 |
+
'ix_ip_whitelists_environment',
|
| 44 |
+
'ip_whitelists',
|
| 45 |
+
['environment']
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
op.create_index(
|
| 49 |
+
'ix_ip_whitelists_active',
|
| 50 |
+
'ip_whitelists',
|
| 51 |
+
['active']
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
op.create_index(
|
| 55 |
+
'ix_ip_whitelists_expires_at',
|
| 56 |
+
'ip_whitelists',
|
| 57 |
+
['expires_at']
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Create unique constraint on ip + environment
|
| 61 |
+
op.create_unique_constraint(
|
| 62 |
+
'uq_ip_whitelists_ip_environment',
|
| 63 |
+
'ip_whitelists',
|
| 64 |
+
['ip_address', 'environment']
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def downgrade() -> None:
|
| 69 |
+
"""Drop ip_whitelists table."""
|
| 70 |
+
op.drop_table('ip_whitelists')
|
src/api/app.py
CHANGED
|
@@ -27,6 +27,8 @@ from src.api.middleware.logging_middleware import LoggingMiddleware
|
|
| 27 |
from src.api.middleware.security import SecurityMiddleware
|
| 28 |
from src.api.middleware.compression import CompressionMiddleware
|
| 29 |
from src.api.middleware.metrics_middleware import MetricsMiddleware, setup_http_metrics
|
|
|
|
|
|
|
| 30 |
from src.infrastructure.observability import (
|
| 31 |
CorrelationMiddleware,
|
| 32 |
tracing_manager,
|
|
@@ -185,6 +187,39 @@ add_compression_middleware(
|
|
| 185 |
exclude_paths={"/health", "/metrics", "/health/metrics", "/api/v1/ws", "/api/v1/observability"}
|
| 186 |
)
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
# Custom OpenAPI schema
|
| 190 |
def custom_openapi():
|
|
@@ -361,6 +396,22 @@ app.include_router(
|
|
| 361 |
tags=["Notifications"]
|
| 362 |
)
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
# Global exception handler
|
| 366 |
@app.exception_handler(CidadaoAIError)
|
|
|
|
| 27 |
from src.api.middleware.security import SecurityMiddleware
|
| 28 |
from src.api.middleware.compression import CompressionMiddleware
|
| 29 |
from src.api.middleware.metrics_middleware import MetricsMiddleware, setup_http_metrics
|
| 30 |
+
from src.api.middleware.ip_whitelist import IPWhitelistMiddleware
|
| 31 |
+
from src.api.middleware.rate_limit import RateLimitMiddleware as RateLimitMiddlewareV2
|
| 32 |
from src.infrastructure.observability import (
|
| 33 |
CorrelationMiddleware,
|
| 34 |
tracing_manager,
|
|
|
|
| 187 |
exclude_paths={"/health", "/metrics", "/health/metrics", "/api/v1/ws", "/api/v1/observability"}
|
| 188 |
)
|
| 189 |
|
| 190 |
+
# Add IP whitelist middleware (only in production)
|
| 191 |
+
if settings.is_production or settings.app_env == "staging":
|
| 192 |
+
app.add_middleware(
|
| 193 |
+
IPWhitelistMiddleware,
|
| 194 |
+
enabled=True,
|
| 195 |
+
excluded_paths=[
|
| 196 |
+
"/health",
|
| 197 |
+
"/healthz",
|
| 198 |
+
"/ping",
|
| 199 |
+
"/ready",
|
| 200 |
+
"/docs",
|
| 201 |
+
"/redoc",
|
| 202 |
+
"/openapi.json",
|
| 203 |
+
"/metrics",
|
| 204 |
+
"/static",
|
| 205 |
+
"/favicon.ico",
|
| 206 |
+
"/_next",
|
| 207 |
+
"/api/v1/auth/login",
|
| 208 |
+
"/api/v1/auth/register",
|
| 209 |
+
"/api/v1/auth/refresh",
|
| 210 |
+
"/api/v1/public",
|
| 211 |
+
"/api/v1/webhooks/incoming"
|
| 212 |
+
],
|
| 213 |
+
strict_mode=False # Allow requests if IP can't be determined
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Add rate limiting middleware v2
|
| 217 |
+
app.add_middleware(
|
| 218 |
+
RateLimitMiddlewareV2,
|
| 219 |
+
default_tier="free",
|
| 220 |
+
strategy="sliding_window"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
|
| 224 |
# Custom OpenAPI schema
|
| 225 |
def custom_openapi():
|
|
|
|
| 396 |
tags=["Notifications"]
|
| 397 |
)
|
| 398 |
|
| 399 |
+
# Import and include admin routes
|
| 400 |
+
from src.api.routes.admin import ip_whitelist as admin_ip_whitelist
|
| 401 |
+
from src.api.routes import api_keys
|
| 402 |
+
|
| 403 |
+
app.include_router(
|
| 404 |
+
admin_ip_whitelist.router,
|
| 405 |
+
prefix="/api/v1/admin",
|
| 406 |
+
tags=["Admin - IP Whitelist"]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
app.include_router(
|
| 410 |
+
api_keys.router,
|
| 411 |
+
prefix="/api/v1",
|
| 412 |
+
tags=["API Keys"]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
|
| 416 |
# Global exception handler
|
| 417 |
@app.exception_handler(CidadaoAIError)
|
src/api/middleware/__init__.py
CHANGED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API middleware package.
|
| 3 |
+
|
| 4 |
+
This module exports all middleware classes for easy import.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .rate_limit import RateLimitMiddleware, rate_limit
|
| 8 |
+
from .webhook_verification import (
|
| 9 |
+
WebhookVerificationMiddleware,
|
| 10 |
+
WebhookSigner,
|
| 11 |
+
create_webhook_signature,
|
| 12 |
+
verify_webhook_signature
|
| 13 |
+
)
|
| 14 |
+
from .ip_whitelist import IPWhitelistMiddleware
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"RateLimitMiddleware",
|
| 18 |
+
"rate_limit",
|
| 19 |
+
"WebhookVerificationMiddleware",
|
| 20 |
+
"WebhookSigner",
|
| 21 |
+
"create_webhook_signature",
|
| 22 |
+
"verify_webhook_signature",
|
| 23 |
+
"IPWhitelistMiddleware"
|
| 24 |
+
]
|
src/api/middleware/ip_whitelist.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: api.middleware.ip_whitelist
|
| 3 |
+
Description: IP whitelist middleware for production environments
|
| 4 |
+
Author: Anderson H. Silva
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
License: Proprietary - All rights reserved
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional, List, Set
|
| 10 |
+
import ipaddress
|
| 11 |
+
|
| 12 |
+
from fastapi import Request, HTTPException, status
|
| 13 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 14 |
+
from starlette.responses import JSONResponse
|
| 15 |
+
|
| 16 |
+
from src.core import get_logger
|
| 17 |
+
from src.core.config import settings
|
| 18 |
+
from src.services.ip_whitelist_service import ip_whitelist_service
|
| 19 |
+
from src.db.session import get_session
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IPWhitelistMiddleware(BaseHTTPMiddleware):
|
| 25 |
+
"""
|
| 26 |
+
IP whitelist middleware for production environments.
|
| 27 |
+
|
| 28 |
+
Features:
|
| 29 |
+
- Environment-based activation
|
| 30 |
+
- Path exclusions
|
| 31 |
+
- Multiple IP extraction methods
|
| 32 |
+
- Performance optimization with caching
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
app,
|
| 38 |
+
enabled: Optional[bool] = None,
|
| 39 |
+
excluded_paths: Optional[List[str]] = None,
|
| 40 |
+
strict_mode: bool = True
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Initialize IP whitelist middleware.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
app: FastAPI application
|
| 47 |
+
enabled: Force enable/disable (None = auto based on environment)
|
| 48 |
+
excluded_paths: Paths to exclude from whitelist check
|
| 49 |
+
strict_mode: If True, reject if can't determine IP
|
| 50 |
+
"""
|
| 51 |
+
super().__init__(app)
|
| 52 |
+
|
| 53 |
+
# Auto-enable in production if not specified
|
| 54 |
+
if enabled is None:
|
| 55 |
+
self.enabled = settings.is_production
|
| 56 |
+
else:
|
| 57 |
+
self.enabled = enabled
|
| 58 |
+
|
| 59 |
+
self.excluded_paths = set(excluded_paths or self._get_default_excluded_paths())
|
| 60 |
+
self.strict_mode = strict_mode
|
| 61 |
+
|
| 62 |
+
logger.info(
|
| 63 |
+
"ip_whitelist_middleware_initialized",
|
| 64 |
+
enabled=self.enabled,
|
| 65 |
+
environment=settings.app_env,
|
| 66 |
+
strict_mode=self.strict_mode
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
async def dispatch(self, request: Request, call_next):
|
| 70 |
+
"""Process request with IP whitelist check."""
|
| 71 |
+
# Skip if disabled
|
| 72 |
+
if not self.enabled:
|
| 73 |
+
return await call_next(request)
|
| 74 |
+
|
| 75 |
+
# Skip excluded paths
|
| 76 |
+
if self._should_skip(request.url.path):
|
| 77 |
+
return await call_next(request)
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
# Extract client IP
|
| 81 |
+
client_ip = self._get_client_ip(request)
|
| 82 |
+
|
| 83 |
+
if not client_ip:
|
| 84 |
+
if self.strict_mode:
|
| 85 |
+
logger.warning(
|
| 86 |
+
"ip_whitelist_no_client_ip",
|
| 87 |
+
path=request.url.path,
|
| 88 |
+
headers=dict(request.headers)
|
| 89 |
+
)
|
| 90 |
+
return JSONResponse(
|
| 91 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 92 |
+
content={
|
| 93 |
+
"detail": "Client IP could not be determined",
|
| 94 |
+
"error": "IP_NOT_DETERMINED"
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
# Allow through if can't determine IP
|
| 99 |
+
logger.debug("ip_whitelist_no_client_ip_allowing")
|
| 100 |
+
return await call_next(request)
|
| 101 |
+
|
| 102 |
+
# Check whitelist
|
| 103 |
+
async with get_session() as session:
|
| 104 |
+
is_whitelisted = await ip_whitelist_service.check_ip(
|
| 105 |
+
session,
|
| 106 |
+
client_ip,
|
| 107 |
+
environment=settings.app_env
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if not is_whitelisted:
|
| 111 |
+
logger.warning(
|
| 112 |
+
"ip_whitelist_access_denied",
|
| 113 |
+
client_ip=client_ip,
|
| 114 |
+
path=request.url.path,
|
| 115 |
+
method=request.method
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return JSONResponse(
|
| 119 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 120 |
+
content={
|
| 121 |
+
"detail": "Access denied: IP not whitelisted",
|
| 122 |
+
"error": "IP_NOT_WHITELISTED"
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# IP is whitelisted, proceed
|
| 127 |
+
request.state.client_ip = client_ip
|
| 128 |
+
request.state.ip_whitelisted = True
|
| 129 |
+
|
| 130 |
+
return await call_next(request)
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(
|
| 134 |
+
"ip_whitelist_error",
|
| 135 |
+
error=str(e),
|
| 136 |
+
exc_info=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# In case of error, fail open or closed based on strict mode
|
| 140 |
+
if self.strict_mode:
|
| 141 |
+
return JSONResponse(
|
| 142 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 143 |
+
content={
|
| 144 |
+
"detail": "IP whitelist check failed",
|
| 145 |
+
"error": "WHITELIST_CHECK_ERROR"
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
# Allow through on error
|
| 150 |
+
return await call_next(request)
|
| 151 |
+
|
| 152 |
+
def _get_default_excluded_paths(self) -> List[str]:
|
| 153 |
+
"""Get default paths to exclude from whitelist."""
|
| 154 |
+
return [
|
| 155 |
+
# Health checks
|
| 156 |
+
"/health",
|
| 157 |
+
"/healthz",
|
| 158 |
+
"/ping",
|
| 159 |
+
"/ready",
|
| 160 |
+
|
| 161 |
+
# Documentation
|
| 162 |
+
"/docs",
|
| 163 |
+
"/redoc",
|
| 164 |
+
"/openapi.json",
|
| 165 |
+
|
| 166 |
+
# Metrics
|
| 167 |
+
"/metrics",
|
| 168 |
+
|
| 169 |
+
# Static assets
|
| 170 |
+
"/static",
|
| 171 |
+
"/favicon.ico",
|
| 172 |
+
"/_next",
|
| 173 |
+
|
| 174 |
+
# Public endpoints
|
| 175 |
+
"/api/v1/auth/login",
|
| 176 |
+
"/api/v1/auth/register",
|
| 177 |
+
"/api/v1/auth/refresh",
|
| 178 |
+
"/api/v1/public",
|
| 179 |
+
|
| 180 |
+
# Webhook endpoints (they have their own auth)
|
| 181 |
+
"/api/v1/webhooks/incoming"
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
def _should_skip(self, path: str) -> bool:
|
| 185 |
+
"""Check if path should skip whitelist check."""
|
| 186 |
+
# Exact match
|
| 187 |
+
if path in self.excluded_paths:
|
| 188 |
+
return True
|
| 189 |
+
|
| 190 |
+
# Prefix match
|
| 191 |
+
for excluded in self.excluded_paths:
|
| 192 |
+
if excluded.endswith("*") and path.startswith(excluded[:-1]):
|
| 193 |
+
return True
|
| 194 |
+
if path.startswith(excluded + "/"):
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
def _get_client_ip(self, request: Request) -> Optional[str]:
|
| 200 |
+
"""
|
| 201 |
+
Extract client IP from request.
|
| 202 |
+
|
| 203 |
+
Tries multiple methods in order:
|
| 204 |
+
1. X-Real-IP header
|
| 205 |
+
2. X-Forwarded-For header
|
| 206 |
+
3. CloudFlare headers
|
| 207 |
+
4. Direct client connection
|
| 208 |
+
"""
|
| 209 |
+
# Try X-Real-IP first (nginx)
|
| 210 |
+
real_ip = request.headers.get("X-Real-IP")
|
| 211 |
+
if real_ip and self._is_valid_ip(real_ip):
|
| 212 |
+
return real_ip
|
| 213 |
+
|
| 214 |
+
# Try X-Forwarded-For (proxy chains)
|
| 215 |
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
| 216 |
+
if forwarded_for:
|
| 217 |
+
# Take the first IP in the chain
|
| 218 |
+
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
| 219 |
+
for ip in ips:
|
| 220 |
+
if self._is_valid_ip(ip):
|
| 221 |
+
return ip
|
| 222 |
+
|
| 223 |
+
# Try CloudFlare headers
|
| 224 |
+
cf_ip = request.headers.get("CF-Connecting-IP")
|
| 225 |
+
if cf_ip and self._is_valid_ip(cf_ip):
|
| 226 |
+
return cf_ip
|
| 227 |
+
|
| 228 |
+
# Try True-Client-IP (Akamai, CloudFlare Enterprise)
|
| 229 |
+
true_client_ip = request.headers.get("True-Client-IP")
|
| 230 |
+
if true_client_ip and self._is_valid_ip(true_client_ip):
|
| 231 |
+
return true_client_ip
|
| 232 |
+
|
| 233 |
+
# Try Fastly header
|
| 234 |
+
fastly_ip = request.headers.get("Fastly-Client-IP")
|
| 235 |
+
if fastly_ip and self._is_valid_ip(fastly_ip):
|
| 236 |
+
return fastly_ip
|
| 237 |
+
|
| 238 |
+
# Fallback to direct connection
|
| 239 |
+
if request.client and request.client.host:
|
| 240 |
+
return request.client.host
|
| 241 |
+
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
def _is_valid_ip(self, ip: str) -> bool:
|
| 245 |
+
"""Check if string is a valid IP address."""
|
| 246 |
+
if not ip:
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
# Validate IP
|
| 251 |
+
ipaddress.ip_address(ip)
|
| 252 |
+
|
| 253 |
+
# Reject private/local IPs unless in development
|
| 254 |
+
if not settings.is_development:
|
| 255 |
+
ip_obj = ipaddress.ip_address(ip)
|
| 256 |
+
if ip_obj.is_private or ip_obj.is_loopback:
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
return True
|
| 260 |
+
except ValueError:
|
| 261 |
+
return False
|
src/api/middleware/webhook_verification.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: api.middleware.webhook_verification
|
| 3 |
+
Description: Webhook signature verification middleware
|
| 4 |
+
Author: Anderson H. Silva
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
License: Proprietary - All rights reserved
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import hmac
|
| 10 |
+
import hashlib
|
| 11 |
+
import time
|
| 12 |
+
from typing import Optional, Dict, Any
|
| 13 |
+
|
| 14 |
+
from fastapi import Request, HTTPException, status
|
| 15 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 16 |
+
from starlette.responses import JSONResponse
|
| 17 |
+
|
| 18 |
+
from src.core import get_logger
|
| 19 |
+
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WebhookVerificationMiddleware(BaseHTTPMiddleware):
|
| 24 |
+
"""
|
| 25 |
+
Middleware for verifying incoming webhook signatures.
|
| 26 |
+
|
| 27 |
+
Protects endpoints that receive webhooks from external services.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
app,
|
| 33 |
+
webhook_paths: Optional[Dict[str, str]] = None,
|
| 34 |
+
max_timestamp_age: int = 300 # 5 minutes
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Initialize webhook verification middleware.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
app: FastAPI application
|
| 41 |
+
webhook_paths: Dict of path -> secret mapping
|
| 42 |
+
max_timestamp_age: Maximum age of timestamp in seconds
|
| 43 |
+
"""
|
| 44 |
+
super().__init__(app)
|
| 45 |
+
self.webhook_paths = webhook_paths or {}
|
| 46 |
+
self.max_timestamp_age = max_timestamp_age
|
| 47 |
+
|
| 48 |
+
async def dispatch(self, request: Request, call_next):
|
| 49 |
+
"""Process request with webhook verification."""
|
| 50 |
+
# Check if this is a webhook path
|
| 51 |
+
if request.url.path not in self.webhook_paths:
|
| 52 |
+
return await call_next(request)
|
| 53 |
+
|
| 54 |
+
# Get the secret for this path
|
| 55 |
+
secret = self.webhook_paths[request.url.path]
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Read body
|
| 59 |
+
body = await request.body()
|
| 60 |
+
|
| 61 |
+
# Verify signature
|
| 62 |
+
if not self._verify_signature(request, body, secret):
|
| 63 |
+
logger.warning(
|
| 64 |
+
"webhook_signature_verification_failed",
|
| 65 |
+
path=request.url.path,
|
| 66 |
+
headers=dict(request.headers)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return JSONResponse(
|
| 70 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 71 |
+
content={
|
| 72 |
+
"detail": "Invalid webhook signature",
|
| 73 |
+
"error": "INVALID_SIGNATURE"
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Verify timestamp if present
|
| 78 |
+
if not self._verify_timestamp(request):
|
| 79 |
+
logger.warning(
|
| 80 |
+
"webhook_timestamp_verification_failed",
|
| 81 |
+
path=request.url.path
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return JSONResponse(
|
| 85 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 86 |
+
content={
|
| 87 |
+
"detail": "Webhook timestamp too old",
|
| 88 |
+
"error": "TIMESTAMP_TOO_OLD"
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Store raw body for handler
|
| 93 |
+
request.state.webhook_body = body
|
| 94 |
+
|
| 95 |
+
# Process request
|
| 96 |
+
return await call_next(request)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(
|
| 100 |
+
"webhook_verification_error",
|
| 101 |
+
error=str(e),
|
| 102 |
+
exc_info=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return JSONResponse(
|
| 106 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 107 |
+
content={
|
| 108 |
+
"detail": "Webhook verification error",
|
| 109 |
+
"error": "VERIFICATION_ERROR"
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def _verify_signature(
|
| 114 |
+
self,
|
| 115 |
+
request: Request,
|
| 116 |
+
body: bytes,
|
| 117 |
+
secret: str
|
| 118 |
+
) -> bool:
|
| 119 |
+
"""Verify webhook signature."""
|
| 120 |
+
# Get signature header - support multiple formats
|
| 121 |
+
signature_header = (
|
| 122 |
+
request.headers.get("X-Cidadao-Signature") or
|
| 123 |
+
request.headers.get("X-Webhook-Signature") or
|
| 124 |
+
request.headers.get("X-Hub-Signature-256") # GitHub format
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if not signature_header:
|
| 128 |
+
logger.debug("No signature header found")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
# Parse signature
|
| 132 |
+
if "=" in signature_header:
|
| 133 |
+
algorithm, signature = signature_header.split("=", 1)
|
| 134 |
+
else:
|
| 135 |
+
algorithm = "sha256"
|
| 136 |
+
signature = signature_header
|
| 137 |
+
|
| 138 |
+
# Calculate expected signature
|
| 139 |
+
if algorithm == "sha256":
|
| 140 |
+
expected = hmac.new(
|
| 141 |
+
secret.encode(),
|
| 142 |
+
body,
|
| 143 |
+
hashlib.sha256
|
| 144 |
+
).hexdigest()
|
| 145 |
+
elif algorithm == "sha1":
|
| 146 |
+
expected = hmac.new(
|
| 147 |
+
secret.encode(),
|
| 148 |
+
body,
|
| 149 |
+
hashlib.sha1
|
| 150 |
+
).hexdigest()
|
| 151 |
+
else:
|
| 152 |
+
logger.warning(f"Unsupported signature algorithm: {algorithm}")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
# Compare signatures
|
| 156 |
+
return hmac.compare_digest(signature, expected)
|
| 157 |
+
|
| 158 |
+
def _verify_timestamp(self, request: Request) -> bool:
|
| 159 |
+
"""Verify webhook timestamp is recent."""
|
| 160 |
+
timestamp_header = (
|
| 161 |
+
request.headers.get("X-Cidadao-Timestamp") or
|
| 162 |
+
request.headers.get("X-Webhook-Timestamp")
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if not timestamp_header:
|
| 166 |
+
# No timestamp to verify
|
| 167 |
+
return True
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# Parse timestamp
|
| 171 |
+
if timestamp_header.isdigit():
|
| 172 |
+
# Unix timestamp
|
| 173 |
+
webhook_time = float(timestamp_header)
|
| 174 |
+
else:
|
| 175 |
+
# ISO format
|
| 176 |
+
from dateutil.parser import parse
|
| 177 |
+
webhook_time = parse(timestamp_header).timestamp()
|
| 178 |
+
|
| 179 |
+
# Check age
|
| 180 |
+
current_time = time.time()
|
| 181 |
+
age = abs(current_time - webhook_time)
|
| 182 |
+
|
| 183 |
+
return age <= self.max_timestamp_age
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Failed to parse timestamp: {e}")
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def create_webhook_signature(
|
| 191 |
+
payload: bytes,
|
| 192 |
+
secret: str,
|
| 193 |
+
algorithm: str = "sha256"
|
| 194 |
+
) -> str:
|
| 195 |
+
"""
|
| 196 |
+
Create webhook signature for outgoing webhooks.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
payload: Request body
|
| 200 |
+
secret: Webhook secret
|
| 201 |
+
algorithm: Hash algorithm (sha256, sha1)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Signature string with format "algorithm=signature"
|
| 205 |
+
"""
|
| 206 |
+
if algorithm == "sha256":
|
| 207 |
+
signature = hmac.new(
|
| 208 |
+
secret.encode(),
|
| 209 |
+
payload,
|
| 210 |
+
hashlib.sha256
|
| 211 |
+
).hexdigest()
|
| 212 |
+
elif algorithm == "sha1":
|
| 213 |
+
signature = hmac.new(
|
| 214 |
+
secret.encode(),
|
| 215 |
+
payload,
|
| 216 |
+
hashlib.sha1
|
| 217 |
+
).hexdigest()
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
| 220 |
+
|
| 221 |
+
return f"{algorithm}={signature}"
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def verify_webhook_signature(
|
| 225 |
+
signature: str,
|
| 226 |
+
payload: bytes,
|
| 227 |
+
secret: str
|
| 228 |
+
) -> bool:
|
| 229 |
+
"""
|
| 230 |
+
Verify webhook signature.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
signature: Signature header value
|
| 234 |
+
payload: Request body
|
| 235 |
+
secret: Webhook secret
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
True if signature is valid
|
| 239 |
+
"""
|
| 240 |
+
try:
|
| 241 |
+
# Parse signature
|
| 242 |
+
if "=" in signature:
|
| 243 |
+
algorithm, sig = signature.split("=", 1)
|
| 244 |
+
else:
|
| 245 |
+
algorithm = "sha256"
|
| 246 |
+
sig = signature
|
| 247 |
+
|
| 248 |
+
# Generate expected signature
|
| 249 |
+
expected = create_webhook_signature(payload, secret, algorithm)
|
| 250 |
+
|
| 251 |
+
# Extract just the signature part
|
| 252 |
+
if "=" in expected:
|
| 253 |
+
_, expected_sig = expected.split("=", 1)
|
| 254 |
+
else:
|
| 255 |
+
expected_sig = expected
|
| 256 |
+
|
| 257 |
+
# Compare
|
| 258 |
+
return hmac.compare_digest(sig, expected_sig)
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Signature verification error: {e}")
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class WebhookSigner:
|
| 266 |
+
"""Helper class for signing webhook requests."""
|
| 267 |
+
|
| 268 |
+
def __init__(self, secret: str, algorithm: str = "sha256"):
|
| 269 |
+
"""Initialize webhook signer."""
|
| 270 |
+
self.secret = secret
|
| 271 |
+
self.algorithm = algorithm
|
| 272 |
+
|
| 273 |
+
def sign(self, payload: bytes) -> Dict[str, str]:
|
| 274 |
+
"""
|
| 275 |
+
Generate webhook headers with signature.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
payload: Request body
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Dict of headers to include in request
|
| 282 |
+
"""
|
| 283 |
+
signature = create_webhook_signature(
|
| 284 |
+
payload,
|
| 285 |
+
self.secret,
|
| 286 |
+
self.algorithm
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
timestamp = str(int(time.time()))
|
| 290 |
+
|
| 291 |
+
return {
|
| 292 |
+
"X-Cidadao-Signature": signature,
|
| 293 |
+
"X-Cidadao-Timestamp": timestamp,
|
| 294 |
+
"X-Cidadao-Algorithm": self.algorithm
|
| 295 |
+
}
|
src/api/routes/admin/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Admin routes package."""
|
src/api/routes/admin/ip_whitelist.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: api.routes.admin.ip_whitelist
|
| 3 |
+
Description: Admin routes for managing IP whitelist
|
| 4 |
+
Author: Anderson H. Silva
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
License: Proprietary - All rights reserved
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import List, Optional, Dict, Any
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import ipaddress
|
| 12 |
+
|
| 13 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
| 14 |
+
from pydantic import BaseModel, Field, field_validator
|
| 15 |
+
|
| 16 |
+
from src.core import get_logger
|
| 17 |
+
from src.api.dependencies import require_admin, get_db
|
| 18 |
+
from src.services.ip_whitelist_service import ip_whitelist_service, IPWhitelist
|
| 19 |
+
from src.core.config import settings
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
router = APIRouter(prefix="/ip-whitelist", tags=["Admin - IP Whitelist"])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class IPWhitelistRequest(BaseModel):
|
| 27 |
+
"""Request to add IP to whitelist."""
|
| 28 |
+
ip_address: str = Field(..., description="IP address or CIDR range")
|
| 29 |
+
description: Optional[str] = Field(None, description="Description of the IP")
|
| 30 |
+
environment: str = Field(default="production", description="Environment")
|
| 31 |
+
expires_at: Optional[datetime] = Field(None, description="Expiration date")
|
| 32 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
| 33 |
+
|
| 34 |
+
@field_validator("ip_address")
|
| 35 |
+
@classmethod
|
| 36 |
+
def validate_ip(cls, v: str) -> str:
|
| 37 |
+
"""Validate IP address or CIDR."""
|
| 38 |
+
try:
|
| 39 |
+
# Try as single IP
|
| 40 |
+
ipaddress.ip_address(v)
|
| 41 |
+
return v
|
| 42 |
+
except ValueError:
|
| 43 |
+
# Try as CIDR
|
| 44 |
+
try:
|
| 45 |
+
ipaddress.ip_network(v, strict=False)
|
| 46 |
+
return v
|
| 47 |
+
except ValueError:
|
| 48 |
+
raise ValueError(f"Invalid IP address or CIDR: {v}")
|
| 49 |
+
|
| 50 |
+
@field_validator("environment")
|
| 51 |
+
@classmethod
|
| 52 |
+
def validate_environment(cls, v: str) -> str:
|
| 53 |
+
"""Validate environment."""
|
| 54 |
+
allowed = ["development", "staging", "production", "testing"]
|
| 55 |
+
if v not in allowed:
|
| 56 |
+
raise ValueError(f"Environment must be one of {allowed}")
|
| 57 |
+
return v
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class IPWhitelistUpdateRequest(BaseModel):
|
| 61 |
+
"""Request to update IP whitelist entry."""
|
| 62 |
+
active: Optional[bool] = None
|
| 63 |
+
description: Optional[str] = None
|
| 64 |
+
expires_at: Optional[datetime] = None
|
| 65 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class IPWhitelistResponse(BaseModel):
|
| 69 |
+
"""IP whitelist entry response."""
|
| 70 |
+
id: str
|
| 71 |
+
ip_address: str
|
| 72 |
+
description: Optional[str]
|
| 73 |
+
environment: str
|
| 74 |
+
active: bool
|
| 75 |
+
is_cidr: bool
|
| 76 |
+
cidr_prefix: Optional[int]
|
| 77 |
+
created_by: str
|
| 78 |
+
created_at: datetime
|
| 79 |
+
expires_at: Optional[datetime]
|
| 80 |
+
metadata: Dict[str, Any]
|
| 81 |
+
is_expired: bool
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def from_model(cls, model: IPWhitelist) -> "IPWhitelistResponse":
|
| 85 |
+
"""Create response from model."""
|
| 86 |
+
return cls(
|
| 87 |
+
id=model.id,
|
| 88 |
+
ip_address=model.ip_address,
|
| 89 |
+
description=model.description,
|
| 90 |
+
environment=model.environment,
|
| 91 |
+
active=model.active,
|
| 92 |
+
is_cidr=model.is_cidr,
|
| 93 |
+
cidr_prefix=model.cidr_prefix,
|
| 94 |
+
created_by=model.created_by,
|
| 95 |
+
created_at=model.created_at,
|
| 96 |
+
expires_at=model.expires_at,
|
| 97 |
+
metadata=model.metadata,
|
| 98 |
+
is_expired=model.is_expired()
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.post("/add", response_model=IPWhitelistResponse)
|
| 103 |
+
async def add_ip_to_whitelist(
|
| 104 |
+
request: IPWhitelistRequest,
|
| 105 |
+
admin_user=Depends(require_admin),
|
| 106 |
+
db=Depends(get_db)
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Add IP address or CIDR range to whitelist.
|
| 110 |
+
|
| 111 |
+
Requires admin privileges.
|
| 112 |
+
"""
|
| 113 |
+
try:
|
| 114 |
+
is_cidr = "/" in request.ip_address
|
| 115 |
+
|
| 116 |
+
entry = await ip_whitelist_service.add_ip(
|
| 117 |
+
session=db,
|
| 118 |
+
ip_address=request.ip_address,
|
| 119 |
+
created_by=admin_user.get("email", "admin"),
|
| 120 |
+
description=request.description,
|
| 121 |
+
environment=request.environment,
|
| 122 |
+
expires_at=request.expires_at,
|
| 123 |
+
is_cidr=is_cidr,
|
| 124 |
+
metadata=request.metadata
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
logger.info(
|
| 128 |
+
"admin_ip_whitelist_added",
|
| 129 |
+
admin=admin_user.get("email"),
|
| 130 |
+
ip=request.ip_address,
|
| 131 |
+
environment=request.environment
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return IPWhitelistResponse.from_model(entry)
|
| 135 |
+
|
| 136 |
+
except ValueError as e:
|
| 137 |
+
raise HTTPException(
|
| 138 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 139 |
+
detail=str(e)
|
| 140 |
+
)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(
|
| 143 |
+
"admin_ip_whitelist_add_error",
|
| 144 |
+
error=str(e),
|
| 145 |
+
exc_info=True
|
| 146 |
+
)
|
| 147 |
+
raise HTTPException(
|
| 148 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 149 |
+
detail="Failed to add IP to whitelist"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@router.delete("/remove/{ip_address}")
|
| 154 |
+
async def remove_ip_from_whitelist(
|
| 155 |
+
ip_address: str,
|
| 156 |
+
environment: str = Query(default="production"),
|
| 157 |
+
admin_user=Depends(require_admin),
|
| 158 |
+
db=Depends(get_db)
|
| 159 |
+
):
|
| 160 |
+
"""
|
| 161 |
+
Remove IP from whitelist.
|
| 162 |
+
|
| 163 |
+
Requires admin privileges.
|
| 164 |
+
"""
|
| 165 |
+
removed = await ip_whitelist_service.remove_ip(
|
| 166 |
+
session=db,
|
| 167 |
+
ip_address=ip_address,
|
| 168 |
+
environment=environment
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if not removed:
|
| 172 |
+
raise HTTPException(
|
| 173 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 174 |
+
detail=f"IP not found in whitelist: {ip_address}"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
logger.info(
|
| 178 |
+
"admin_ip_whitelist_removed",
|
| 179 |
+
admin=admin_user.get("email"),
|
| 180 |
+
ip=ip_address,
|
| 181 |
+
environment=environment
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return {"status": "removed", "ip_address": ip_address}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@router.get("/list", response_model=List[IPWhitelistResponse])
|
| 188 |
+
async def list_whitelisted_ips(
|
| 189 |
+
environment: str = Query(default="production"),
|
| 190 |
+
include_expired: bool = Query(default=False),
|
| 191 |
+
admin_user=Depends(require_admin),
|
| 192 |
+
db=Depends(get_db)
|
| 193 |
+
):
|
| 194 |
+
"""
|
| 195 |
+
List all whitelisted IPs.
|
| 196 |
+
|
| 197 |
+
Requires admin privileges.
|
| 198 |
+
"""
|
| 199 |
+
entries = await ip_whitelist_service.list_ips(
|
| 200 |
+
session=db,
|
| 201 |
+
environment=environment,
|
| 202 |
+
include_expired=include_expired
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return [IPWhitelistResponse.from_model(entry) for entry in entries]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@router.get("/check/{ip_address}")
|
| 209 |
+
async def check_ip_whitelist(
|
| 210 |
+
ip_address: str,
|
| 211 |
+
environment: str = Query(default="production"),
|
| 212 |
+
admin_user=Depends(require_admin),
|
| 213 |
+
db=Depends(get_db)
|
| 214 |
+
):
|
| 215 |
+
"""
|
| 216 |
+
Check if IP is whitelisted.
|
| 217 |
+
|
| 218 |
+
Requires admin privileges.
|
| 219 |
+
"""
|
| 220 |
+
try:
|
| 221 |
+
# Validate IP
|
| 222 |
+
ipaddress.ip_address(ip_address)
|
| 223 |
+
except ValueError:
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 226 |
+
detail="Invalid IP address format"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
is_whitelisted = await ip_whitelist_service.check_ip(
|
| 230 |
+
session=db,
|
| 231 |
+
ip_address=ip_address,
|
| 232 |
+
environment=environment
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"ip_address": ip_address,
|
| 237 |
+
"environment": environment,
|
| 238 |
+
"is_whitelisted": is_whitelisted
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@router.put("/update/{ip_address}", response_model=IPWhitelistResponse)
|
| 243 |
+
async def update_whitelist_entry(
|
| 244 |
+
ip_address: str,
|
| 245 |
+
request: IPWhitelistUpdateRequest,
|
| 246 |
+
environment: str = Query(default="production"),
|
| 247 |
+
admin_user=Depends(require_admin),
|
| 248 |
+
db=Depends(get_db)
|
| 249 |
+
):
|
| 250 |
+
"""
|
| 251 |
+
Update whitelist entry.
|
| 252 |
+
|
| 253 |
+
Requires admin privileges.
|
| 254 |
+
"""
|
| 255 |
+
entry = await ip_whitelist_service.update_ip(
|
| 256 |
+
session=db,
|
| 257 |
+
ip_address=ip_address,
|
| 258 |
+
environment=environment,
|
| 259 |
+
active=request.active,
|
| 260 |
+
description=request.description,
|
| 261 |
+
expires_at=request.expires_at,
|
| 262 |
+
metadata=request.metadata
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if not entry:
|
| 266 |
+
raise HTTPException(
|
| 267 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 268 |
+
detail=f"IP not found in whitelist: {ip_address}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
logger.info(
|
| 272 |
+
"admin_ip_whitelist_updated",
|
| 273 |
+
admin=admin_user.get("email"),
|
| 274 |
+
ip=ip_address,
|
| 275 |
+
active=entry.active
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
return IPWhitelistResponse.from_model(entry)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@router.post("/cleanup")
|
| 282 |
+
async def cleanup_expired_entries(
|
| 283 |
+
environment: Optional[str] = None,
|
| 284 |
+
admin_user=Depends(require_admin),
|
| 285 |
+
db=Depends(get_db)
|
| 286 |
+
):
|
| 287 |
+
"""
|
| 288 |
+
Remove expired whitelist entries.
|
| 289 |
+
|
| 290 |
+
Requires admin privileges.
|
| 291 |
+
"""
|
| 292 |
+
count = await ip_whitelist_service.cleanup_expired(
|
| 293 |
+
session=db,
|
| 294 |
+
environment=environment
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
logger.info(
|
| 298 |
+
"admin_ip_whitelist_cleanup",
|
| 299 |
+
admin=admin_user.get("email"),
|
| 300 |
+
removed=count,
|
| 301 |
+
environment=environment
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
"status": "cleaned",
|
| 306 |
+
"removed_count": count,
|
| 307 |
+
"environment": environment
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@router.post("/initialize-defaults")
|
| 312 |
+
async def initialize_default_whitelist(
|
| 313 |
+
admin_user=Depends(require_admin),
|
| 314 |
+
db=Depends(get_db)
|
| 315 |
+
):
|
| 316 |
+
"""
|
| 317 |
+
Initialize default whitelist entries for current environment.
|
| 318 |
+
|
| 319 |
+
Requires admin privileges.
|
| 320 |
+
"""
|
| 321 |
+
count = await ip_whitelist_service.initialize_defaults(
|
| 322 |
+
session=db,
|
| 323 |
+
created_by=admin_user.get("email", "admin")
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
logger.info(
|
| 327 |
+
"admin_ip_whitelist_defaults_initialized",
|
| 328 |
+
admin=admin_user.get("email"),
|
| 329 |
+
count=count,
|
| 330 |
+
environment=settings.app_env
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"status": "initialized",
|
| 335 |
+
"added_count": count,
|
| 336 |
+
"environment": settings.app_env,
|
| 337 |
+
"defaults": ip_whitelist_service.get_default_whitelist()
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@router.get("/stats")
|
| 342 |
+
async def get_whitelist_stats(
|
| 343 |
+
admin_user=Depends(require_admin),
|
| 344 |
+
db=Depends(get_db)
|
| 345 |
+
):
|
| 346 |
+
"""
|
| 347 |
+
Get whitelist statistics.
|
| 348 |
+
|
| 349 |
+
Requires admin privileges.
|
| 350 |
+
"""
|
| 351 |
+
environments = ["development", "staging", "production"]
|
| 352 |
+
stats = {}
|
| 353 |
+
|
| 354 |
+
for env in environments:
|
| 355 |
+
entries = await ip_whitelist_service.list_ips(
|
| 356 |
+
session=db,
|
| 357 |
+
environment=env,
|
| 358 |
+
include_expired=True
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
active = sum(1 for e in entries if e.active and not e.is_expired())
|
| 362 |
+
expired = sum(1 for e in entries if e.is_expired())
|
| 363 |
+
cidr_ranges = sum(1 for e in entries if e.is_cidr)
|
| 364 |
+
|
| 365 |
+
stats[env] = {
|
| 366 |
+
"total": len(entries),
|
| 367 |
+
"active": active,
|
| 368 |
+
"expired": expired,
|
| 369 |
+
"inactive": sum(1 for e in entries if not e.active),
|
| 370 |
+
"cidr_ranges": cidr_ranges,
|
| 371 |
+
"single_ips": len(entries) - cidr_ranges
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
"statistics": stats,
|
| 376 |
+
"current_environment": settings.app_env
|
| 377 |
+
}
|
src/api/routes/webhooks.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: api.routes.webhooks
|
| 3 |
+
Description: Webhook endpoints for receiving external events
|
| 4 |
+
Author: Anderson H. Silva
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
License: Proprietary - All rights reserved
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
from fastapi import APIRouter, Request, Depends, HTTPException, status, BackgroundTasks
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
|
| 15 |
+
from src.core import get_logger
|
| 16 |
+
from src.api.dependencies import get_current_user
|
| 17 |
+
from src.services.webhook_service import WebhookConfig, WebhookEvent, webhook_service
|
| 18 |
+
from src.api.middleware.webhook_verification import verify_webhook_signature
|
| 19 |
+
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
router = APIRouter(prefix="/webhooks", tags=["Webhooks"])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class IncomingWebhookPayload(BaseModel):
|
| 26 |
+
"""Generic incoming webhook payload."""
|
| 27 |
+
event: str
|
| 28 |
+
timestamp: Optional[datetime] = None
|
| 29 |
+
data: Dict[str, Any]
|
| 30 |
+
signature: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WebhookRegistrationRequest(BaseModel):
|
| 34 |
+
"""Request to register a new webhook."""
|
| 35 |
+
url: str = Field(..., description="Webhook endpoint URL")
|
| 36 |
+
events: Optional[list[str]] = Field(None, description="Events to subscribe to (None = all)")
|
| 37 |
+
secret: Optional[str] = Field(None, description="Webhook secret for HMAC signing")
|
| 38 |
+
headers: Optional[Dict[str, str]] = Field(None, description="Custom headers")
|
| 39 |
+
active: bool = Field(True, description="Whether webhook is active")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class WebhookTestRequest(BaseModel):
|
| 43 |
+
"""Request to test a webhook."""
|
| 44 |
+
url: str = Field(..., description="Webhook URL to test")
|
| 45 |
+
secret: Optional[str] = Field(None, description="Webhook secret if any")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.post("/incoming/github")
|
| 49 |
+
async def receive_github_webhook(
|
| 50 |
+
request: Request,
|
| 51 |
+
background_tasks: BackgroundTasks
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Receive webhooks from GitHub.
|
| 55 |
+
|
| 56 |
+
Requires webhook signature verification.
|
| 57 |
+
"""
|
| 58 |
+
# Get raw body from request state (set by verification middleware)
|
| 59 |
+
body = getattr(request.state, "webhook_body", None)
|
| 60 |
+
if not body:
|
| 61 |
+
body = await request.body()
|
| 62 |
+
|
| 63 |
+
# Parse event type
|
| 64 |
+
event_type = request.headers.get("X-GitHub-Event", "unknown")
|
| 65 |
+
|
| 66 |
+
# Parse payload
|
| 67 |
+
try:
|
| 68 |
+
import json
|
| 69 |
+
payload = json.loads(body)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error("Failed to parse GitHub webhook", error=str(e))
|
| 72 |
+
raise HTTPException(
|
| 73 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 74 |
+
detail="Invalid payload format"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Log webhook received
|
| 78 |
+
logger.info(
|
| 79 |
+
"github_webhook_received",
|
| 80 |
+
event=event_type,
|
| 81 |
+
repository=payload.get("repository", {}).get("full_name"),
|
| 82 |
+
action=payload.get("action")
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Process webhook asynchronously
|
| 86 |
+
background_tasks.add_task(
|
| 87 |
+
process_github_webhook,
|
| 88 |
+
event_type,
|
| 89 |
+
payload
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return {"status": "accepted", "event": event_type}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@router.post("/incoming/generic/{webhook_id}")
|
| 96 |
+
async def receive_generic_webhook(
|
| 97 |
+
webhook_id: str,
|
| 98 |
+
request: Request,
|
| 99 |
+
payload: IncomingWebhookPayload,
|
| 100 |
+
background_tasks: BackgroundTasks
|
| 101 |
+
):
|
| 102 |
+
"""
|
| 103 |
+
Receive generic webhooks with configurable verification.
|
| 104 |
+
|
| 105 |
+
The webhook_id should match a configured incoming webhook.
|
| 106 |
+
"""
|
| 107 |
+
# Verify webhook ID exists and get configuration
|
| 108 |
+
# In production, this would look up from database
|
| 109 |
+
webhook_config = get_incoming_webhook_config(webhook_id)
|
| 110 |
+
|
| 111 |
+
if not webhook_config:
|
| 112 |
+
raise HTTPException(
|
| 113 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 114 |
+
detail=f"Webhook configuration not found: {webhook_id}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Verify signature if secret is configured
|
| 118 |
+
if webhook_config.get("secret"):
|
| 119 |
+
body = await request.body()
|
| 120 |
+
signature = request.headers.get("X-Webhook-Signature")
|
| 121 |
+
|
| 122 |
+
if not signature:
|
| 123 |
+
raise HTTPException(
|
| 124 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 125 |
+
detail="Missing webhook signature"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if not verify_webhook_signature(signature, body, webhook_config["secret"]):
|
| 129 |
+
raise HTTPException(
|
| 130 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 131 |
+
detail="Invalid webhook signature"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Log webhook
|
| 135 |
+
logger.info(
|
| 136 |
+
"generic_webhook_received",
|
| 137 |
+
webhook_id=webhook_id,
|
| 138 |
+
event=payload.event,
|
| 139 |
+
timestamp=payload.timestamp
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Process asynchronously
|
| 143 |
+
background_tasks.add_task(
|
| 144 |
+
process_generic_webhook,
|
| 145 |
+
webhook_id,
|
| 146 |
+
payload
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
"status": "accepted",
|
| 151 |
+
"webhook_id": webhook_id,
|
| 152 |
+
"event": payload.event
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@router.post("/register")
|
| 157 |
+
async def register_webhook(
|
| 158 |
+
request: WebhookRegistrationRequest,
|
| 159 |
+
current_user=Depends(get_current_user)
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Register a new outgoing webhook.
|
| 163 |
+
|
| 164 |
+
Requires authentication.
|
| 165 |
+
"""
|
| 166 |
+
# Convert string events to enum
|
| 167 |
+
events = None
|
| 168 |
+
if request.events:
|
| 169 |
+
try:
|
| 170 |
+
events = [WebhookEvent(e) for e in request.events]
|
| 171 |
+
except ValueError as e:
|
| 172 |
+
raise HTTPException(
|
| 173 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 174 |
+
detail=f"Invalid event type: {e}"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Create webhook config
|
| 178 |
+
config = WebhookConfig(
|
| 179 |
+
url=request.url,
|
| 180 |
+
events=events,
|
| 181 |
+
secret=request.secret,
|
| 182 |
+
headers=request.headers,
|
| 183 |
+
active=request.active
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Register webhook
|
| 187 |
+
webhook_service.add_webhook(config)
|
| 188 |
+
|
| 189 |
+
logger.info(
|
| 190 |
+
"webhook_registered",
|
| 191 |
+
user=current_user.get("email"),
|
| 192 |
+
url=request.url,
|
| 193 |
+
events=request.events
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"status": "registered",
|
| 198 |
+
"url": request.url,
|
| 199 |
+
"events": request.events,
|
| 200 |
+
"active": request.active
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@router.delete("/unregister")
|
| 205 |
+
async def unregister_webhook(
|
| 206 |
+
url: str,
|
| 207 |
+
current_user=Depends(get_current_user)
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Unregister a webhook.
|
| 211 |
+
|
| 212 |
+
Requires authentication.
|
| 213 |
+
"""
|
| 214 |
+
removed = webhook_service.remove_webhook(url)
|
| 215 |
+
|
| 216 |
+
if not removed:
|
| 217 |
+
raise HTTPException(
|
| 218 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 219 |
+
detail=f"Webhook not found: {url}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
logger.info(
|
| 223 |
+
"webhook_unregistered",
|
| 224 |
+
user=current_user.get("email"),
|
| 225 |
+
url=url
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return {"status": "unregistered", "url": url}
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@router.get("/list")
|
| 232 |
+
async def list_webhooks(
|
| 233 |
+
current_user=Depends(get_current_user)
|
| 234 |
+
):
|
| 235 |
+
"""
|
| 236 |
+
List all registered outgoing webhooks.
|
| 237 |
+
|
| 238 |
+
Requires authentication.
|
| 239 |
+
"""
|
| 240 |
+
webhooks = webhook_service.list_webhooks()
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
"webhooks": [
|
| 244 |
+
{
|
| 245 |
+
"url": str(w.url),
|
| 246 |
+
"events": [e.value for e in w.events] if w.events else None,
|
| 247 |
+
"active": w.active,
|
| 248 |
+
"has_secret": bool(w.secret)
|
| 249 |
+
}
|
| 250 |
+
for w in webhooks
|
| 251 |
+
],
|
| 252 |
+
"total": len(webhooks)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@router.post("/test")
|
| 257 |
+
async def test_webhook(
|
| 258 |
+
request: WebhookTestRequest,
|
| 259 |
+
background_tasks: BackgroundTasks,
|
| 260 |
+
current_user=Depends(get_current_user)
|
| 261 |
+
):
|
| 262 |
+
"""
|
| 263 |
+
Test a webhook endpoint.
|
| 264 |
+
|
| 265 |
+
Sends a test payload to verify webhook is working.
|
| 266 |
+
"""
|
| 267 |
+
# Create temporary webhook config
|
| 268 |
+
config = WebhookConfig(
|
| 269 |
+
url=request.url,
|
| 270 |
+
secret=request.secret,
|
| 271 |
+
max_retries=1,
|
| 272 |
+
timeout=10
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Test webhook
|
| 276 |
+
delivery = await webhook_service.test_webhook(config)
|
| 277 |
+
|
| 278 |
+
logger.info(
|
| 279 |
+
"webhook_tested",
|
| 280 |
+
user=current_user.get("email"),
|
| 281 |
+
url=request.url,
|
| 282 |
+
success=delivery.success,
|
| 283 |
+
status_code=delivery.status_code
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return {
|
| 287 |
+
"url": request.url,
|
| 288 |
+
"success": delivery.success,
|
| 289 |
+
"status_code": delivery.status_code,
|
| 290 |
+
"response": delivery.response_body,
|
| 291 |
+
"error": delivery.error,
|
| 292 |
+
"duration_ms": delivery.duration_ms
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@router.get("/history")
|
| 297 |
+
async def get_webhook_history(
|
| 298 |
+
event: Optional[str] = None,
|
| 299 |
+
url: Optional[str] = None,
|
| 300 |
+
success: Optional[bool] = None,
|
| 301 |
+
limit: int = 100,
|
| 302 |
+
current_user=Depends(get_current_user)
|
| 303 |
+
):
|
| 304 |
+
"""
|
| 305 |
+
Get webhook delivery history.
|
| 306 |
+
|
| 307 |
+
Requires authentication.
|
| 308 |
+
"""
|
| 309 |
+
# Convert event string to enum if provided
|
| 310 |
+
event_enum = None
|
| 311 |
+
if event:
|
| 312 |
+
try:
|
| 313 |
+
event_enum = WebhookEvent(event)
|
| 314 |
+
except ValueError:
|
| 315 |
+
raise HTTPException(
|
| 316 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 317 |
+
detail=f"Invalid event type: {event}"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
history = webhook_service.get_delivery_history(
|
| 321 |
+
event=event_enum,
|
| 322 |
+
url=url,
|
| 323 |
+
success=success,
|
| 324 |
+
limit=limit
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return {
|
| 328 |
+
"deliveries": [
|
| 329 |
+
{
|
| 330 |
+
"webhook_url": d.webhook_url,
|
| 331 |
+
"event": d.event.value,
|
| 332 |
+
"timestamp": d.timestamp.isoformat(),
|
| 333 |
+
"success": d.success,
|
| 334 |
+
"status_code": d.status_code,
|
| 335 |
+
"error": d.error,
|
| 336 |
+
"attempts": d.attempts,
|
| 337 |
+
"duration_ms": d.duration_ms
|
| 338 |
+
}
|
| 339 |
+
for d in history
|
| 340 |
+
],
|
| 341 |
+
"total": len(history)
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# Helper functions
|
| 346 |
+
|
| 347 |
+
def get_incoming_webhook_config(webhook_id: str) -> Optional[Dict[str, Any]]:
|
| 348 |
+
"""Get configuration for incoming webhook."""
|
| 349 |
+
# In production, this would be from database
|
| 350 |
+
# For now, return mock config
|
| 351 |
+
configs = {
|
| 352 |
+
"test": {
|
| 353 |
+
"secret": "test-secret",
|
| 354 |
+
"active": True
|
| 355 |
+
},
|
| 356 |
+
"monitoring": {
|
| 357 |
+
"secret": "monitoring-secret",
|
| 358 |
+
"active": True
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
return configs.get(webhook_id)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
async def process_github_webhook(event_type: str, payload: Dict[str, Any]):
|
| 366 |
+
"""Process GitHub webhook asynchronously."""
|
| 367 |
+
try:
|
| 368 |
+
# Handle different GitHub events
|
| 369 |
+
if event_type == "push":
|
| 370 |
+
# Handle code push
|
| 371 |
+
logger.info("Processing GitHub push event")
|
| 372 |
+
elif event_type == "pull_request":
|
| 373 |
+
# Handle pull request
|
| 374 |
+
logger.info("Processing GitHub pull request event")
|
| 375 |
+
elif event_type == "issues":
|
| 376 |
+
# Handle issues
|
| 377 |
+
logger.info("Processing GitHub issues event")
|
| 378 |
+
# Add more event handlers as needed
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(
|
| 382 |
+
"Failed to process GitHub webhook",
|
| 383 |
+
event=event_type,
|
| 384 |
+
error=str(e),
|
| 385 |
+
exc_info=True
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
async def process_generic_webhook(webhook_id: str, payload: IncomingWebhookPayload):
|
| 390 |
+
"""Process generic webhook asynchronously."""
|
| 391 |
+
try:
|
| 392 |
+
# Route to appropriate handler based on webhook_id
|
| 393 |
+
logger.info(
|
| 394 |
+
"Processing generic webhook",
|
| 395 |
+
webhook_id=webhook_id,
|
| 396 |
+
event=payload.event
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# Add specific processing logic here
|
| 400 |
+
|
| 401 |
+
except Exception as e:
|
| 402 |
+
logger.error(
|
| 403 |
+
"Failed to process generic webhook",
|
| 404 |
+
webhook_id=webhook_id,
|
| 405 |
+
error=str(e),
|
| 406 |
+
exc_info=True
|
| 407 |
+
)
|
src/core/config.py
CHANGED
|
@@ -178,6 +178,11 @@ class Settings(BaseSettings):
|
|
| 178 |
rate_limit_per_hour: int = Field(default=1000, description="Rate limit per hour")
|
| 179 |
rate_limit_per_day: int = Field(default=10000, description="Rate limit per day")
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
# Celery
|
| 182 |
celery_broker_url: str = Field(
|
| 183 |
default="redis://localhost:6379/1",
|
|
|
|
| 178 |
rate_limit_per_hour: int = Field(default=1000, description="Rate limit per hour")
|
| 179 |
rate_limit_per_day: int = Field(default=10000, description="Rate limit per day")
|
| 180 |
|
| 181 |
+
# IP Whitelist
|
| 182 |
+
ip_whitelist_enabled: bool = Field(default=True, description="Enable IP whitelist in production")
|
| 183 |
+
ip_whitelist_strict: bool = Field(default=False, description="Strict mode - reject if IP unknown")
|
| 184 |
+
ip_whitelist_cache_ttl: int = Field(default=300, description="IP whitelist cache TTL seconds")
|
| 185 |
+
|
| 186 |
# Celery
|
| 187 |
celery_broker_url: str = Field(
|
| 188 |
default="redis://localhost:6379/1",
|
src/services/ip_whitelist_service.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module: services.ip_whitelist_service
|
| 3 |
+
Description: IP whitelist management for production environments
|
| 4 |
+
Author: Anderson H. Silva
|
| 5 |
+
Date: 2025-01-25
|
| 6 |
+
License: Proprietary - All rights reserved
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import ipaddress
|
| 10 |
+
from typing import List, Optional, Set, Dict, Any
|
| 11 |
+
from datetime import datetime, timezone
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
from src.core import get_logger
|
| 15 |
+
from src.infrastructure.cache import cache_service
|
| 16 |
+
from src.core.config import settings
|
| 17 |
+
from src.models.base import BaseModel
|
| 18 |
+
from sqlalchemy import Column, String, Boolean, DateTime, Integer, JSON
|
| 19 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 20 |
+
from sqlalchemy import select, delete
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class IPWhitelist(BaseModel):
|
| 26 |
+
"""IP whitelist entry model."""
|
| 27 |
+
__tablename__ = "ip_whitelists"
|
| 28 |
+
|
| 29 |
+
id = Column(String(64), primary_key=True)
|
| 30 |
+
ip_address = Column(String(45), nullable=False, unique=True) # IPv4 or IPv6
|
| 31 |
+
description = Column(String(255))
|
| 32 |
+
environment = Column(String(20), nullable=False, default="production")
|
| 33 |
+
active = Column(Boolean, default=True)
|
| 34 |
+
created_by = Column(String(255), nullable=False)
|
| 35 |
+
created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
|
| 36 |
+
expires_at = Column(DateTime(timezone=True), nullable=True)
|
| 37 |
+
metadata = Column(JSON, default=dict)
|
| 38 |
+
|
| 39 |
+
# CIDR support
|
| 40 |
+
is_cidr = Column(Boolean, default=False)
|
| 41 |
+
cidr_prefix = Column(Integer, nullable=True)
|
| 42 |
+
|
| 43 |
+
def is_expired(self) -> bool:
|
| 44 |
+
"""Check if whitelist entry is expired."""
|
| 45 |
+
if not self.expires_at:
|
| 46 |
+
return False
|
| 47 |
+
return datetime.now(timezone.utc) > self.expires_at
|
| 48 |
+
|
| 49 |
+
def matches(self, ip: str) -> bool:
|
| 50 |
+
"""Check if IP matches this whitelist entry."""
|
| 51 |
+
if not self.active or self.is_expired():
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
if self.is_cidr:
|
| 56 |
+
# CIDR range check
|
| 57 |
+
network = ipaddress.ip_network(f"{self.ip_address}/{self.cidr_prefix}")
|
| 58 |
+
return ipaddress.ip_address(ip) in network
|
| 59 |
+
else:
|
| 60 |
+
# Exact match
|
| 61 |
+
return self.ip_address == ip
|
| 62 |
+
except ValueError:
|
| 63 |
+
logger.error(f"Invalid IP address format: {ip}")
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class IPWhitelistService:
|
| 68 |
+
"""Service for managing IP whitelists."""
|
| 69 |
+
|
| 70 |
+
def __init__(self):
|
| 71 |
+
"""Initialize IP whitelist service."""
|
| 72 |
+
self._cache_key_prefix = "ip_whitelist"
|
| 73 |
+
self._cache_ttl = 300 # 5 minutes
|
| 74 |
+
self._whitelist_cache: Optional[Set[str]] = None
|
| 75 |
+
self._cidr_cache: Optional[List[tuple]] = None
|
| 76 |
+
self._last_cache_update: Optional[datetime] = None
|
| 77 |
+
|
| 78 |
+
async def add_ip(
|
| 79 |
+
self,
|
| 80 |
+
session: AsyncSession,
|
| 81 |
+
ip_address: str,
|
| 82 |
+
created_by: str,
|
| 83 |
+
description: Optional[str] = None,
|
| 84 |
+
environment: str = "production",
|
| 85 |
+
expires_at: Optional[datetime] = None,
|
| 86 |
+
is_cidr: bool = False,
|
| 87 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 88 |
+
) -> IPWhitelist:
|
| 89 |
+
"""Add IP address or CIDR range to whitelist."""
|
| 90 |
+
try:
|
| 91 |
+
# Parse and validate IP/CIDR
|
| 92 |
+
if is_cidr or "/" in ip_address:
|
| 93 |
+
network = ipaddress.ip_network(ip_address, strict=False)
|
| 94 |
+
ip_str = str(network.network_address)
|
| 95 |
+
cidr_prefix = network.prefixlen
|
| 96 |
+
is_cidr = True
|
| 97 |
+
else:
|
| 98 |
+
# Validate single IP
|
| 99 |
+
ip_obj = ipaddress.ip_address(ip_address)
|
| 100 |
+
ip_str = str(ip_obj)
|
| 101 |
+
cidr_prefix = None
|
| 102 |
+
is_cidr = False
|
| 103 |
+
|
| 104 |
+
except ValueError as e:
|
| 105 |
+
logger.error(f"Invalid IP address format: {ip_address}")
|
| 106 |
+
raise ValueError(f"Invalid IP address format: {ip_address}") from e
|
| 107 |
+
|
| 108 |
+
# Check if already exists
|
| 109 |
+
existing = await session.execute(
|
| 110 |
+
select(IPWhitelist).where(
|
| 111 |
+
IPWhitelist.ip_address == ip_str,
|
| 112 |
+
IPWhitelist.environment == environment
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
if existing.scalar_one_or_none():
|
| 116 |
+
raise ValueError(f"IP address already whitelisted: {ip_str}")
|
| 117 |
+
|
| 118 |
+
# Create whitelist entry
|
| 119 |
+
entry = IPWhitelist(
|
| 120 |
+
id=f"{environment}:{ip_str}",
|
| 121 |
+
ip_address=ip_str,
|
| 122 |
+
description=description,
|
| 123 |
+
environment=environment,
|
| 124 |
+
created_by=created_by,
|
| 125 |
+
expires_at=expires_at,
|
| 126 |
+
is_cidr=is_cidr,
|
| 127 |
+
cidr_prefix=cidr_prefix,
|
| 128 |
+
metadata=metadata or {}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
session.add(entry)
|
| 132 |
+
await session.commit()
|
| 133 |
+
|
| 134 |
+
# Invalidate cache
|
| 135 |
+
await self._invalidate_cache()
|
| 136 |
+
|
| 137 |
+
logger.info(
|
| 138 |
+
"ip_whitelist_added",
|
| 139 |
+
ip=ip_str,
|
| 140 |
+
environment=environment,
|
| 141 |
+
is_cidr=is_cidr,
|
| 142 |
+
created_by=created_by
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return entry
|
| 146 |
+
|
| 147 |
+
async def remove_ip(
|
| 148 |
+
self,
|
| 149 |
+
session: AsyncSession,
|
| 150 |
+
ip_address: str,
|
| 151 |
+
environment: str = "production"
|
| 152 |
+
) -> bool:
|
| 153 |
+
"""Remove IP from whitelist."""
|
| 154 |
+
result = await session.execute(
|
| 155 |
+
delete(IPWhitelist).where(
|
| 156 |
+
IPWhitelist.ip_address == ip_address,
|
| 157 |
+
IPWhitelist.environment == environment
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
await session.commit()
|
| 161 |
+
|
| 162 |
+
if result.rowcount > 0:
|
| 163 |
+
await self._invalidate_cache()
|
| 164 |
+
logger.info(
|
| 165 |
+
"ip_whitelist_removed",
|
| 166 |
+
ip=ip_address,
|
| 167 |
+
environment=environment
|
| 168 |
+
)
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
async def check_ip(
|
| 174 |
+
self,
|
| 175 |
+
session: AsyncSession,
|
| 176 |
+
ip_address: str,
|
| 177 |
+
environment: str = "production"
|
| 178 |
+
) -> bool:
|
| 179 |
+
"""Check if IP is whitelisted."""
|
| 180 |
+
# Check cache first
|
| 181 |
+
cache_key = f"{self._cache_key_prefix}:{environment}:check:{ip_address}"
|
| 182 |
+
cached = await cache_service.get(cache_key)
|
| 183 |
+
if cached is not None:
|
| 184 |
+
return cached
|
| 185 |
+
|
| 186 |
+
# Load whitelist if needed
|
| 187 |
+
await self._ensure_cache_loaded(session, environment)
|
| 188 |
+
|
| 189 |
+
# Check exact matches first
|
| 190 |
+
if self._whitelist_cache and ip_address in self._whitelist_cache:
|
| 191 |
+
await cache_service.set(cache_key, True, ttl=self._cache_ttl)
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
# Check CIDR ranges
|
| 195 |
+
if self._cidr_cache:
|
| 196 |
+
for cidr_ip, prefix, expires_at in self._cidr_cache:
|
| 197 |
+
if expires_at and datetime.now(timezone.utc) > expires_at:
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
network = ipaddress.ip_network(f"{cidr_ip}/{prefix}")
|
| 202 |
+
if ipaddress.ip_address(ip_address) in network:
|
| 203 |
+
await cache_service.set(cache_key, True, ttl=self._cache_ttl)
|
| 204 |
+
return True
|
| 205 |
+
except ValueError:
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
# Not whitelisted
|
| 209 |
+
await cache_service.set(cache_key, False, ttl=self._cache_ttl)
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
async def list_ips(
|
| 213 |
+
self,
|
| 214 |
+
session: AsyncSession,
|
| 215 |
+
environment: str = "production",
|
| 216 |
+
include_expired: bool = False
|
| 217 |
+
) -> List[IPWhitelist]:
|
| 218 |
+
"""List all whitelisted IPs."""
|
| 219 |
+
query = select(IPWhitelist).where(
|
| 220 |
+
IPWhitelist.environment == environment
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if not include_expired:
|
| 224 |
+
now = datetime.now(timezone.utc)
|
| 225 |
+
query = query.where(
|
| 226 |
+
(IPWhitelist.expires_at.is_(None)) |
|
| 227 |
+
(IPWhitelist.expires_at > now)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
result = await session.execute(query)
|
| 231 |
+
return list(result.scalars().all())
|
| 232 |
+
|
| 233 |
+
async def update_ip(
|
| 234 |
+
self,
|
| 235 |
+
session: AsyncSession,
|
| 236 |
+
ip_address: str,
|
| 237 |
+
environment: str = "production",
|
| 238 |
+
active: Optional[bool] = None,
|
| 239 |
+
description: Optional[str] = None,
|
| 240 |
+
expires_at: Optional[datetime] = None,
|
| 241 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 242 |
+
) -> Optional[IPWhitelist]:
|
| 243 |
+
"""Update whitelist entry."""
|
| 244 |
+
result = await session.execute(
|
| 245 |
+
select(IPWhitelist).where(
|
| 246 |
+
IPWhitelist.ip_address == ip_address,
|
| 247 |
+
IPWhitelist.environment == environment
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
entry = result.scalar_one_or_none()
|
| 251 |
+
|
| 252 |
+
if not entry:
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
if active is not None:
|
| 256 |
+
entry.active = active
|
| 257 |
+
if description is not None:
|
| 258 |
+
entry.description = description
|
| 259 |
+
if expires_at is not None:
|
| 260 |
+
entry.expires_at = expires_at
|
| 261 |
+
if metadata is not None:
|
| 262 |
+
entry.metadata = metadata
|
| 263 |
+
|
| 264 |
+
await session.commit()
|
| 265 |
+
await self._invalidate_cache()
|
| 266 |
+
|
| 267 |
+
logger.info(
|
| 268 |
+
"ip_whitelist_updated",
|
| 269 |
+
ip=ip_address,
|
| 270 |
+
environment=environment,
|
| 271 |
+
active=entry.active
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return entry
|
| 275 |
+
|
| 276 |
+
async def cleanup_expired(
|
| 277 |
+
self,
|
| 278 |
+
session: AsyncSession,
|
| 279 |
+
environment: Optional[str] = None
|
| 280 |
+
) -> int:
|
| 281 |
+
"""Remove expired whitelist entries."""
|
| 282 |
+
query = delete(IPWhitelist).where(
|
| 283 |
+
IPWhitelist.expires_at < datetime.now(timezone.utc)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if environment:
|
| 287 |
+
query = query.where(IPWhitelist.environment == environment)
|
| 288 |
+
|
| 289 |
+
result = await session.execute(query)
|
| 290 |
+
await session.commit()
|
| 291 |
+
|
| 292 |
+
if result.rowcount > 0:
|
| 293 |
+
await self._invalidate_cache()
|
| 294 |
+
logger.info(
|
| 295 |
+
"ip_whitelist_cleanup",
|
| 296 |
+
removed=result.rowcount,
|
| 297 |
+
environment=environment
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
return result.rowcount
|
| 301 |
+
|
| 302 |
+
async def _ensure_cache_loaded(
|
| 303 |
+
self,
|
| 304 |
+
session: AsyncSession,
|
| 305 |
+
environment: str
|
| 306 |
+
) -> None:
|
| 307 |
+
"""Ensure whitelist is loaded in cache."""
|
| 308 |
+
# Check if cache is still valid
|
| 309 |
+
if (
|
| 310 |
+
self._last_cache_update and
|
| 311 |
+
(datetime.now(timezone.utc) - self._last_cache_update).total_seconds() < self._cache_ttl
|
| 312 |
+
):
|
| 313 |
+
return
|
| 314 |
+
|
| 315 |
+
# Load from database
|
| 316 |
+
now = datetime.now(timezone.utc)
|
| 317 |
+
result = await session.execute(
|
| 318 |
+
select(IPWhitelist).where(
|
| 319 |
+
IPWhitelist.environment == environment,
|
| 320 |
+
IPWhitelist.active == True,
|
| 321 |
+
(IPWhitelist.expires_at.is_(None)) | (IPWhitelist.expires_at > now)
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
entries = result.scalars().all()
|
| 326 |
+
|
| 327 |
+
# Separate exact IPs and CIDR ranges
|
| 328 |
+
self._whitelist_cache = set()
|
| 329 |
+
self._cidr_cache = []
|
| 330 |
+
|
| 331 |
+
for entry in entries:
|
| 332 |
+
if entry.is_cidr:
|
| 333 |
+
self._cidr_cache.append((
|
| 334 |
+
entry.ip_address,
|
| 335 |
+
entry.cidr_prefix,
|
| 336 |
+
entry.expires_at
|
| 337 |
+
))
|
| 338 |
+
else:
|
| 339 |
+
self._whitelist_cache.add(entry.ip_address)
|
| 340 |
+
|
| 341 |
+
self._last_cache_update = datetime.now(timezone.utc)
|
| 342 |
+
|
| 343 |
+
logger.debug(
|
| 344 |
+
"ip_whitelist_cache_loaded",
|
| 345 |
+
environment=environment,
|
| 346 |
+
exact_ips=len(self._whitelist_cache),
|
| 347 |
+
cidr_ranges=len(self._cidr_cache)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
async def _invalidate_cache(self) -> None:
|
| 351 |
+
"""Invalidate the whitelist cache."""
|
| 352 |
+
self._whitelist_cache = None
|
| 353 |
+
self._cidr_cache = None
|
| 354 |
+
self._last_cache_update = None
|
| 355 |
+
|
| 356 |
+
# Clear Redis cache patterns
|
| 357 |
+
pattern = f"{self._cache_key_prefix}:*"
|
| 358 |
+
await cache_service.delete_pattern(pattern)
|
| 359 |
+
|
| 360 |
+
def get_default_whitelist(self) -> List[str]:
|
| 361 |
+
"""Get default whitelist based on environment."""
|
| 362 |
+
defaults = []
|
| 363 |
+
|
| 364 |
+
# Always allow localhost
|
| 365 |
+
defaults.extend([
|
| 366 |
+
"127.0.0.1",
|
| 367 |
+
"::1",
|
| 368 |
+
"localhost"
|
| 369 |
+
])
|
| 370 |
+
|
| 371 |
+
# Development environment
|
| 372 |
+
if settings.is_development:
|
| 373 |
+
defaults.extend([
|
| 374 |
+
"10.0.0.0/8", # Private network
|
| 375 |
+
"172.16.0.0/12", # Private network
|
| 376 |
+
"192.168.0.0/16" # Private network
|
| 377 |
+
])
|
| 378 |
+
|
| 379 |
+
# Production environment - add known services
|
| 380 |
+
if settings.is_production:
|
| 381 |
+
# Vercel IPs (example - would need real ranges)
|
| 382 |
+
defaults.extend([
|
| 383 |
+
"76.76.21.0/24", # Vercel edge network (example)
|
| 384 |
+
"76.223.0.0/16" # Vercel edge network (example)
|
| 385 |
+
])
|
| 386 |
+
|
| 387 |
+
# HuggingFace Spaces IPs (example - would need real ranges)
|
| 388 |
+
defaults.extend([
|
| 389 |
+
"34.0.0.0/8", # Google Cloud (where HF runs)
|
| 390 |
+
"35.0.0.0/8" # Google Cloud
|
| 391 |
+
])
|
| 392 |
+
|
| 393 |
+
# Monitoring services
|
| 394 |
+
defaults.extend([
|
| 395 |
+
"52.0.0.0/8" # AWS (for monitoring)
|
| 396 |
+
])
|
| 397 |
+
|
| 398 |
+
return defaults
|
| 399 |
+
|
| 400 |
+
async def initialize_defaults(
|
| 401 |
+
self,
|
| 402 |
+
session: AsyncSession,
|
| 403 |
+
created_by: str = "system"
|
| 404 |
+
) -> int:
|
| 405 |
+
"""Initialize default whitelist entries."""
|
| 406 |
+
defaults = self.get_default_whitelist()
|
| 407 |
+
count = 0
|
| 408 |
+
|
| 409 |
+
for ip in defaults:
|
| 410 |
+
try:
|
| 411 |
+
is_cidr = "/" in ip
|
| 412 |
+
await self.add_ip(
|
| 413 |
+
session=session,
|
| 414 |
+
ip_address=ip,
|
| 415 |
+
created_by=created_by,
|
| 416 |
+
description="Default whitelist entry",
|
| 417 |
+
environment=settings.app_env,
|
| 418 |
+
is_cidr=is_cidr
|
| 419 |
+
)
|
| 420 |
+
count += 1
|
| 421 |
+
except ValueError:
|
| 422 |
+
# Already exists or invalid
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
logger.info(
|
| 426 |
+
"ip_whitelist_defaults_initialized",
|
| 427 |
+
count=count,
|
| 428 |
+
environment=settings.app_env
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return count
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# Global instance
|
| 435 |
+
ip_whitelist_service = IPWhitelistService()
|
tests/unit/services/test_ip_whitelist_service.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for IP whitelist service."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from datetime import datetime, timedelta, timezone
|
| 5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
from unittest.mock import Mock, AsyncMock, patch
|
| 7 |
+
|
| 8 |
+
from src.services.ip_whitelist_service import IPWhitelistService, IPWhitelist
|
| 9 |
+
from src.core.config import settings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def ip_whitelist_service():
|
| 14 |
+
"""Create IP whitelist service instance."""
|
| 15 |
+
return IPWhitelistService()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@pytest.fixture
|
| 19 |
+
async def mock_db_session():
|
| 20 |
+
"""Create mock database session."""
|
| 21 |
+
session = AsyncMock(spec=AsyncSession)
|
| 22 |
+
session.commit = AsyncMock()
|
| 23 |
+
session.execute = AsyncMock()
|
| 24 |
+
session.add = Mock()
|
| 25 |
+
return session
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TestIPWhitelistService:
|
| 29 |
+
"""Test IP whitelist service."""
|
| 30 |
+
|
| 31 |
+
async def test_add_single_ip(self, ip_whitelist_service, mock_db_session):
|
| 32 |
+
"""Test adding a single IP address."""
|
| 33 |
+
# Mock query result
|
| 34 |
+
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
| 35 |
+
|
| 36 |
+
# Add IP
|
| 37 |
+
entry = await ip_whitelist_service.add_ip(
|
| 38 |
+
session=mock_db_session,
|
| 39 |
+
ip_address="192.168.1.100",
|
| 40 |
+
created_by="[email protected]",
|
| 41 |
+
description="Test IP",
|
| 42 |
+
environment="production"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Verify
|
| 46 |
+
assert entry.ip_address == "192.168.1.100"
|
| 47 |
+
assert entry.created_by == "[email protected]"
|
| 48 |
+
assert entry.description == "Test IP"
|
| 49 |
+
assert entry.environment == "production"
|
| 50 |
+
assert entry.is_cidr is False
|
| 51 |
+
assert entry.active is True
|
| 52 |
+
|
| 53 |
+
# Verify database operations
|
| 54 |
+
mock_db_session.add.assert_called_once()
|
| 55 |
+
mock_db_session.commit.assert_called_once()
|
| 56 |
+
|
| 57 |
+
async def test_add_cidr_range(self, ip_whitelist_service, mock_db_session):
|
| 58 |
+
"""Test adding a CIDR range."""
|
| 59 |
+
# Mock query result
|
| 60 |
+
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
| 61 |
+
|
| 62 |
+
# Add CIDR
|
| 63 |
+
entry = await ip_whitelist_service.add_ip(
|
| 64 |
+
session=mock_db_session,
|
| 65 |
+
ip_address="10.0.0.0/24",
|
| 66 |
+
created_by="[email protected]",
|
| 67 |
+
description="Test subnet",
|
| 68 |
+
environment="production",
|
| 69 |
+
is_cidr=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Verify
|
| 73 |
+
assert entry.ip_address == "10.0.0.0"
|
| 74 |
+
assert entry.is_cidr is True
|
| 75 |
+
assert entry.cidr_prefix == 24
|
| 76 |
+
assert entry.active is True
|
| 77 |
+
|
| 78 |
+
async def test_add_duplicate_ip_fails(self, ip_whitelist_service, mock_db_session):
|
| 79 |
+
"""Test adding duplicate IP fails."""
|
| 80 |
+
# Mock existing entry
|
| 81 |
+
existing = Mock(spec=IPWhitelist)
|
| 82 |
+
mock_db_session.execute.return_value.scalar_one_or_none.return_value = existing
|
| 83 |
+
|
| 84 |
+
# Try to add duplicate
|
| 85 |
+
with pytest.raises(ValueError, match="already whitelisted"):
|
| 86 |
+
await ip_whitelist_service.add_ip(
|
| 87 |
+
session=mock_db_session,
|
| 88 |
+
ip_address="192.168.1.100",
|
| 89 |
+
created_by="[email protected]"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
async def test_add_invalid_ip_fails(self, ip_whitelist_service, mock_db_session):
|
| 93 |
+
"""Test adding invalid IP fails."""
|
| 94 |
+
with pytest.raises(ValueError, match="Invalid IP address format"):
|
| 95 |
+
await ip_whitelist_service.add_ip(
|
| 96 |
+
session=mock_db_session,
|
| 97 |
+
ip_address="not.an.ip.address",
|
| 98 |
+
created_by="[email protected]"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
async def test_check_ip_exact_match(self, ip_whitelist_service, mock_db_session):
|
| 102 |
+
"""Test checking IP with exact match."""
|
| 103 |
+
# Mock whitelist entries
|
| 104 |
+
entries = [
|
| 105 |
+
Mock(
|
| 106 |
+
ip_address="192.168.1.100",
|
| 107 |
+
is_cidr=False,
|
| 108 |
+
active=True,
|
| 109 |
+
expires_at=None
|
| 110 |
+
),
|
| 111 |
+
Mock(
|
| 112 |
+
ip_address="10.0.0.0",
|
| 113 |
+
is_cidr=True,
|
| 114 |
+
cidr_prefix=24,
|
| 115 |
+
active=True,
|
| 116 |
+
expires_at=None
|
| 117 |
+
)
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
|
| 121 |
+
|
| 122 |
+
# Force cache reload
|
| 123 |
+
ip_whitelist_service._last_cache_update = None
|
| 124 |
+
|
| 125 |
+
# Check whitelisted IP
|
| 126 |
+
result = await ip_whitelist_service.check_ip(
|
| 127 |
+
session=mock_db_session,
|
| 128 |
+
ip_address="192.168.1.100",
|
| 129 |
+
environment="production"
|
| 130 |
+
)
|
| 131 |
+
assert result is True
|
| 132 |
+
|
| 133 |
+
async def test_check_ip_cidr_match(self, ip_whitelist_service, mock_db_session):
|
| 134 |
+
"""Test checking IP within CIDR range."""
|
| 135 |
+
# Mock whitelist entries
|
| 136 |
+
entries = [
|
| 137 |
+
Mock(
|
| 138 |
+
ip_address="10.0.0.0",
|
| 139 |
+
is_cidr=True,
|
| 140 |
+
cidr_prefix=24,
|
| 141 |
+
active=True,
|
| 142 |
+
expires_at=None
|
| 143 |
+
)
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
|
| 147 |
+
|
| 148 |
+
# Force cache reload
|
| 149 |
+
ip_whitelist_service._last_cache_update = None
|
| 150 |
+
|
| 151 |
+
# Check IP in range
|
| 152 |
+
result = await ip_whitelist_service.check_ip(
|
| 153 |
+
session=mock_db_session,
|
| 154 |
+
ip_address="10.0.0.50",
|
| 155 |
+
environment="production"
|
| 156 |
+
)
|
| 157 |
+
assert result is True
|
| 158 |
+
|
| 159 |
+
# Check IP outside range
|
| 160 |
+
result = await ip_whitelist_service.check_ip(
|
| 161 |
+
session=mock_db_session,
|
| 162 |
+
ip_address="10.0.1.50",
|
| 163 |
+
environment="production"
|
| 164 |
+
)
|
| 165 |
+
assert result is False
|
| 166 |
+
|
| 167 |
+
async def test_check_ip_expired_entry(self, ip_whitelist_service, mock_db_session):
|
| 168 |
+
"""Test expired entries are ignored."""
|
| 169 |
+
# Mock expired entry
|
| 170 |
+
entries = [
|
| 171 |
+
Mock(
|
| 172 |
+
ip_address="192.168.1.100",
|
| 173 |
+
is_cidr=False,
|
| 174 |
+
active=True,
|
| 175 |
+
expires_at=datetime.now(timezone.utc) - timedelta(hours=1)
|
| 176 |
+
)
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
|
| 180 |
+
|
| 181 |
+
# Force cache reload
|
| 182 |
+
ip_whitelist_service._last_cache_update = None
|
| 183 |
+
|
| 184 |
+
# Check expired IP
|
| 185 |
+
result = await ip_whitelist_service.check_ip(
|
| 186 |
+
session=mock_db_session,
|
| 187 |
+
ip_address="192.168.1.100",
|
| 188 |
+
environment="production"
|
| 189 |
+
)
|
| 190 |
+
assert result is False
|
| 191 |
+
|
| 192 |
+
async def test_remove_ip(self, ip_whitelist_service, mock_db_session):
|
| 193 |
+
"""Test removing IP from whitelist."""
|
| 194 |
+
# Mock delete result
|
| 195 |
+
mock_result = Mock()
|
| 196 |
+
mock_result.rowcount = 1
|
| 197 |
+
mock_db_session.execute.return_value = mock_result
|
| 198 |
+
|
| 199 |
+
# Remove IP
|
| 200 |
+
result = await ip_whitelist_service.remove_ip(
|
| 201 |
+
session=mock_db_session,
|
| 202 |
+
ip_address="192.168.1.100",
|
| 203 |
+
environment="production"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
assert result is True
|
| 207 |
+
mock_db_session.commit.assert_called_once()
|
| 208 |
+
|
| 209 |
+
async def test_update_ip(self, ip_whitelist_service, mock_db_session):
|
| 210 |
+
"""Test updating whitelist entry."""
|
| 211 |
+
# Mock existing entry
|
| 212 |
+
entry = Mock(spec=IPWhitelist)
|
| 213 |
+
entry.ip_address = "192.168.1.100"
|
| 214 |
+
entry.active = True
|
| 215 |
+
entry.description = "Old description"
|
| 216 |
+
|
| 217 |
+
mock_db_session.execute.return_value.scalar_one_or_none.return_value = entry
|
| 218 |
+
|
| 219 |
+
# Update entry
|
| 220 |
+
result = await ip_whitelist_service.update_ip(
|
| 221 |
+
session=mock_db_session,
|
| 222 |
+
ip_address="192.168.1.100",
|
| 223 |
+
environment="production",
|
| 224 |
+
active=False,
|
| 225 |
+
description="New description"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
assert result is not None
|
| 229 |
+
assert entry.active is False
|
| 230 |
+
assert entry.description == "New description"
|
| 231 |
+
mock_db_session.commit.assert_called_once()
|
| 232 |
+
|
| 233 |
+
def test_get_default_whitelist_development(self, ip_whitelist_service):
|
| 234 |
+
"""Test default whitelist for development."""
|
| 235 |
+
with patch.object(settings, 'is_development', True):
|
| 236 |
+
defaults = ip_whitelist_service.get_default_whitelist()
|
| 237 |
+
|
| 238 |
+
# Should include localhost and private networks
|
| 239 |
+
assert "127.0.0.1" in defaults
|
| 240 |
+
assert "::1" in defaults
|
| 241 |
+
assert "10.0.0.0/8" in defaults
|
| 242 |
+
assert "192.168.0.0/16" in defaults
|
| 243 |
+
|
| 244 |
+
def test_get_default_whitelist_production(self, ip_whitelist_service):
|
| 245 |
+
"""Test default whitelist for production."""
|
| 246 |
+
with patch.object(settings, 'is_production', True):
|
| 247 |
+
defaults = ip_whitelist_service.get_default_whitelist()
|
| 248 |
+
|
| 249 |
+
# Should include localhost and service IPs
|
| 250 |
+
assert "127.0.0.1" in defaults
|
| 251 |
+
assert "::1" in defaults
|
| 252 |
+
# Should have cloud provider ranges
|
| 253 |
+
assert any("76." in ip for ip in defaults) # Vercel
|
| 254 |
+
assert any("34." in ip for ip in defaults) # Google Cloud
|
| 255 |
+
|
| 256 |
+
async def test_cleanup_expired(self, ip_whitelist_service, mock_db_session):
|
| 257 |
+
"""Test cleaning up expired entries."""
|
| 258 |
+
# Mock delete result
|
| 259 |
+
mock_result = Mock()
|
| 260 |
+
mock_result.rowcount = 5
|
| 261 |
+
mock_db_session.execute.return_value = mock_result
|
| 262 |
+
|
| 263 |
+
# Cleanup
|
| 264 |
+
count = await ip_whitelist_service.cleanup_expired(
|
| 265 |
+
session=mock_db_session,
|
| 266 |
+
environment="production"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
assert count == 5
|
| 270 |
+
mock_db_session.commit.assert_called_once()
|
| 271 |
+
|
| 272 |
+
def test_ip_whitelist_model_matches(self):
|
| 273 |
+
"""Test IPWhitelist model matching logic."""
|
| 274 |
+
# Test exact match
|
| 275 |
+
entry = IPWhitelist(
|
| 276 |
+
id="test",
|
| 277 |
+
ip_address="192.168.1.100",
|
| 278 |
+
is_cidr=False,
|
| 279 |
+
active=True,
|
| 280 |
+
created_by="test"
|
| 281 |
+
)
|
| 282 |
+
assert entry.matches("192.168.1.100") is True
|
| 283 |
+
assert entry.matches("192.168.1.101") is False
|
| 284 |
+
|
| 285 |
+
# Test CIDR match
|
| 286 |
+
entry_cidr = IPWhitelist(
|
| 287 |
+
id="test",
|
| 288 |
+
ip_address="10.0.0.0",
|
| 289 |
+
is_cidr=True,
|
| 290 |
+
cidr_prefix=24,
|
| 291 |
+
active=True,
|
| 292 |
+
created_by="test"
|
| 293 |
+
)
|
| 294 |
+
assert entry_cidr.matches("10.0.0.1") is True
|
| 295 |
+
assert entry_cidr.matches("10.0.0.255") is True
|
| 296 |
+
assert entry_cidr.matches("10.0.1.1") is False
|
| 297 |
+
|
| 298 |
+
# Test inactive entry
|
| 299 |
+
entry.active = False
|
| 300 |
+
assert entry.matches("192.168.1.100") is False
|