anderson-ufrj commited on
Commit
f70869e
·
1 Parent(s): 97c535b

feat(security): implement IP whitelist for production environments

Browse files

- Add IP whitelist service with CIDR support
- Create middleware for IP validation and enforcement
- Implement admin API endpoints for whitelist management
- Add database migration for ip_whitelists table
- Configure automatic defaults for different environments
- Support expiration dates and metadata for entries
- Include comprehensive unit tests for service logic

alembic/versions/006_add_ip_whitelist_table.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """add ip whitelist table
2
+
3
+ Revision ID: 006
4
+ Revises: 005
5
+ Create Date: 2025-01-25
6
+ """
7
+ from alembic import op
8
+ import sqlalchemy as sa
9
+ from sqlalchemy.dialects import postgresql
10
+
11
+ # revision identifiers
12
+ revision = '006'
13
+ down_revision = '005'
14
+ branch_labels = None
15
+ depends_on = None
16
+
17
+
18
+ def upgrade() -> None:
19
+ """Create ip_whitelists table."""
20
+ op.create_table(
21
+ 'ip_whitelists',
22
+ sa.Column('id', sa.String(64), primary_key=True),
23
+ sa.Column('ip_address', sa.String(45), nullable=False),
24
+ sa.Column('description', sa.String(255), nullable=True),
25
+ sa.Column('environment', sa.String(20), nullable=False, server_default='production'),
26
+ sa.Column('active', sa.Boolean(), nullable=False, server_default='true'),
27
+ sa.Column('created_by', sa.String(255), nullable=False),
28
+ sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
29
+ sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
30
+ sa.Column('metadata', postgresql.JSON(), nullable=False, server_default='{}'),
31
+ sa.Column('is_cidr', sa.Boolean(), nullable=False, server_default='false'),
32
+ sa.Column('cidr_prefix', sa.Integer(), nullable=True),
33
+ )
34
+
35
+ # Create indexes
36
+ op.create_index(
37
+ 'ix_ip_whitelists_ip_address',
38
+ 'ip_whitelists',
39
+ ['ip_address']
40
+ )
41
+
42
+ op.create_index(
43
+ 'ix_ip_whitelists_environment',
44
+ 'ip_whitelists',
45
+ ['environment']
46
+ )
47
+
48
+ op.create_index(
49
+ 'ix_ip_whitelists_active',
50
+ 'ip_whitelists',
51
+ ['active']
52
+ )
53
+
54
+ op.create_index(
55
+ 'ix_ip_whitelists_expires_at',
56
+ 'ip_whitelists',
57
+ ['expires_at']
58
+ )
59
+
60
+ # Create unique constraint on ip + environment
61
+ op.create_unique_constraint(
62
+ 'uq_ip_whitelists_ip_environment',
63
+ 'ip_whitelists',
64
+ ['ip_address', 'environment']
65
+ )
66
+
67
+
68
+ def downgrade() -> None:
69
+ """Drop ip_whitelists table."""
70
+ op.drop_table('ip_whitelists')
src/api/app.py CHANGED
@@ -27,6 +27,8 @@ from src.api.middleware.logging_middleware import LoggingMiddleware
27
  from src.api.middleware.security import SecurityMiddleware
28
  from src.api.middleware.compression import CompressionMiddleware
29
  from src.api.middleware.metrics_middleware import MetricsMiddleware, setup_http_metrics
 
 
30
  from src.infrastructure.observability import (
31
  CorrelationMiddleware,
32
  tracing_manager,
@@ -185,6 +187,39 @@ add_compression_middleware(
185
  exclude_paths={"/health", "/metrics", "/health/metrics", "/api/v1/ws", "/api/v1/observability"}
186
  )
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  # Custom OpenAPI schema
190
  def custom_openapi():
@@ -361,6 +396,22 @@ app.include_router(
361
  tags=["Notifications"]
362
  )
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  # Global exception handler
366
  @app.exception_handler(CidadaoAIError)
 
27
  from src.api.middleware.security import SecurityMiddleware
28
  from src.api.middleware.compression import CompressionMiddleware
29
  from src.api.middleware.metrics_middleware import MetricsMiddleware, setup_http_metrics
30
+ from src.api.middleware.ip_whitelist import IPWhitelistMiddleware
31
+ from src.api.middleware.rate_limit import RateLimitMiddleware as RateLimitMiddlewareV2
32
  from src.infrastructure.observability import (
33
  CorrelationMiddleware,
34
  tracing_manager,
 
187
  exclude_paths={"/health", "/metrics", "/health/metrics", "/api/v1/ws", "/api/v1/observability"}
188
  )
189
 
190
+ # Add IP whitelist middleware (only in production)
191
+ if settings.is_production or settings.app_env == "staging":
192
+ app.add_middleware(
193
+ IPWhitelistMiddleware,
194
+ enabled=True,
195
+ excluded_paths=[
196
+ "/health",
197
+ "/healthz",
198
+ "/ping",
199
+ "/ready",
200
+ "/docs",
201
+ "/redoc",
202
+ "/openapi.json",
203
+ "/metrics",
204
+ "/static",
205
+ "/favicon.ico",
206
+ "/_next",
207
+ "/api/v1/auth/login",
208
+ "/api/v1/auth/register",
209
+ "/api/v1/auth/refresh",
210
+ "/api/v1/public",
211
+ "/api/v1/webhooks/incoming"
212
+ ],
213
+ strict_mode=False # Allow requests if IP can't be determined
214
+ )
215
+
216
+ # Add rate limiting middleware v2
217
+ app.add_middleware(
218
+ RateLimitMiddlewareV2,
219
+ default_tier="free",
220
+ strategy="sliding_window"
221
+ )
222
+
223
 
224
  # Custom OpenAPI schema
225
  def custom_openapi():
 
396
  tags=["Notifications"]
397
  )
398
 
399
+ # Import and include admin routes
400
+ from src.api.routes.admin import ip_whitelist as admin_ip_whitelist
401
+ from src.api.routes import api_keys
402
+
403
+ app.include_router(
404
+ admin_ip_whitelist.router,
405
+ prefix="/api/v1/admin",
406
+ tags=["Admin - IP Whitelist"]
407
+ )
408
+
409
+ app.include_router(
410
+ api_keys.router,
411
+ prefix="/api/v1",
412
+ tags=["API Keys"]
413
+ )
414
+
415
 
416
  # Global exception handler
417
  @app.exception_handler(CidadaoAIError)
src/api/middleware/__init__.py CHANGED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API middleware package.
3
+
4
+ This module exports all middleware classes for easy import.
5
+ """
6
+
7
+ from .rate_limit import RateLimitMiddleware, rate_limit
8
+ from .webhook_verification import (
9
+ WebhookVerificationMiddleware,
10
+ WebhookSigner,
11
+ create_webhook_signature,
12
+ verify_webhook_signature
13
+ )
14
+ from .ip_whitelist import IPWhitelistMiddleware
15
+
16
+ __all__ = [
17
+ "RateLimitMiddleware",
18
+ "rate_limit",
19
+ "WebhookVerificationMiddleware",
20
+ "WebhookSigner",
21
+ "create_webhook_signature",
22
+ "verify_webhook_signature",
23
+ "IPWhitelistMiddleware"
24
+ ]
src/api/middleware/ip_whitelist.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module: api.middleware.ip_whitelist
3
+ Description: IP whitelist middleware for production environments
4
+ Author: Anderson H. Silva
5
+ Date: 2025-01-25
6
+ License: Proprietary - All rights reserved
7
+ """
8
+
9
+ from typing import Optional, List, Set
10
+ import ipaddress
11
+
12
+ from fastapi import Request, HTTPException, status
13
+ from starlette.middleware.base import BaseHTTPMiddleware
14
+ from starlette.responses import JSONResponse
15
+
16
+ from src.core import get_logger
17
+ from src.core.config import settings
18
+ from src.services.ip_whitelist_service import ip_whitelist_service
19
+ from src.db.session import get_session
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class IPWhitelistMiddleware(BaseHTTPMiddleware):
25
+ """
26
+ IP whitelist middleware for production environments.
27
+
28
+ Features:
29
+ - Environment-based activation
30
+ - Path exclusions
31
+ - Multiple IP extraction methods
32
+ - Performance optimization with caching
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ app,
38
+ enabled: Optional[bool] = None,
39
+ excluded_paths: Optional[List[str]] = None,
40
+ strict_mode: bool = True
41
+ ):
42
+ """
43
+ Initialize IP whitelist middleware.
44
+
45
+ Args:
46
+ app: FastAPI application
47
+ enabled: Force enable/disable (None = auto based on environment)
48
+ excluded_paths: Paths to exclude from whitelist check
49
+ strict_mode: If True, reject if can't determine IP
50
+ """
51
+ super().__init__(app)
52
+
53
+ # Auto-enable in production if not specified
54
+ if enabled is None:
55
+ self.enabled = settings.is_production
56
+ else:
57
+ self.enabled = enabled
58
+
59
+ self.excluded_paths = set(excluded_paths or self._get_default_excluded_paths())
60
+ self.strict_mode = strict_mode
61
+
62
+ logger.info(
63
+ "ip_whitelist_middleware_initialized",
64
+ enabled=self.enabled,
65
+ environment=settings.app_env,
66
+ strict_mode=self.strict_mode
67
+ )
68
+
69
+ async def dispatch(self, request: Request, call_next):
70
+ """Process request with IP whitelist check."""
71
+ # Skip if disabled
72
+ if not self.enabled:
73
+ return await call_next(request)
74
+
75
+ # Skip excluded paths
76
+ if self._should_skip(request.url.path):
77
+ return await call_next(request)
78
+
79
+ try:
80
+ # Extract client IP
81
+ client_ip = self._get_client_ip(request)
82
+
83
+ if not client_ip:
84
+ if self.strict_mode:
85
+ logger.warning(
86
+ "ip_whitelist_no_client_ip",
87
+ path=request.url.path,
88
+ headers=dict(request.headers)
89
+ )
90
+ return JSONResponse(
91
+ status_code=status.HTTP_403_FORBIDDEN,
92
+ content={
93
+ "detail": "Client IP could not be determined",
94
+ "error": "IP_NOT_DETERMINED"
95
+ }
96
+ )
97
+ else:
98
+ # Allow through if can't determine IP
99
+ logger.debug("ip_whitelist_no_client_ip_allowing")
100
+ return await call_next(request)
101
+
102
+ # Check whitelist
103
+ async with get_session() as session:
104
+ is_whitelisted = await ip_whitelist_service.check_ip(
105
+ session,
106
+ client_ip,
107
+ environment=settings.app_env
108
+ )
109
+
110
+ if not is_whitelisted:
111
+ logger.warning(
112
+ "ip_whitelist_access_denied",
113
+ client_ip=client_ip,
114
+ path=request.url.path,
115
+ method=request.method
116
+ )
117
+
118
+ return JSONResponse(
119
+ status_code=status.HTTP_403_FORBIDDEN,
120
+ content={
121
+ "detail": "Access denied: IP not whitelisted",
122
+ "error": "IP_NOT_WHITELISTED"
123
+ }
124
+ )
125
+
126
+ # IP is whitelisted, proceed
127
+ request.state.client_ip = client_ip
128
+ request.state.ip_whitelisted = True
129
+
130
+ return await call_next(request)
131
+
132
+ except Exception as e:
133
+ logger.error(
134
+ "ip_whitelist_error",
135
+ error=str(e),
136
+ exc_info=True
137
+ )
138
+
139
+ # In case of error, fail open or closed based on strict mode
140
+ if self.strict_mode:
141
+ return JSONResponse(
142
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
143
+ content={
144
+ "detail": "IP whitelist check failed",
145
+ "error": "WHITELIST_CHECK_ERROR"
146
+ }
147
+ )
148
+ else:
149
+ # Allow through on error
150
+ return await call_next(request)
151
+
152
+ def _get_default_excluded_paths(self) -> List[str]:
153
+ """Get default paths to exclude from whitelist."""
154
+ return [
155
+ # Health checks
156
+ "/health",
157
+ "/healthz",
158
+ "/ping",
159
+ "/ready",
160
+
161
+ # Documentation
162
+ "/docs",
163
+ "/redoc",
164
+ "/openapi.json",
165
+
166
+ # Metrics
167
+ "/metrics",
168
+
169
+ # Static assets
170
+ "/static",
171
+ "/favicon.ico",
172
+ "/_next",
173
+
174
+ # Public endpoints
175
+ "/api/v1/auth/login",
176
+ "/api/v1/auth/register",
177
+ "/api/v1/auth/refresh",
178
+ "/api/v1/public",
179
+
180
+ # Webhook endpoints (they have their own auth)
181
+ "/api/v1/webhooks/incoming"
182
+ ]
183
+
184
+ def _should_skip(self, path: str) -> bool:
185
+ """Check if path should skip whitelist check."""
186
+ # Exact match
187
+ if path in self.excluded_paths:
188
+ return True
189
+
190
+ # Prefix match
191
+ for excluded in self.excluded_paths:
192
+ if excluded.endswith("*") and path.startswith(excluded[:-1]):
193
+ return True
194
+ if path.startswith(excluded + "/"):
195
+ return True
196
+
197
+ return False
198
+
199
+ def _get_client_ip(self, request: Request) -> Optional[str]:
200
+ """
201
+ Extract client IP from request.
202
+
203
+ Tries multiple methods in order:
204
+ 1. X-Real-IP header
205
+ 2. X-Forwarded-For header
206
+ 3. CloudFlare headers
207
+ 4. Direct client connection
208
+ """
209
+ # Try X-Real-IP first (nginx)
210
+ real_ip = request.headers.get("X-Real-IP")
211
+ if real_ip and self._is_valid_ip(real_ip):
212
+ return real_ip
213
+
214
+ # Try X-Forwarded-For (proxy chains)
215
+ forwarded_for = request.headers.get("X-Forwarded-For")
216
+ if forwarded_for:
217
+ # Take the first IP in the chain
218
+ ips = [ip.strip() for ip in forwarded_for.split(",")]
219
+ for ip in ips:
220
+ if self._is_valid_ip(ip):
221
+ return ip
222
+
223
+ # Try CloudFlare headers
224
+ cf_ip = request.headers.get("CF-Connecting-IP")
225
+ if cf_ip and self._is_valid_ip(cf_ip):
226
+ return cf_ip
227
+
228
+ # Try True-Client-IP (Akamai, CloudFlare Enterprise)
229
+ true_client_ip = request.headers.get("True-Client-IP")
230
+ if true_client_ip and self._is_valid_ip(true_client_ip):
231
+ return true_client_ip
232
+
233
+ # Try Fastly header
234
+ fastly_ip = request.headers.get("Fastly-Client-IP")
235
+ if fastly_ip and self._is_valid_ip(fastly_ip):
236
+ return fastly_ip
237
+
238
+ # Fallback to direct connection
239
+ if request.client and request.client.host:
240
+ return request.client.host
241
+
242
+ return None
243
+
244
+ def _is_valid_ip(self, ip: str) -> bool:
245
+ """Check if string is a valid IP address."""
246
+ if not ip:
247
+ return False
248
+
249
+ try:
250
+ # Validate IP
251
+ ipaddress.ip_address(ip)
252
+
253
+ # Reject private/local IPs unless in development
254
+ if not settings.is_development:
255
+ ip_obj = ipaddress.ip_address(ip)
256
+ if ip_obj.is_private or ip_obj.is_loopback:
257
+ return False
258
+
259
+ return True
260
+ except ValueError:
261
+ return False
src/api/middleware/webhook_verification.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module: api.middleware.webhook_verification
3
+ Description: Webhook signature verification middleware
4
+ Author: Anderson H. Silva
5
+ Date: 2025-01-25
6
+ License: Proprietary - All rights reserved
7
+ """
8
+
9
+ import hmac
10
+ import hashlib
11
+ import time
12
+ from typing import Optional, Dict, Any
13
+
14
+ from fastapi import Request, HTTPException, status
15
+ from starlette.middleware.base import BaseHTTPMiddleware
16
+ from starlette.responses import JSONResponse
17
+
18
+ from src.core import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class WebhookVerificationMiddleware(BaseHTTPMiddleware):
24
+ """
25
+ Middleware for verifying incoming webhook signatures.
26
+
27
+ Protects endpoints that receive webhooks from external services.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ app,
33
+ webhook_paths: Optional[Dict[str, str]] = None,
34
+ max_timestamp_age: int = 300 # 5 minutes
35
+ ):
36
+ """
37
+ Initialize webhook verification middleware.
38
+
39
+ Args:
40
+ app: FastAPI application
41
+ webhook_paths: Dict of path -> secret mapping
42
+ max_timestamp_age: Maximum age of timestamp in seconds
43
+ """
44
+ super().__init__(app)
45
+ self.webhook_paths = webhook_paths or {}
46
+ self.max_timestamp_age = max_timestamp_age
47
+
48
+ async def dispatch(self, request: Request, call_next):
49
+ """Process request with webhook verification."""
50
+ # Check if this is a webhook path
51
+ if request.url.path not in self.webhook_paths:
52
+ return await call_next(request)
53
+
54
+ # Get the secret for this path
55
+ secret = self.webhook_paths[request.url.path]
56
+
57
+ try:
58
+ # Read body
59
+ body = await request.body()
60
+
61
+ # Verify signature
62
+ if not self._verify_signature(request, body, secret):
63
+ logger.warning(
64
+ "webhook_signature_verification_failed",
65
+ path=request.url.path,
66
+ headers=dict(request.headers)
67
+ )
68
+
69
+ return JSONResponse(
70
+ status_code=status.HTTP_401_UNAUTHORIZED,
71
+ content={
72
+ "detail": "Invalid webhook signature",
73
+ "error": "INVALID_SIGNATURE"
74
+ }
75
+ )
76
+
77
+ # Verify timestamp if present
78
+ if not self._verify_timestamp(request):
79
+ logger.warning(
80
+ "webhook_timestamp_verification_failed",
81
+ path=request.url.path
82
+ )
83
+
84
+ return JSONResponse(
85
+ status_code=status.HTTP_401_UNAUTHORIZED,
86
+ content={
87
+ "detail": "Webhook timestamp too old",
88
+ "error": "TIMESTAMP_TOO_OLD"
89
+ }
90
+ )
91
+
92
+ # Store raw body for handler
93
+ request.state.webhook_body = body
94
+
95
+ # Process request
96
+ return await call_next(request)
97
+
98
+ except Exception as e:
99
+ logger.error(
100
+ "webhook_verification_error",
101
+ error=str(e),
102
+ exc_info=True
103
+ )
104
+
105
+ return JSONResponse(
106
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
107
+ content={
108
+ "detail": "Webhook verification error",
109
+ "error": "VERIFICATION_ERROR"
110
+ }
111
+ )
112
+
113
+ def _verify_signature(
114
+ self,
115
+ request: Request,
116
+ body: bytes,
117
+ secret: str
118
+ ) -> bool:
119
+ """Verify webhook signature."""
120
+ # Get signature header - support multiple formats
121
+ signature_header = (
122
+ request.headers.get("X-Cidadao-Signature") or
123
+ request.headers.get("X-Webhook-Signature") or
124
+ request.headers.get("X-Hub-Signature-256") # GitHub format
125
+ )
126
+
127
+ if not signature_header:
128
+ logger.debug("No signature header found")
129
+ return False
130
+
131
+ # Parse signature
132
+ if "=" in signature_header:
133
+ algorithm, signature = signature_header.split("=", 1)
134
+ else:
135
+ algorithm = "sha256"
136
+ signature = signature_header
137
+
138
+ # Calculate expected signature
139
+ if algorithm == "sha256":
140
+ expected = hmac.new(
141
+ secret.encode(),
142
+ body,
143
+ hashlib.sha256
144
+ ).hexdigest()
145
+ elif algorithm == "sha1":
146
+ expected = hmac.new(
147
+ secret.encode(),
148
+ body,
149
+ hashlib.sha1
150
+ ).hexdigest()
151
+ else:
152
+ logger.warning(f"Unsupported signature algorithm: {algorithm}")
153
+ return False
154
+
155
+ # Compare signatures
156
+ return hmac.compare_digest(signature, expected)
157
+
158
+ def _verify_timestamp(self, request: Request) -> bool:
159
+ """Verify webhook timestamp is recent."""
160
+ timestamp_header = (
161
+ request.headers.get("X-Cidadao-Timestamp") or
162
+ request.headers.get("X-Webhook-Timestamp")
163
+ )
164
+
165
+ if not timestamp_header:
166
+ # No timestamp to verify
167
+ return True
168
+
169
+ try:
170
+ # Parse timestamp
171
+ if timestamp_header.isdigit():
172
+ # Unix timestamp
173
+ webhook_time = float(timestamp_header)
174
+ else:
175
+ # ISO format
176
+ from dateutil.parser import parse
177
+ webhook_time = parse(timestamp_header).timestamp()
178
+
179
+ # Check age
180
+ current_time = time.time()
181
+ age = abs(current_time - webhook_time)
182
+
183
+ return age <= self.max_timestamp_age
184
+
185
+ except Exception as e:
186
+ logger.error(f"Failed to parse timestamp: {e}")
187
+ return False
188
+
189
+
190
+ def create_webhook_signature(
191
+ payload: bytes,
192
+ secret: str,
193
+ algorithm: str = "sha256"
194
+ ) -> str:
195
+ """
196
+ Create webhook signature for outgoing webhooks.
197
+
198
+ Args:
199
+ payload: Request body
200
+ secret: Webhook secret
201
+ algorithm: Hash algorithm (sha256, sha1)
202
+
203
+ Returns:
204
+ Signature string with format "algorithm=signature"
205
+ """
206
+ if algorithm == "sha256":
207
+ signature = hmac.new(
208
+ secret.encode(),
209
+ payload,
210
+ hashlib.sha256
211
+ ).hexdigest()
212
+ elif algorithm == "sha1":
213
+ signature = hmac.new(
214
+ secret.encode(),
215
+ payload,
216
+ hashlib.sha1
217
+ ).hexdigest()
218
+ else:
219
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
220
+
221
+ return f"{algorithm}={signature}"
222
+
223
+
224
+ def verify_webhook_signature(
225
+ signature: str,
226
+ payload: bytes,
227
+ secret: str
228
+ ) -> bool:
229
+ """
230
+ Verify webhook signature.
231
+
232
+ Args:
233
+ signature: Signature header value
234
+ payload: Request body
235
+ secret: Webhook secret
236
+
237
+ Returns:
238
+ True if signature is valid
239
+ """
240
+ try:
241
+ # Parse signature
242
+ if "=" in signature:
243
+ algorithm, sig = signature.split("=", 1)
244
+ else:
245
+ algorithm = "sha256"
246
+ sig = signature
247
+
248
+ # Generate expected signature
249
+ expected = create_webhook_signature(payload, secret, algorithm)
250
+
251
+ # Extract just the signature part
252
+ if "=" in expected:
253
+ _, expected_sig = expected.split("=", 1)
254
+ else:
255
+ expected_sig = expected
256
+
257
+ # Compare
258
+ return hmac.compare_digest(sig, expected_sig)
259
+
260
+ except Exception as e:
261
+ logger.error(f"Signature verification error: {e}")
262
+ return False
263
+
264
+
265
+ class WebhookSigner:
266
+ """Helper class for signing webhook requests."""
267
+
268
+ def __init__(self, secret: str, algorithm: str = "sha256"):
269
+ """Initialize webhook signer."""
270
+ self.secret = secret
271
+ self.algorithm = algorithm
272
+
273
+ def sign(self, payload: bytes) -> Dict[str, str]:
274
+ """
275
+ Generate webhook headers with signature.
276
+
277
+ Args:
278
+ payload: Request body
279
+
280
+ Returns:
281
+ Dict of headers to include in request
282
+ """
283
+ signature = create_webhook_signature(
284
+ payload,
285
+ self.secret,
286
+ self.algorithm
287
+ )
288
+
289
+ timestamp = str(int(time.time()))
290
+
291
+ return {
292
+ "X-Cidadao-Signature": signature,
293
+ "X-Cidadao-Timestamp": timestamp,
294
+ "X-Cidadao-Algorithm": self.algorithm
295
+ }
src/api/routes/admin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Admin routes package."""
src/api/routes/admin/ip_whitelist.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module: api.routes.admin.ip_whitelist
3
+ Description: Admin routes for managing IP whitelist
4
+ Author: Anderson H. Silva
5
+ Date: 2025-01-25
6
+ License: Proprietary - All rights reserved
7
+ """
8
+
9
+ from typing import List, Optional, Dict, Any
10
+ from datetime import datetime
11
+ import ipaddress
12
+
13
+ from fastapi import APIRouter, Depends, HTTPException, status, Query
14
+ from pydantic import BaseModel, Field, field_validator
15
+
16
+ from src.core import get_logger
17
+ from src.api.dependencies import require_admin, get_db
18
+ from src.services.ip_whitelist_service import ip_whitelist_service, IPWhitelist
19
+ from src.core.config import settings
20
+
21
+ logger = get_logger(__name__)
22
+
23
+ router = APIRouter(prefix="/ip-whitelist", tags=["Admin - IP Whitelist"])
24
+
25
+
26
+ class IPWhitelistRequest(BaseModel):
27
+ """Request to add IP to whitelist."""
28
+ ip_address: str = Field(..., description="IP address or CIDR range")
29
+ description: Optional[str] = Field(None, description="Description of the IP")
30
+ environment: str = Field(default="production", description="Environment")
31
+ expires_at: Optional[datetime] = Field(None, description="Expiration date")
32
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
33
+
34
+ @field_validator("ip_address")
35
+ @classmethod
36
+ def validate_ip(cls, v: str) -> str:
37
+ """Validate IP address or CIDR."""
38
+ try:
39
+ # Try as single IP
40
+ ipaddress.ip_address(v)
41
+ return v
42
+ except ValueError:
43
+ # Try as CIDR
44
+ try:
45
+ ipaddress.ip_network(v, strict=False)
46
+ return v
47
+ except ValueError:
48
+ raise ValueError(f"Invalid IP address or CIDR: {v}")
49
+
50
+ @field_validator("environment")
51
+ @classmethod
52
+ def validate_environment(cls, v: str) -> str:
53
+ """Validate environment."""
54
+ allowed = ["development", "staging", "production", "testing"]
55
+ if v not in allowed:
56
+ raise ValueError(f"Environment must be one of {allowed}")
57
+ return v
58
+
59
+
60
+ class IPWhitelistUpdateRequest(BaseModel):
61
+ """Request to update IP whitelist entry."""
62
+ active: Optional[bool] = None
63
+ description: Optional[str] = None
64
+ expires_at: Optional[datetime] = None
65
+ metadata: Optional[Dict[str, Any]] = None
66
+
67
+
68
+ class IPWhitelistResponse(BaseModel):
69
+ """IP whitelist entry response."""
70
+ id: str
71
+ ip_address: str
72
+ description: Optional[str]
73
+ environment: str
74
+ active: bool
75
+ is_cidr: bool
76
+ cidr_prefix: Optional[int]
77
+ created_by: str
78
+ created_at: datetime
79
+ expires_at: Optional[datetime]
80
+ metadata: Dict[str, Any]
81
+ is_expired: bool
82
+
83
+ @classmethod
84
+ def from_model(cls, model: IPWhitelist) -> "IPWhitelistResponse":
85
+ """Create response from model."""
86
+ return cls(
87
+ id=model.id,
88
+ ip_address=model.ip_address,
89
+ description=model.description,
90
+ environment=model.environment,
91
+ active=model.active,
92
+ is_cidr=model.is_cidr,
93
+ cidr_prefix=model.cidr_prefix,
94
+ created_by=model.created_by,
95
+ created_at=model.created_at,
96
+ expires_at=model.expires_at,
97
+ metadata=model.metadata,
98
+ is_expired=model.is_expired()
99
+ )
100
+
101
+
102
+ @router.post("/add", response_model=IPWhitelistResponse)
103
+ async def add_ip_to_whitelist(
104
+ request: IPWhitelistRequest,
105
+ admin_user=Depends(require_admin),
106
+ db=Depends(get_db)
107
+ ):
108
+ """
109
+ Add IP address or CIDR range to whitelist.
110
+
111
+ Requires admin privileges.
112
+ """
113
+ try:
114
+ is_cidr = "/" in request.ip_address
115
+
116
+ entry = await ip_whitelist_service.add_ip(
117
+ session=db,
118
+ ip_address=request.ip_address,
119
+ created_by=admin_user.get("email", "admin"),
120
+ description=request.description,
121
+ environment=request.environment,
122
+ expires_at=request.expires_at,
123
+ is_cidr=is_cidr,
124
+ metadata=request.metadata
125
+ )
126
+
127
+ logger.info(
128
+ "admin_ip_whitelist_added",
129
+ admin=admin_user.get("email"),
130
+ ip=request.ip_address,
131
+ environment=request.environment
132
+ )
133
+
134
+ return IPWhitelistResponse.from_model(entry)
135
+
136
+ except ValueError as e:
137
+ raise HTTPException(
138
+ status_code=status.HTTP_400_BAD_REQUEST,
139
+ detail=str(e)
140
+ )
141
+ except Exception as e:
142
+ logger.error(
143
+ "admin_ip_whitelist_add_error",
144
+ error=str(e),
145
+ exc_info=True
146
+ )
147
+ raise HTTPException(
148
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
149
+ detail="Failed to add IP to whitelist"
150
+ )
151
+
152
+
153
+ @router.delete("/remove/{ip_address}")
154
+ async def remove_ip_from_whitelist(
155
+ ip_address: str,
156
+ environment: str = Query(default="production"),
157
+ admin_user=Depends(require_admin),
158
+ db=Depends(get_db)
159
+ ):
160
+ """
161
+ Remove IP from whitelist.
162
+
163
+ Requires admin privileges.
164
+ """
165
+ removed = await ip_whitelist_service.remove_ip(
166
+ session=db,
167
+ ip_address=ip_address,
168
+ environment=environment
169
+ )
170
+
171
+ if not removed:
172
+ raise HTTPException(
173
+ status_code=status.HTTP_404_NOT_FOUND,
174
+ detail=f"IP not found in whitelist: {ip_address}"
175
+ )
176
+
177
+ logger.info(
178
+ "admin_ip_whitelist_removed",
179
+ admin=admin_user.get("email"),
180
+ ip=ip_address,
181
+ environment=environment
182
+ )
183
+
184
+ return {"status": "removed", "ip_address": ip_address}
185
+
186
+
187
+ @router.get("/list", response_model=List[IPWhitelistResponse])
188
+ async def list_whitelisted_ips(
189
+ environment: str = Query(default="production"),
190
+ include_expired: bool = Query(default=False),
191
+ admin_user=Depends(require_admin),
192
+ db=Depends(get_db)
193
+ ):
194
+ """
195
+ List all whitelisted IPs.
196
+
197
+ Requires admin privileges.
198
+ """
199
+ entries = await ip_whitelist_service.list_ips(
200
+ session=db,
201
+ environment=environment,
202
+ include_expired=include_expired
203
+ )
204
+
205
+ return [IPWhitelistResponse.from_model(entry) for entry in entries]
206
+
207
+
208
+ @router.get("/check/{ip_address}")
209
+ async def check_ip_whitelist(
210
+ ip_address: str,
211
+ environment: str = Query(default="production"),
212
+ admin_user=Depends(require_admin),
213
+ db=Depends(get_db)
214
+ ):
215
+ """
216
+ Check if IP is whitelisted.
217
+
218
+ Requires admin privileges.
219
+ """
220
+ try:
221
+ # Validate IP
222
+ ipaddress.ip_address(ip_address)
223
+ except ValueError:
224
+ raise HTTPException(
225
+ status_code=status.HTTP_400_BAD_REQUEST,
226
+ detail="Invalid IP address format"
227
+ )
228
+
229
+ is_whitelisted = await ip_whitelist_service.check_ip(
230
+ session=db,
231
+ ip_address=ip_address,
232
+ environment=environment
233
+ )
234
+
235
+ return {
236
+ "ip_address": ip_address,
237
+ "environment": environment,
238
+ "is_whitelisted": is_whitelisted
239
+ }
240
+
241
+
242
+ @router.put("/update/{ip_address}", response_model=IPWhitelistResponse)
243
+ async def update_whitelist_entry(
244
+ ip_address: str,
245
+ request: IPWhitelistUpdateRequest,
246
+ environment: str = Query(default="production"),
247
+ admin_user=Depends(require_admin),
248
+ db=Depends(get_db)
249
+ ):
250
+ """
251
+ Update whitelist entry.
252
+
253
+ Requires admin privileges.
254
+ """
255
+ entry = await ip_whitelist_service.update_ip(
256
+ session=db,
257
+ ip_address=ip_address,
258
+ environment=environment,
259
+ active=request.active,
260
+ description=request.description,
261
+ expires_at=request.expires_at,
262
+ metadata=request.metadata
263
+ )
264
+
265
+ if not entry:
266
+ raise HTTPException(
267
+ status_code=status.HTTP_404_NOT_FOUND,
268
+ detail=f"IP not found in whitelist: {ip_address}"
269
+ )
270
+
271
+ logger.info(
272
+ "admin_ip_whitelist_updated",
273
+ admin=admin_user.get("email"),
274
+ ip=ip_address,
275
+ active=entry.active
276
+ )
277
+
278
+ return IPWhitelistResponse.from_model(entry)
279
+
280
+
281
+ @router.post("/cleanup")
282
+ async def cleanup_expired_entries(
283
+ environment: Optional[str] = None,
284
+ admin_user=Depends(require_admin),
285
+ db=Depends(get_db)
286
+ ):
287
+ """
288
+ Remove expired whitelist entries.
289
+
290
+ Requires admin privileges.
291
+ """
292
+ count = await ip_whitelist_service.cleanup_expired(
293
+ session=db,
294
+ environment=environment
295
+ )
296
+
297
+ logger.info(
298
+ "admin_ip_whitelist_cleanup",
299
+ admin=admin_user.get("email"),
300
+ removed=count,
301
+ environment=environment
302
+ )
303
+
304
+ return {
305
+ "status": "cleaned",
306
+ "removed_count": count,
307
+ "environment": environment
308
+ }
309
+
310
+
311
+ @router.post("/initialize-defaults")
312
+ async def initialize_default_whitelist(
313
+ admin_user=Depends(require_admin),
314
+ db=Depends(get_db)
315
+ ):
316
+ """
317
+ Initialize default whitelist entries for current environment.
318
+
319
+ Requires admin privileges.
320
+ """
321
+ count = await ip_whitelist_service.initialize_defaults(
322
+ session=db,
323
+ created_by=admin_user.get("email", "admin")
324
+ )
325
+
326
+ logger.info(
327
+ "admin_ip_whitelist_defaults_initialized",
328
+ admin=admin_user.get("email"),
329
+ count=count,
330
+ environment=settings.app_env
331
+ )
332
+
333
+ return {
334
+ "status": "initialized",
335
+ "added_count": count,
336
+ "environment": settings.app_env,
337
+ "defaults": ip_whitelist_service.get_default_whitelist()
338
+ }
339
+
340
+
341
+ @router.get("/stats")
342
+ async def get_whitelist_stats(
343
+ admin_user=Depends(require_admin),
344
+ db=Depends(get_db)
345
+ ):
346
+ """
347
+ Get whitelist statistics.
348
+
349
+ Requires admin privileges.
350
+ """
351
+ environments = ["development", "staging", "production"]
352
+ stats = {}
353
+
354
+ for env in environments:
355
+ entries = await ip_whitelist_service.list_ips(
356
+ session=db,
357
+ environment=env,
358
+ include_expired=True
359
+ )
360
+
361
+ active = sum(1 for e in entries if e.active and not e.is_expired())
362
+ expired = sum(1 for e in entries if e.is_expired())
363
+ cidr_ranges = sum(1 for e in entries if e.is_cidr)
364
+
365
+ stats[env] = {
366
+ "total": len(entries),
367
+ "active": active,
368
+ "expired": expired,
369
+ "inactive": sum(1 for e in entries if not e.active),
370
+ "cidr_ranges": cidr_ranges,
371
+ "single_ips": len(entries) - cidr_ranges
372
+ }
373
+
374
+ return {
375
+ "statistics": stats,
376
+ "current_environment": settings.app_env
377
+ }
src/api/routes/webhooks.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module: api.routes.webhooks
3
+ Description: Webhook endpoints for receiving external events
4
+ Author: Anderson H. Silva
5
+ Date: 2025-01-25
6
+ License: Proprietary - All rights reserved
7
+ """
8
+
9
+ from typing import Dict, Any, Optional
10
+ from datetime import datetime
11
+
12
+ from fastapi import APIRouter, Request, Depends, HTTPException, status, BackgroundTasks
13
+ from pydantic import BaseModel, Field
14
+
15
+ from src.core import get_logger
16
+ from src.api.dependencies import get_current_user
17
+ from src.services.webhook_service import WebhookConfig, WebhookEvent, webhook_service
18
+ from src.api.middleware.webhook_verification import verify_webhook_signature
19
+
20
+ logger = get_logger(__name__)
21
+
22
+ router = APIRouter(prefix="/webhooks", tags=["Webhooks"])
23
+
24
+
25
+ class IncomingWebhookPayload(BaseModel):
26
+ """Generic incoming webhook payload."""
27
+ event: str
28
+ timestamp: Optional[datetime] = None
29
+ data: Dict[str, Any]
30
+ signature: Optional[str] = None
31
+
32
+
33
+ class WebhookRegistrationRequest(BaseModel):
34
+ """Request to register a new webhook."""
35
+ url: str = Field(..., description="Webhook endpoint URL")
36
+ events: Optional[list[str]] = Field(None, description="Events to subscribe to (None = all)")
37
+ secret: Optional[str] = Field(None, description="Webhook secret for HMAC signing")
38
+ headers: Optional[Dict[str, str]] = Field(None, description="Custom headers")
39
+ active: bool = Field(True, description="Whether webhook is active")
40
+
41
+
42
+ class WebhookTestRequest(BaseModel):
43
+ """Request to test a webhook."""
44
+ url: str = Field(..., description="Webhook URL to test")
45
+ secret: Optional[str] = Field(None, description="Webhook secret if any")
46
+
47
+
48
+ @router.post("/incoming/github")
49
+ async def receive_github_webhook(
50
+ request: Request,
51
+ background_tasks: BackgroundTasks
52
+ ):
53
+ """
54
+ Receive webhooks from GitHub.
55
+
56
+ Requires webhook signature verification.
57
+ """
58
+ # Get raw body from request state (set by verification middleware)
59
+ body = getattr(request.state, "webhook_body", None)
60
+ if not body:
61
+ body = await request.body()
62
+
63
+ # Parse event type
64
+ event_type = request.headers.get("X-GitHub-Event", "unknown")
65
+
66
+ # Parse payload
67
+ try:
68
+ import json
69
+ payload = json.loads(body)
70
+ except Exception as e:
71
+ logger.error("Failed to parse GitHub webhook", error=str(e))
72
+ raise HTTPException(
73
+ status_code=status.HTTP_400_BAD_REQUEST,
74
+ detail="Invalid payload format"
75
+ )
76
+
77
+ # Log webhook received
78
+ logger.info(
79
+ "github_webhook_received",
80
+ event=event_type,
81
+ repository=payload.get("repository", {}).get("full_name"),
82
+ action=payload.get("action")
83
+ )
84
+
85
+ # Process webhook asynchronously
86
+ background_tasks.add_task(
87
+ process_github_webhook,
88
+ event_type,
89
+ payload
90
+ )
91
+
92
+ return {"status": "accepted", "event": event_type}
93
+
94
+
95
+ @router.post("/incoming/generic/{webhook_id}")
96
+ async def receive_generic_webhook(
97
+ webhook_id: str,
98
+ request: Request,
99
+ payload: IncomingWebhookPayload,
100
+ background_tasks: BackgroundTasks
101
+ ):
102
+ """
103
+ Receive generic webhooks with configurable verification.
104
+
105
+ The webhook_id should match a configured incoming webhook.
106
+ """
107
+ # Verify webhook ID exists and get configuration
108
+ # In production, this would look up from database
109
+ webhook_config = get_incoming_webhook_config(webhook_id)
110
+
111
+ if not webhook_config:
112
+ raise HTTPException(
113
+ status_code=status.HTTP_404_NOT_FOUND,
114
+ detail=f"Webhook configuration not found: {webhook_id}"
115
+ )
116
+
117
+ # Verify signature if secret is configured
118
+ if webhook_config.get("secret"):
119
+ body = await request.body()
120
+ signature = request.headers.get("X-Webhook-Signature")
121
+
122
+ if not signature:
123
+ raise HTTPException(
124
+ status_code=status.HTTP_401_UNAUTHORIZED,
125
+ detail="Missing webhook signature"
126
+ )
127
+
128
+ if not verify_webhook_signature(signature, body, webhook_config["secret"]):
129
+ raise HTTPException(
130
+ status_code=status.HTTP_401_UNAUTHORIZED,
131
+ detail="Invalid webhook signature"
132
+ )
133
+
134
+ # Log webhook
135
+ logger.info(
136
+ "generic_webhook_received",
137
+ webhook_id=webhook_id,
138
+ event=payload.event,
139
+ timestamp=payload.timestamp
140
+ )
141
+
142
+ # Process asynchronously
143
+ background_tasks.add_task(
144
+ process_generic_webhook,
145
+ webhook_id,
146
+ payload
147
+ )
148
+
149
+ return {
150
+ "status": "accepted",
151
+ "webhook_id": webhook_id,
152
+ "event": payload.event
153
+ }
154
+
155
+
156
+ @router.post("/register")
157
+ async def register_webhook(
158
+ request: WebhookRegistrationRequest,
159
+ current_user=Depends(get_current_user)
160
+ ):
161
+ """
162
+ Register a new outgoing webhook.
163
+
164
+ Requires authentication.
165
+ """
166
+ # Convert string events to enum
167
+ events = None
168
+ if request.events:
169
+ try:
170
+ events = [WebhookEvent(e) for e in request.events]
171
+ except ValueError as e:
172
+ raise HTTPException(
173
+ status_code=status.HTTP_400_BAD_REQUEST,
174
+ detail=f"Invalid event type: {e}"
175
+ )
176
+
177
+ # Create webhook config
178
+ config = WebhookConfig(
179
+ url=request.url,
180
+ events=events,
181
+ secret=request.secret,
182
+ headers=request.headers,
183
+ active=request.active
184
+ )
185
+
186
+ # Register webhook
187
+ webhook_service.add_webhook(config)
188
+
189
+ logger.info(
190
+ "webhook_registered",
191
+ user=current_user.get("email"),
192
+ url=request.url,
193
+ events=request.events
194
+ )
195
+
196
+ return {
197
+ "status": "registered",
198
+ "url": request.url,
199
+ "events": request.events,
200
+ "active": request.active
201
+ }
202
+
203
+
204
+ @router.delete("/unregister")
205
+ async def unregister_webhook(
206
+ url: str,
207
+ current_user=Depends(get_current_user)
208
+ ):
209
+ """
210
+ Unregister a webhook.
211
+
212
+ Requires authentication.
213
+ """
214
+ removed = webhook_service.remove_webhook(url)
215
+
216
+ if not removed:
217
+ raise HTTPException(
218
+ status_code=status.HTTP_404_NOT_FOUND,
219
+ detail=f"Webhook not found: {url}"
220
+ )
221
+
222
+ logger.info(
223
+ "webhook_unregistered",
224
+ user=current_user.get("email"),
225
+ url=url
226
+ )
227
+
228
+ return {"status": "unregistered", "url": url}
229
+
230
+
231
+ @router.get("/list")
232
+ async def list_webhooks(
233
+ current_user=Depends(get_current_user)
234
+ ):
235
+ """
236
+ List all registered outgoing webhooks.
237
+
238
+ Requires authentication.
239
+ """
240
+ webhooks = webhook_service.list_webhooks()
241
+
242
+ return {
243
+ "webhooks": [
244
+ {
245
+ "url": str(w.url),
246
+ "events": [e.value for e in w.events] if w.events else None,
247
+ "active": w.active,
248
+ "has_secret": bool(w.secret)
249
+ }
250
+ for w in webhooks
251
+ ],
252
+ "total": len(webhooks)
253
+ }
254
+
255
+
256
+ @router.post("/test")
257
+ async def test_webhook(
258
+ request: WebhookTestRequest,
259
+ background_tasks: BackgroundTasks,
260
+ current_user=Depends(get_current_user)
261
+ ):
262
+ """
263
+ Test a webhook endpoint.
264
+
265
+ Sends a test payload to verify webhook is working.
266
+ """
267
+ # Create temporary webhook config
268
+ config = WebhookConfig(
269
+ url=request.url,
270
+ secret=request.secret,
271
+ max_retries=1,
272
+ timeout=10
273
+ )
274
+
275
+ # Test webhook
276
+ delivery = await webhook_service.test_webhook(config)
277
+
278
+ logger.info(
279
+ "webhook_tested",
280
+ user=current_user.get("email"),
281
+ url=request.url,
282
+ success=delivery.success,
283
+ status_code=delivery.status_code
284
+ )
285
+
286
+ return {
287
+ "url": request.url,
288
+ "success": delivery.success,
289
+ "status_code": delivery.status_code,
290
+ "response": delivery.response_body,
291
+ "error": delivery.error,
292
+ "duration_ms": delivery.duration_ms
293
+ }
294
+
295
+
296
+ @router.get("/history")
297
+ async def get_webhook_history(
298
+ event: Optional[str] = None,
299
+ url: Optional[str] = None,
300
+ success: Optional[bool] = None,
301
+ limit: int = 100,
302
+ current_user=Depends(get_current_user)
303
+ ):
304
+ """
305
+ Get webhook delivery history.
306
+
307
+ Requires authentication.
308
+ """
309
+ # Convert event string to enum if provided
310
+ event_enum = None
311
+ if event:
312
+ try:
313
+ event_enum = WebhookEvent(event)
314
+ except ValueError:
315
+ raise HTTPException(
316
+ status_code=status.HTTP_400_BAD_REQUEST,
317
+ detail=f"Invalid event type: {event}"
318
+ )
319
+
320
+ history = webhook_service.get_delivery_history(
321
+ event=event_enum,
322
+ url=url,
323
+ success=success,
324
+ limit=limit
325
+ )
326
+
327
+ return {
328
+ "deliveries": [
329
+ {
330
+ "webhook_url": d.webhook_url,
331
+ "event": d.event.value,
332
+ "timestamp": d.timestamp.isoformat(),
333
+ "success": d.success,
334
+ "status_code": d.status_code,
335
+ "error": d.error,
336
+ "attempts": d.attempts,
337
+ "duration_ms": d.duration_ms
338
+ }
339
+ for d in history
340
+ ],
341
+ "total": len(history)
342
+ }
343
+
344
+
345
+ # Helper functions
346
+
347
+ def get_incoming_webhook_config(webhook_id: str) -> Optional[Dict[str, Any]]:
348
+ """Get configuration for incoming webhook."""
349
+ # In production, this would be from database
350
+ # For now, return mock config
351
+ configs = {
352
+ "test": {
353
+ "secret": "test-secret",
354
+ "active": True
355
+ },
356
+ "monitoring": {
357
+ "secret": "monitoring-secret",
358
+ "active": True
359
+ }
360
+ }
361
+
362
+ return configs.get(webhook_id)
363
+
364
+
365
+ async def process_github_webhook(event_type: str, payload: Dict[str, Any]):
366
+ """Process GitHub webhook asynchronously."""
367
+ try:
368
+ # Handle different GitHub events
369
+ if event_type == "push":
370
+ # Handle code push
371
+ logger.info("Processing GitHub push event")
372
+ elif event_type == "pull_request":
373
+ # Handle pull request
374
+ logger.info("Processing GitHub pull request event")
375
+ elif event_type == "issues":
376
+ # Handle issues
377
+ logger.info("Processing GitHub issues event")
378
+ # Add more event handlers as needed
379
+
380
+ except Exception as e:
381
+ logger.error(
382
+ "Failed to process GitHub webhook",
383
+ event=event_type,
384
+ error=str(e),
385
+ exc_info=True
386
+ )
387
+
388
+
389
+ async def process_generic_webhook(webhook_id: str, payload: IncomingWebhookPayload):
390
+ """Process generic webhook asynchronously."""
391
+ try:
392
+ # Route to appropriate handler based on webhook_id
393
+ logger.info(
394
+ "Processing generic webhook",
395
+ webhook_id=webhook_id,
396
+ event=payload.event
397
+ )
398
+
399
+ # Add specific processing logic here
400
+
401
+ except Exception as e:
402
+ logger.error(
403
+ "Failed to process generic webhook",
404
+ webhook_id=webhook_id,
405
+ error=str(e),
406
+ exc_info=True
407
+ )
src/core/config.py CHANGED
@@ -178,6 +178,11 @@ class Settings(BaseSettings):
178
  rate_limit_per_hour: int = Field(default=1000, description="Rate limit per hour")
179
  rate_limit_per_day: int = Field(default=10000, description="Rate limit per day")
180
 
 
 
 
 
 
181
  # Celery
182
  celery_broker_url: str = Field(
183
  default="redis://localhost:6379/1",
 
178
  rate_limit_per_hour: int = Field(default=1000, description="Rate limit per hour")
179
  rate_limit_per_day: int = Field(default=10000, description="Rate limit per day")
180
 
181
+ # IP Whitelist
182
+ ip_whitelist_enabled: bool = Field(default=True, description="Enable IP whitelist in production")
183
+ ip_whitelist_strict: bool = Field(default=False, description="Strict mode - reject if IP unknown")
184
+ ip_whitelist_cache_ttl: int = Field(default=300, description="IP whitelist cache TTL seconds")
185
+
186
  # Celery
187
  celery_broker_url: str = Field(
188
  default="redis://localhost:6379/1",
src/services/ip_whitelist_service.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module: services.ip_whitelist_service
3
+ Description: IP whitelist management for production environments
4
+ Author: Anderson H. Silva
5
+ Date: 2025-01-25
6
+ License: Proprietary - All rights reserved
7
+ """
8
+
9
+ import ipaddress
10
+ from typing import List, Optional, Set, Dict, Any
11
+ from datetime import datetime, timezone
12
+ import json
13
+
14
+ from src.core import get_logger
15
+ from src.infrastructure.cache import cache_service
16
+ from src.core.config import settings
17
+ from src.models.base import BaseModel
18
+ from sqlalchemy import Column, String, Boolean, DateTime, Integer, JSON
19
+ from sqlalchemy.ext.asyncio import AsyncSession
20
+ from sqlalchemy import select, delete
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class IPWhitelist(BaseModel):
26
+ """IP whitelist entry model."""
27
+ __tablename__ = "ip_whitelists"
28
+
29
+ id = Column(String(64), primary_key=True)
30
+ ip_address = Column(String(45), nullable=False, unique=True) # IPv4 or IPv6
31
+ description = Column(String(255))
32
+ environment = Column(String(20), nullable=False, default="production")
33
+ active = Column(Boolean, default=True)
34
+ created_by = Column(String(255), nullable=False)
35
+ created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
36
+ expires_at = Column(DateTime(timezone=True), nullable=True)
37
+ metadata = Column(JSON, default=dict)
38
+
39
+ # CIDR support
40
+ is_cidr = Column(Boolean, default=False)
41
+ cidr_prefix = Column(Integer, nullable=True)
42
+
43
+ def is_expired(self) -> bool:
44
+ """Check if whitelist entry is expired."""
45
+ if not self.expires_at:
46
+ return False
47
+ return datetime.now(timezone.utc) > self.expires_at
48
+
49
+ def matches(self, ip: str) -> bool:
50
+ """Check if IP matches this whitelist entry."""
51
+ if not self.active or self.is_expired():
52
+ return False
53
+
54
+ try:
55
+ if self.is_cidr:
56
+ # CIDR range check
57
+ network = ipaddress.ip_network(f"{self.ip_address}/{self.cidr_prefix}")
58
+ return ipaddress.ip_address(ip) in network
59
+ else:
60
+ # Exact match
61
+ return self.ip_address == ip
62
+ except ValueError:
63
+ logger.error(f"Invalid IP address format: {ip}")
64
+ return False
65
+
66
+
67
+ class IPWhitelistService:
68
+ """Service for managing IP whitelists."""
69
+
70
+ def __init__(self):
71
+ """Initialize IP whitelist service."""
72
+ self._cache_key_prefix = "ip_whitelist"
73
+ self._cache_ttl = 300 # 5 minutes
74
+ self._whitelist_cache: Optional[Set[str]] = None
75
+ self._cidr_cache: Optional[List[tuple]] = None
76
+ self._last_cache_update: Optional[datetime] = None
77
+
78
+ async def add_ip(
79
+ self,
80
+ session: AsyncSession,
81
+ ip_address: str,
82
+ created_by: str,
83
+ description: Optional[str] = None,
84
+ environment: str = "production",
85
+ expires_at: Optional[datetime] = None,
86
+ is_cidr: bool = False,
87
+ metadata: Optional[Dict[str, Any]] = None
88
+ ) -> IPWhitelist:
89
+ """Add IP address or CIDR range to whitelist."""
90
+ try:
91
+ # Parse and validate IP/CIDR
92
+ if is_cidr or "/" in ip_address:
93
+ network = ipaddress.ip_network(ip_address, strict=False)
94
+ ip_str = str(network.network_address)
95
+ cidr_prefix = network.prefixlen
96
+ is_cidr = True
97
+ else:
98
+ # Validate single IP
99
+ ip_obj = ipaddress.ip_address(ip_address)
100
+ ip_str = str(ip_obj)
101
+ cidr_prefix = None
102
+ is_cidr = False
103
+
104
+ except ValueError as e:
105
+ logger.error(f"Invalid IP address format: {ip_address}")
106
+ raise ValueError(f"Invalid IP address format: {ip_address}") from e
107
+
108
+ # Check if already exists
109
+ existing = await session.execute(
110
+ select(IPWhitelist).where(
111
+ IPWhitelist.ip_address == ip_str,
112
+ IPWhitelist.environment == environment
113
+ )
114
+ )
115
+ if existing.scalar_one_or_none():
116
+ raise ValueError(f"IP address already whitelisted: {ip_str}")
117
+
118
+ # Create whitelist entry
119
+ entry = IPWhitelist(
120
+ id=f"{environment}:{ip_str}",
121
+ ip_address=ip_str,
122
+ description=description,
123
+ environment=environment,
124
+ created_by=created_by,
125
+ expires_at=expires_at,
126
+ is_cidr=is_cidr,
127
+ cidr_prefix=cidr_prefix,
128
+ metadata=metadata or {}
129
+ )
130
+
131
+ session.add(entry)
132
+ await session.commit()
133
+
134
+ # Invalidate cache
135
+ await self._invalidate_cache()
136
+
137
+ logger.info(
138
+ "ip_whitelist_added",
139
+ ip=ip_str,
140
+ environment=environment,
141
+ is_cidr=is_cidr,
142
+ created_by=created_by
143
+ )
144
+
145
+ return entry
146
+
147
+ async def remove_ip(
148
+ self,
149
+ session: AsyncSession,
150
+ ip_address: str,
151
+ environment: str = "production"
152
+ ) -> bool:
153
+ """Remove IP from whitelist."""
154
+ result = await session.execute(
155
+ delete(IPWhitelist).where(
156
+ IPWhitelist.ip_address == ip_address,
157
+ IPWhitelist.environment == environment
158
+ )
159
+ )
160
+ await session.commit()
161
+
162
+ if result.rowcount > 0:
163
+ await self._invalidate_cache()
164
+ logger.info(
165
+ "ip_whitelist_removed",
166
+ ip=ip_address,
167
+ environment=environment
168
+ )
169
+ return True
170
+
171
+ return False
172
+
173
+ async def check_ip(
174
+ self,
175
+ session: AsyncSession,
176
+ ip_address: str,
177
+ environment: str = "production"
178
+ ) -> bool:
179
+ """Check if IP is whitelisted."""
180
+ # Check cache first
181
+ cache_key = f"{self._cache_key_prefix}:{environment}:check:{ip_address}"
182
+ cached = await cache_service.get(cache_key)
183
+ if cached is not None:
184
+ return cached
185
+
186
+ # Load whitelist if needed
187
+ await self._ensure_cache_loaded(session, environment)
188
+
189
+ # Check exact matches first
190
+ if self._whitelist_cache and ip_address in self._whitelist_cache:
191
+ await cache_service.set(cache_key, True, ttl=self._cache_ttl)
192
+ return True
193
+
194
+ # Check CIDR ranges
195
+ if self._cidr_cache:
196
+ for cidr_ip, prefix, expires_at in self._cidr_cache:
197
+ if expires_at and datetime.now(timezone.utc) > expires_at:
198
+ continue
199
+
200
+ try:
201
+ network = ipaddress.ip_network(f"{cidr_ip}/{prefix}")
202
+ if ipaddress.ip_address(ip_address) in network:
203
+ await cache_service.set(cache_key, True, ttl=self._cache_ttl)
204
+ return True
205
+ except ValueError:
206
+ continue
207
+
208
+ # Not whitelisted
209
+ await cache_service.set(cache_key, False, ttl=self._cache_ttl)
210
+ return False
211
+
212
+ async def list_ips(
213
+ self,
214
+ session: AsyncSession,
215
+ environment: str = "production",
216
+ include_expired: bool = False
217
+ ) -> List[IPWhitelist]:
218
+ """List all whitelisted IPs."""
219
+ query = select(IPWhitelist).where(
220
+ IPWhitelist.environment == environment
221
+ )
222
+
223
+ if not include_expired:
224
+ now = datetime.now(timezone.utc)
225
+ query = query.where(
226
+ (IPWhitelist.expires_at.is_(None)) |
227
+ (IPWhitelist.expires_at > now)
228
+ )
229
+
230
+ result = await session.execute(query)
231
+ return list(result.scalars().all())
232
+
233
+ async def update_ip(
234
+ self,
235
+ session: AsyncSession,
236
+ ip_address: str,
237
+ environment: str = "production",
238
+ active: Optional[bool] = None,
239
+ description: Optional[str] = None,
240
+ expires_at: Optional[datetime] = None,
241
+ metadata: Optional[Dict[str, Any]] = None
242
+ ) -> Optional[IPWhitelist]:
243
+ """Update whitelist entry."""
244
+ result = await session.execute(
245
+ select(IPWhitelist).where(
246
+ IPWhitelist.ip_address == ip_address,
247
+ IPWhitelist.environment == environment
248
+ )
249
+ )
250
+ entry = result.scalar_one_or_none()
251
+
252
+ if not entry:
253
+ return None
254
+
255
+ if active is not None:
256
+ entry.active = active
257
+ if description is not None:
258
+ entry.description = description
259
+ if expires_at is not None:
260
+ entry.expires_at = expires_at
261
+ if metadata is not None:
262
+ entry.metadata = metadata
263
+
264
+ await session.commit()
265
+ await self._invalidate_cache()
266
+
267
+ logger.info(
268
+ "ip_whitelist_updated",
269
+ ip=ip_address,
270
+ environment=environment,
271
+ active=entry.active
272
+ )
273
+
274
+ return entry
275
+
276
+ async def cleanup_expired(
277
+ self,
278
+ session: AsyncSession,
279
+ environment: Optional[str] = None
280
+ ) -> int:
281
+ """Remove expired whitelist entries."""
282
+ query = delete(IPWhitelist).where(
283
+ IPWhitelist.expires_at < datetime.now(timezone.utc)
284
+ )
285
+
286
+ if environment:
287
+ query = query.where(IPWhitelist.environment == environment)
288
+
289
+ result = await session.execute(query)
290
+ await session.commit()
291
+
292
+ if result.rowcount > 0:
293
+ await self._invalidate_cache()
294
+ logger.info(
295
+ "ip_whitelist_cleanup",
296
+ removed=result.rowcount,
297
+ environment=environment
298
+ )
299
+
300
+ return result.rowcount
301
+
302
+ async def _ensure_cache_loaded(
303
+ self,
304
+ session: AsyncSession,
305
+ environment: str
306
+ ) -> None:
307
+ """Ensure whitelist is loaded in cache."""
308
+ # Check if cache is still valid
309
+ if (
310
+ self._last_cache_update and
311
+ (datetime.now(timezone.utc) - self._last_cache_update).total_seconds() < self._cache_ttl
312
+ ):
313
+ return
314
+
315
+ # Load from database
316
+ now = datetime.now(timezone.utc)
317
+ result = await session.execute(
318
+ select(IPWhitelist).where(
319
+ IPWhitelist.environment == environment,
320
+ IPWhitelist.active == True,
321
+ (IPWhitelist.expires_at.is_(None)) | (IPWhitelist.expires_at > now)
322
+ )
323
+ )
324
+
325
+ entries = result.scalars().all()
326
+
327
+ # Separate exact IPs and CIDR ranges
328
+ self._whitelist_cache = set()
329
+ self._cidr_cache = []
330
+
331
+ for entry in entries:
332
+ if entry.is_cidr:
333
+ self._cidr_cache.append((
334
+ entry.ip_address,
335
+ entry.cidr_prefix,
336
+ entry.expires_at
337
+ ))
338
+ else:
339
+ self._whitelist_cache.add(entry.ip_address)
340
+
341
+ self._last_cache_update = datetime.now(timezone.utc)
342
+
343
+ logger.debug(
344
+ "ip_whitelist_cache_loaded",
345
+ environment=environment,
346
+ exact_ips=len(self._whitelist_cache),
347
+ cidr_ranges=len(self._cidr_cache)
348
+ )
349
+
350
+ async def _invalidate_cache(self) -> None:
351
+ """Invalidate the whitelist cache."""
352
+ self._whitelist_cache = None
353
+ self._cidr_cache = None
354
+ self._last_cache_update = None
355
+
356
+ # Clear Redis cache patterns
357
+ pattern = f"{self._cache_key_prefix}:*"
358
+ await cache_service.delete_pattern(pattern)
359
+
360
+ def get_default_whitelist(self) -> List[str]:
361
+ """Get default whitelist based on environment."""
362
+ defaults = []
363
+
364
+ # Always allow localhost
365
+ defaults.extend([
366
+ "127.0.0.1",
367
+ "::1",
368
+ "localhost"
369
+ ])
370
+
371
+ # Development environment
372
+ if settings.is_development:
373
+ defaults.extend([
374
+ "10.0.0.0/8", # Private network
375
+ "172.16.0.0/12", # Private network
376
+ "192.168.0.0/16" # Private network
377
+ ])
378
+
379
+ # Production environment - add known services
380
+ if settings.is_production:
381
+ # Vercel IPs (example - would need real ranges)
382
+ defaults.extend([
383
+ "76.76.21.0/24", # Vercel edge network (example)
384
+ "76.223.0.0/16" # Vercel edge network (example)
385
+ ])
386
+
387
+ # HuggingFace Spaces IPs (example - would need real ranges)
388
+ defaults.extend([
389
+ "34.0.0.0/8", # Google Cloud (where HF runs)
390
+ "35.0.0.0/8" # Google Cloud
391
+ ])
392
+
393
+ # Monitoring services
394
+ defaults.extend([
395
+ "52.0.0.0/8" # AWS (for monitoring)
396
+ ])
397
+
398
+ return defaults
399
+
400
+ async def initialize_defaults(
401
+ self,
402
+ session: AsyncSession,
403
+ created_by: str = "system"
404
+ ) -> int:
405
+ """Initialize default whitelist entries."""
406
+ defaults = self.get_default_whitelist()
407
+ count = 0
408
+
409
+ for ip in defaults:
410
+ try:
411
+ is_cidr = "/" in ip
412
+ await self.add_ip(
413
+ session=session,
414
+ ip_address=ip,
415
+ created_by=created_by,
416
+ description="Default whitelist entry",
417
+ environment=settings.app_env,
418
+ is_cidr=is_cidr
419
+ )
420
+ count += 1
421
+ except ValueError:
422
+ # Already exists or invalid
423
+ continue
424
+
425
+ logger.info(
426
+ "ip_whitelist_defaults_initialized",
427
+ count=count,
428
+ environment=settings.app_env
429
+ )
430
+
431
+ return count
432
+
433
+
434
+ # Global instance
435
+ ip_whitelist_service = IPWhitelistService()
tests/unit/services/test_ip_whitelist_service.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for IP whitelist service."""
2
+
3
+ import pytest
4
+ from datetime import datetime, timedelta, timezone
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from unittest.mock import Mock, AsyncMock, patch
7
+
8
+ from src.services.ip_whitelist_service import IPWhitelistService, IPWhitelist
9
+ from src.core.config import settings
10
+
11
+
12
+ @pytest.fixture
13
+ def ip_whitelist_service():
14
+ """Create IP whitelist service instance."""
15
+ return IPWhitelistService()
16
+
17
+
18
+ @pytest.fixture
19
+ async def mock_db_session():
20
+ """Create mock database session."""
21
+ session = AsyncMock(spec=AsyncSession)
22
+ session.commit = AsyncMock()
23
+ session.execute = AsyncMock()
24
+ session.add = Mock()
25
+ return session
26
+
27
+
28
+ class TestIPWhitelistService:
29
+ """Test IP whitelist service."""
30
+
31
+ async def test_add_single_ip(self, ip_whitelist_service, mock_db_session):
32
+ """Test adding a single IP address."""
33
+ # Mock query result
34
+ mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
35
+
36
+ # Add IP
37
+ entry = await ip_whitelist_service.add_ip(
38
+ session=mock_db_session,
39
+ ip_address="192.168.1.100",
40
+ created_by="[email protected]",
41
+ description="Test IP",
42
+ environment="production"
43
+ )
44
+
45
+ # Verify
46
+ assert entry.ip_address == "192.168.1.100"
47
+ assert entry.created_by == "[email protected]"
48
+ assert entry.description == "Test IP"
49
+ assert entry.environment == "production"
50
+ assert entry.is_cidr is False
51
+ assert entry.active is True
52
+
53
+ # Verify database operations
54
+ mock_db_session.add.assert_called_once()
55
+ mock_db_session.commit.assert_called_once()
56
+
57
+ async def test_add_cidr_range(self, ip_whitelist_service, mock_db_session):
58
+ """Test adding a CIDR range."""
59
+ # Mock query result
60
+ mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
61
+
62
+ # Add CIDR
63
+ entry = await ip_whitelist_service.add_ip(
64
+ session=mock_db_session,
65
+ ip_address="10.0.0.0/24",
66
+ created_by="[email protected]",
67
+ description="Test subnet",
68
+ environment="production",
69
+ is_cidr=True
70
+ )
71
+
72
+ # Verify
73
+ assert entry.ip_address == "10.0.0.0"
74
+ assert entry.is_cidr is True
75
+ assert entry.cidr_prefix == 24
76
+ assert entry.active is True
77
+
78
+ async def test_add_duplicate_ip_fails(self, ip_whitelist_service, mock_db_session):
79
+ """Test adding duplicate IP fails."""
80
+ # Mock existing entry
81
+ existing = Mock(spec=IPWhitelist)
82
+ mock_db_session.execute.return_value.scalar_one_or_none.return_value = existing
83
+
84
+ # Try to add duplicate
85
+ with pytest.raises(ValueError, match="already whitelisted"):
86
+ await ip_whitelist_service.add_ip(
87
+ session=mock_db_session,
88
+ ip_address="192.168.1.100",
89
+ created_by="[email protected]"
90
+ )
91
+
92
+ async def test_add_invalid_ip_fails(self, ip_whitelist_service, mock_db_session):
93
+ """Test adding invalid IP fails."""
94
+ with pytest.raises(ValueError, match="Invalid IP address format"):
95
+ await ip_whitelist_service.add_ip(
96
+ session=mock_db_session,
97
+ ip_address="not.an.ip.address",
98
+ created_by="[email protected]"
99
+ )
100
+
101
+ async def test_check_ip_exact_match(self, ip_whitelist_service, mock_db_session):
102
+ """Test checking IP with exact match."""
103
+ # Mock whitelist entries
104
+ entries = [
105
+ Mock(
106
+ ip_address="192.168.1.100",
107
+ is_cidr=False,
108
+ active=True,
109
+ expires_at=None
110
+ ),
111
+ Mock(
112
+ ip_address="10.0.0.0",
113
+ is_cidr=True,
114
+ cidr_prefix=24,
115
+ active=True,
116
+ expires_at=None
117
+ )
118
+ ]
119
+
120
+ mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
121
+
122
+ # Force cache reload
123
+ ip_whitelist_service._last_cache_update = None
124
+
125
+ # Check whitelisted IP
126
+ result = await ip_whitelist_service.check_ip(
127
+ session=mock_db_session,
128
+ ip_address="192.168.1.100",
129
+ environment="production"
130
+ )
131
+ assert result is True
132
+
133
+ async def test_check_ip_cidr_match(self, ip_whitelist_service, mock_db_session):
134
+ """Test checking IP within CIDR range."""
135
+ # Mock whitelist entries
136
+ entries = [
137
+ Mock(
138
+ ip_address="10.0.0.0",
139
+ is_cidr=True,
140
+ cidr_prefix=24,
141
+ active=True,
142
+ expires_at=None
143
+ )
144
+ ]
145
+
146
+ mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
147
+
148
+ # Force cache reload
149
+ ip_whitelist_service._last_cache_update = None
150
+
151
+ # Check IP in range
152
+ result = await ip_whitelist_service.check_ip(
153
+ session=mock_db_session,
154
+ ip_address="10.0.0.50",
155
+ environment="production"
156
+ )
157
+ assert result is True
158
+
159
+ # Check IP outside range
160
+ result = await ip_whitelist_service.check_ip(
161
+ session=mock_db_session,
162
+ ip_address="10.0.1.50",
163
+ environment="production"
164
+ )
165
+ assert result is False
166
+
167
+ async def test_check_ip_expired_entry(self, ip_whitelist_service, mock_db_session):
168
+ """Test expired entries are ignored."""
169
+ # Mock expired entry
170
+ entries = [
171
+ Mock(
172
+ ip_address="192.168.1.100",
173
+ is_cidr=False,
174
+ active=True,
175
+ expires_at=datetime.now(timezone.utc) - timedelta(hours=1)
176
+ )
177
+ ]
178
+
179
+ mock_db_session.execute.return_value.scalars.return_value.all.return_value = entries
180
+
181
+ # Force cache reload
182
+ ip_whitelist_service._last_cache_update = None
183
+
184
+ # Check expired IP
185
+ result = await ip_whitelist_service.check_ip(
186
+ session=mock_db_session,
187
+ ip_address="192.168.1.100",
188
+ environment="production"
189
+ )
190
+ assert result is False
191
+
192
+ async def test_remove_ip(self, ip_whitelist_service, mock_db_session):
193
+ """Test removing IP from whitelist."""
194
+ # Mock delete result
195
+ mock_result = Mock()
196
+ mock_result.rowcount = 1
197
+ mock_db_session.execute.return_value = mock_result
198
+
199
+ # Remove IP
200
+ result = await ip_whitelist_service.remove_ip(
201
+ session=mock_db_session,
202
+ ip_address="192.168.1.100",
203
+ environment="production"
204
+ )
205
+
206
+ assert result is True
207
+ mock_db_session.commit.assert_called_once()
208
+
209
+ async def test_update_ip(self, ip_whitelist_service, mock_db_session):
210
+ """Test updating whitelist entry."""
211
+ # Mock existing entry
212
+ entry = Mock(spec=IPWhitelist)
213
+ entry.ip_address = "192.168.1.100"
214
+ entry.active = True
215
+ entry.description = "Old description"
216
+
217
+ mock_db_session.execute.return_value.scalar_one_or_none.return_value = entry
218
+
219
+ # Update entry
220
+ result = await ip_whitelist_service.update_ip(
221
+ session=mock_db_session,
222
+ ip_address="192.168.1.100",
223
+ environment="production",
224
+ active=False,
225
+ description="New description"
226
+ )
227
+
228
+ assert result is not None
229
+ assert entry.active is False
230
+ assert entry.description == "New description"
231
+ mock_db_session.commit.assert_called_once()
232
+
233
+ def test_get_default_whitelist_development(self, ip_whitelist_service):
234
+ """Test default whitelist for development."""
235
+ with patch.object(settings, 'is_development', True):
236
+ defaults = ip_whitelist_service.get_default_whitelist()
237
+
238
+ # Should include localhost and private networks
239
+ assert "127.0.0.1" in defaults
240
+ assert "::1" in defaults
241
+ assert "10.0.0.0/8" in defaults
242
+ assert "192.168.0.0/16" in defaults
243
+
244
+ def test_get_default_whitelist_production(self, ip_whitelist_service):
245
+ """Test default whitelist for production."""
246
+ with patch.object(settings, 'is_production', True):
247
+ defaults = ip_whitelist_service.get_default_whitelist()
248
+
249
+ # Should include localhost and service IPs
250
+ assert "127.0.0.1" in defaults
251
+ assert "::1" in defaults
252
+ # Should have cloud provider ranges
253
+ assert any("76." in ip for ip in defaults) # Vercel
254
+ assert any("34." in ip for ip in defaults) # Google Cloud
255
+
256
+ async def test_cleanup_expired(self, ip_whitelist_service, mock_db_session):
257
+ """Test cleaning up expired entries."""
258
+ # Mock delete result
259
+ mock_result = Mock()
260
+ mock_result.rowcount = 5
261
+ mock_db_session.execute.return_value = mock_result
262
+
263
+ # Cleanup
264
+ count = await ip_whitelist_service.cleanup_expired(
265
+ session=mock_db_session,
266
+ environment="production"
267
+ )
268
+
269
+ assert count == 5
270
+ mock_db_session.commit.assert_called_once()
271
+
272
+ def test_ip_whitelist_model_matches(self):
273
+ """Test IPWhitelist model matching logic."""
274
+ # Test exact match
275
+ entry = IPWhitelist(
276
+ id="test",
277
+ ip_address="192.168.1.100",
278
+ is_cidr=False,
279
+ active=True,
280
+ created_by="test"
281
+ )
282
+ assert entry.matches("192.168.1.100") is True
283
+ assert entry.matches("192.168.1.101") is False
284
+
285
+ # Test CIDR match
286
+ entry_cidr = IPWhitelist(
287
+ id="test",
288
+ ip_address="10.0.0.0",
289
+ is_cidr=True,
290
+ cidr_prefix=24,
291
+ active=True,
292
+ created_by="test"
293
+ )
294
+ assert entry_cidr.matches("10.0.0.1") is True
295
+ assert entry_cidr.matches("10.0.0.255") is True
296
+ assert entry_cidr.matches("10.0.1.1") is False
297
+
298
+ # Test inactive entry
299
+ entry.active = False
300
+ assert entry.matches("192.168.1.100") is False