anderson-ufrj commited on
Commit
6eb1f60
·
1 Parent(s): dd1b2de

test(infra): add infrastructure component tests

Browse files

- Test metrics collection system
- Test health checking functionality
- Test circuit breaker pattern
- Test retry policies and backoff
- Test database connection pooling
- Add message queue tests

Files changed (1) hide show
  1. tests/unit/test_infrastructure.py +490 -0
tests/unit/test_infrastructure.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for infrastructure components."""
2
+ import pytest
3
+ import asyncio
4
+ from unittest.mock import MagicMock, patch, AsyncMock
5
+ from datetime import datetime, timedelta
6
+ import json
7
+
8
+ from src.infrastructure.monitoring import (
9
+ MetricsCollector,
10
+ HealthChecker,
11
+ PerformanceMonitor,
12
+ ResourceMonitor,
13
+ AlertManager
14
+ )
15
+ from src.infrastructure.database import (
16
+ DatabasePool,
17
+ ConnectionManager,
18
+ TransactionManager,
19
+ QueryOptimizer
20
+ )
21
+ from src.infrastructure.message_queue import (
22
+ MessageQueue,
23
+ MessageBroker,
24
+ EventBus,
25
+ MessagePriority
26
+ )
27
+ from src.infrastructure.circuit_breaker import (
28
+ CircuitBreaker,
29
+ CircuitState,
30
+ CircuitBreakerConfig
31
+ )
32
+ from src.infrastructure.retry import (
33
+ RetryPolicy,
34
+ ExponentialBackoff,
35
+ RetryManager
36
+ )
37
+
38
+
39
+ class TestMetricsCollector:
40
+ """Test metrics collection system."""
41
+
42
+ @pytest.fixture
43
+ def metrics_collector(self):
44
+ """Create metrics collector instance."""
45
+ return MetricsCollector()
46
+
47
+ def test_counter_metric(self, metrics_collector):
48
+ """Test counter metric collection."""
49
+ # Increment counter
50
+ metrics_collector.increment("api_requests_total", labels={"endpoint": "/health"})
51
+ metrics_collector.increment("api_requests_total", labels={"endpoint": "/health"})
52
+ metrics_collector.increment("api_requests_total", labels={"endpoint": "/api/v1/users"})
53
+
54
+ # Get metric value
55
+ value = metrics_collector.get_metric_value(
56
+ "api_requests_total",
57
+ labels={"endpoint": "/health"}
58
+ )
59
+
60
+ assert value == 2
61
+
62
+ def test_gauge_metric(self, metrics_collector):
63
+ """Test gauge metric collection."""
64
+ # Set gauge values
65
+ metrics_collector.set_gauge("active_connections", 10)
66
+ metrics_collector.set_gauge("active_connections", 15)
67
+ metrics_collector.set_gauge("active_connections", 12)
68
+
69
+ value = metrics_collector.get_metric_value("active_connections")
70
+ assert value == 12
71
+
72
+ def test_histogram_metric(self, metrics_collector):
73
+ """Test histogram metric collection."""
74
+ # Record durations
75
+ durations = [0.1, 0.2, 0.15, 0.3, 0.25, 0.4, 0.2, 0.18]
76
+
77
+ for duration in durations:
78
+ metrics_collector.record_duration(
79
+ "request_duration_seconds",
80
+ duration,
81
+ labels={"method": "GET"}
82
+ )
83
+
84
+ stats = metrics_collector.get_histogram_stats("request_duration_seconds")
85
+
86
+ assert stats["count"] == len(durations)
87
+ assert 0.1 <= stats["mean"] <= 0.3
88
+ assert stats["p50"] > 0 # Median
89
+ assert stats["p95"] > stats["p50"] # 95th percentile > median
90
+
91
+ def test_metric_labels(self, metrics_collector):
92
+ """Test metric labeling."""
93
+ # Same metric with different labels
94
+ metrics_collector.increment("errors_total", labels={"type": "database"})
95
+ metrics_collector.increment("errors_total", labels={"type": "api"})
96
+ metrics_collector.increment("errors_total", labels={"type": "api"})
97
+
98
+ db_errors = metrics_collector.get_metric_value(
99
+ "errors_total",
100
+ labels={"type": "database"}
101
+ )
102
+ api_errors = metrics_collector.get_metric_value(
103
+ "errors_total",
104
+ labels={"type": "api"}
105
+ )
106
+
107
+ assert db_errors == 1
108
+ assert api_errors == 2
109
+
110
+
111
+ class TestHealthChecker:
112
+ """Test health checking system."""
113
+
114
+ @pytest.fixture
115
+ def health_checker(self):
116
+ """Create health checker instance."""
117
+ return HealthChecker()
118
+
119
+ @pytest.mark.asyncio
120
+ async def test_component_health_check(self, health_checker):
121
+ """Test individual component health checks."""
122
+ # Register health check functions
123
+ async def database_check():
124
+ return {"status": "healthy", "latency_ms": 5}
125
+
126
+ async def redis_check():
127
+ return {"status": "healthy", "latency_ms": 2}
128
+
129
+ health_checker.register_check("database", database_check)
130
+ health_checker.register_check("redis", redis_check)
131
+
132
+ # Run health checks
133
+ results = await health_checker.check_all()
134
+
135
+ assert results["status"] == "healthy"
136
+ assert results["components"]["database"]["status"] == "healthy"
137
+ assert results["components"]["redis"]["status"] == "healthy"
138
+
139
+ @pytest.mark.asyncio
140
+ async def test_unhealthy_component(self, health_checker):
141
+ """Test handling unhealthy components."""
142
+ async def failing_check():
143
+ raise Exception("Connection failed")
144
+
145
+ async def healthy_check():
146
+ return {"status": "healthy"}
147
+
148
+ health_checker.register_check("failing_service", failing_check)
149
+ health_checker.register_check("healthy_service", healthy_check)
150
+
151
+ results = await health_checker.check_all()
152
+
153
+ assert results["status"] == "degraded"
154
+ assert results["components"]["failing_service"]["status"] == "unhealthy"
155
+ assert results["components"]["healthy_service"]["status"] == "healthy"
156
+
157
+ @pytest.mark.asyncio
158
+ async def test_health_check_timeout(self, health_checker):
159
+ """Test health check timeout handling."""
160
+ async def slow_check():
161
+ await asyncio.sleep(10) # Longer than timeout
162
+ return {"status": "healthy"}
163
+
164
+ health_checker.register_check("slow_service", slow_check, timeout=1)
165
+
166
+ results = await health_checker.check_all()
167
+
168
+ assert results["components"]["slow_service"]["status"] == "timeout"
169
+ assert results["status"] == "degraded"
170
+
171
+
172
+ class TestPerformanceMonitor:
173
+ """Test performance monitoring."""
174
+
175
+ @pytest.fixture
176
+ def perf_monitor(self):
177
+ """Create performance monitor instance."""
178
+ return PerformanceMonitor()
179
+
180
+ @pytest.mark.asyncio
181
+ async def test_operation_timing(self, perf_monitor):
182
+ """Test timing operations."""
183
+ async with perf_monitor.measure("database_query"):
184
+ await asyncio.sleep(0.1)
185
+
186
+ async with perf_monitor.measure("api_call"):
187
+ await asyncio.sleep(0.05)
188
+
189
+ stats = perf_monitor.get_stats()
190
+
191
+ assert "database_query" in stats
192
+ assert stats["database_query"]["count"] == 1
193
+ assert stats["database_query"]["avg_duration"] >= 0.1
194
+
195
+ assert "api_call" in stats
196
+ assert stats["api_call"]["avg_duration"] >= 0.05
197
+
198
+ def test_throughput_calculation(self, perf_monitor):
199
+ """Test throughput calculation."""
200
+ # Record multiple operations
201
+ for _ in range(100):
202
+ perf_monitor.record_operation("process_request")
203
+
204
+ # Calculate throughput
205
+ throughput = perf_monitor.calculate_throughput("process_request", window_seconds=1)
206
+
207
+ assert throughput > 0
208
+ assert throughput <= 100 # Can't be more than recorded
209
+
210
+
211
+ class TestCircuitBreaker:
212
+ """Test circuit breaker pattern."""
213
+
214
+ @pytest.fixture
215
+ def circuit_breaker(self):
216
+ """Create circuit breaker instance."""
217
+ config = CircuitBreakerConfig(
218
+ failure_threshold=3,
219
+ recovery_timeout=5,
220
+ expected_exception=Exception
221
+ )
222
+ return CircuitBreaker("test_service", config)
223
+
224
+ @pytest.mark.asyncio
225
+ async def test_circuit_breaker_closed_state(self, circuit_breaker):
226
+ """Test circuit breaker in closed (normal) state."""
227
+ async def success_operation():
228
+ return "success"
229
+
230
+ result = await circuit_breaker.call(success_operation)
231
+
232
+ assert result == "success"
233
+ assert circuit_breaker.state == CircuitState.CLOSED
234
+ assert circuit_breaker.failure_count == 0
235
+
236
+ @pytest.mark.asyncio
237
+ async def test_circuit_breaker_opens_on_failures(self, circuit_breaker):
238
+ """Test circuit breaker opens after threshold failures."""
239
+ async def failing_operation():
240
+ raise Exception("Operation failed")
241
+
242
+ # Fail multiple times
243
+ for _ in range(3):
244
+ with pytest.raises(Exception):
245
+ await circuit_breaker.call(failing_operation)
246
+
247
+ assert circuit_breaker.state == CircuitState.OPEN
248
+ assert circuit_breaker.failure_count == 3
249
+
250
+ # Should reject calls when open
251
+ with pytest.raises(Exception) as exc_info:
252
+ await circuit_breaker.call(failing_operation)
253
+ assert "Circuit breaker is OPEN" in str(exc_info.value)
254
+
255
+ @pytest.mark.asyncio
256
+ async def test_circuit_breaker_half_open_recovery(self, circuit_breaker):
257
+ """Test circuit breaker recovery through half-open state."""
258
+ # Open the circuit
259
+ circuit_breaker.state = CircuitState.OPEN
260
+ circuit_breaker.last_failure_time = datetime.now() - timedelta(seconds=10)
261
+
262
+ async def success_operation():
263
+ return "recovered"
264
+
265
+ # Should enter half-open and try operation
266
+ result = await circuit_breaker.call(success_operation)
267
+
268
+ assert result == "recovered"
269
+ assert circuit_breaker.state == CircuitState.CLOSED
270
+ assert circuit_breaker.failure_count == 0
271
+
272
+
273
+ class TestRetryPolicy:
274
+ """Test retry policies and backoff strategies."""
275
+
276
+ def test_exponential_backoff(self):
277
+ """Test exponential backoff calculation."""
278
+ backoff = ExponentialBackoff(
279
+ initial_delay=1,
280
+ max_delay=60,
281
+ multiplier=2
282
+ )
283
+
284
+ # Calculate delays for successive retries
285
+ delays = [backoff.get_delay(i) for i in range(5)]
286
+
287
+ assert delays[0] == 1 # Initial delay
288
+ assert delays[1] == 2 # 1 * 2
289
+ assert delays[2] == 4 # 2 * 2
290
+ assert delays[3] == 8 # 4 * 2
291
+ assert delays[4] == 16 # 8 * 2
292
+
293
+ # Test max delay cap
294
+ delay_10 = backoff.get_delay(10)
295
+ assert delay_10 == 60 # Capped at max_delay
296
+
297
+ @pytest.mark.asyncio
298
+ async def test_retry_manager(self):
299
+ """Test retry manager with policy."""
300
+ policy = RetryPolicy(
301
+ max_attempts=3,
302
+ backoff=ExponentialBackoff(initial_delay=0.1),
303
+ retryable_exceptions=(ValueError,)
304
+ )
305
+
306
+ retry_manager = RetryManager(policy)
307
+
308
+ attempt_count = 0
309
+
310
+ async def flaky_operation():
311
+ nonlocal attempt_count
312
+ attempt_count += 1
313
+ if attempt_count < 3:
314
+ raise ValueError("Temporary failure")
315
+ return "success"
316
+
317
+ result = await retry_manager.execute(flaky_operation)
318
+
319
+ assert result == "success"
320
+ assert attempt_count == 3
321
+
322
+ @pytest.mark.asyncio
323
+ async def test_retry_with_non_retryable_exception(self):
324
+ """Test retry skips non-retryable exceptions."""
325
+ policy = RetryPolicy(
326
+ max_attempts=3,
327
+ retryable_exceptions=(ValueError,)
328
+ )
329
+
330
+ retry_manager = RetryManager(policy)
331
+
332
+ async def failing_operation():
333
+ raise TypeError("Non-retryable error")
334
+
335
+ with pytest.raises(TypeError):
336
+ await retry_manager.execute(failing_operation)
337
+
338
+
339
+ class TestDatabasePool:
340
+ """Test database connection pooling."""
341
+
342
+ @pytest.fixture
343
+ def db_pool(self):
344
+ """Create database pool instance."""
345
+ return DatabasePool(
346
+ min_size=2,
347
+ max_size=10,
348
+ max_idle_time=300
349
+ )
350
+
351
+ @pytest.mark.asyncio
352
+ async def test_connection_acquisition(self, db_pool):
353
+ """Test getting connections from pool."""
354
+ # Mock connection
355
+ mock_conn = AsyncMock()
356
+ mock_conn.is_closed.return_value = False
357
+
358
+ with patch.object(db_pool, '_create_connection', return_value=mock_conn):
359
+ # Get connection
360
+ async with db_pool.acquire() as conn:
361
+ assert conn is not None
362
+ assert conn == mock_conn
363
+
364
+ # Connection should be returned to pool
365
+ assert db_pool.size > 0
366
+
367
+ @pytest.mark.asyncio
368
+ async def test_connection_pool_limits(self, db_pool):
369
+ """Test pool size limits."""
370
+ connections = []
371
+
372
+ # Mock connection creation
373
+ with patch.object(db_pool, '_create_connection') as mock_create:
374
+ mock_create.return_value = AsyncMock()
375
+
376
+ # Acquire max connections
377
+ for _ in range(10):
378
+ conn = await db_pool.acquire()
379
+ connections.append(conn)
380
+
381
+ assert db_pool.size == 10 # Max size
382
+
383
+ # Try to acquire one more (should wait or fail)
384
+ with pytest.raises(asyncio.TimeoutError):
385
+ await asyncio.wait_for(db_pool.acquire(), timeout=0.1)
386
+
387
+ @pytest.mark.asyncio
388
+ async def test_connection_health_check(self, db_pool):
389
+ """Test connection health checking."""
390
+ # Create healthy and unhealthy connections
391
+ healthy_conn = AsyncMock()
392
+ healthy_conn.is_closed.return_value = False
393
+ healthy_conn.ping.return_value = True
394
+
395
+ unhealthy_conn = AsyncMock()
396
+ unhealthy_conn.is_closed.return_value = True
397
+
398
+ db_pool._connections = [healthy_conn, unhealthy_conn]
399
+
400
+ # Run health check
401
+ await db_pool.health_check()
402
+
403
+ # Unhealthy connection should be removed
404
+ assert unhealthy_conn not in db_pool._connections
405
+ assert healthy_conn in db_pool._connections
406
+
407
+
408
+ class TestMessageQueue:
409
+ """Test message queue system."""
410
+
411
+ @pytest.fixture
412
+ def message_queue(self):
413
+ """Create message queue instance."""
414
+ return MessageQueue()
415
+
416
+ @pytest.mark.asyncio
417
+ async def test_message_publish_subscribe(self, message_queue):
418
+ """Test pub/sub functionality."""
419
+ received_messages = []
420
+
421
+ async def handler(message):
422
+ received_messages.append(message)
423
+
424
+ # Subscribe to topic
425
+ await message_queue.subscribe("test.topic", handler)
426
+
427
+ # Publish messages
428
+ await message_queue.publish("test.topic", {"data": "message1"})
429
+ await message_queue.publish("test.topic", {"data": "message2"})
430
+
431
+ # Allow time for processing
432
+ await asyncio.sleep(0.1)
433
+
434
+ assert len(received_messages) == 2
435
+ assert received_messages[0]["data"] == "message1"
436
+ assert received_messages[1]["data"] == "message2"
437
+
438
+ @pytest.mark.asyncio
439
+ async def test_message_priority_queue(self, message_queue):
440
+ """Test priority message processing."""
441
+ processed_order = []
442
+
443
+ async def handler(message):
444
+ processed_order.append(message["id"])
445
+
446
+ await message_queue.subscribe("priority.topic", handler)
447
+
448
+ # Publish with different priorities
449
+ await message_queue.publish(
450
+ "priority.topic",
451
+ {"id": "low", "data": "low priority"},
452
+ priority=MessagePriority.LOW
453
+ )
454
+ await message_queue.publish(
455
+ "priority.topic",
456
+ {"id": "high", "data": "high priority"},
457
+ priority=MessagePriority.HIGH
458
+ )
459
+ await message_queue.publish(
460
+ "priority.topic",
461
+ {"id": "medium", "data": "medium priority"},
462
+ priority=MessagePriority.MEDIUM
463
+ )
464
+
465
+ # Process queue
466
+ await message_queue.process_priority_queue()
467
+
468
+ # High priority should be processed first
469
+ assert processed_order[0] == "high"
470
+ assert processed_order[-1] == "low"
471
+
472
+ @pytest.mark.asyncio
473
+ async def test_message_persistence(self, message_queue):
474
+ """Test message persistence for reliability."""
475
+ # Enable persistence
476
+ message_queue.enable_persistence(True)
477
+
478
+ # Publish message
479
+ message_id = await message_queue.publish(
480
+ "persistent.topic",
481
+ {"important": "data"},
482
+ persistent=True
483
+ )
484
+
485
+ # Simulate failure before processing
486
+ # Message should be recoverable
487
+ recovered = await message_queue.recover_messages()
488
+
489
+ assert len(recovered) > 0
490
+ assert any(msg["id"] == message_id for msg in recovered)