cidadao.ai-backend / src /ml /ab_testing.py
anderson-ufrj
fix(ml): implement lazy initialization for MLTrainingPipeline
44eae1d
"""
A/B Testing Framework for ML Models
This module provides A/B testing capabilities for comparing model
performance in production environments.
"""
import asyncio
import json
import random
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple, Union
from enum import Enum
import numpy as np
from scipy import stats
from src.core import get_logger
from src.core.cache import get_redis_client
from src.ml.training_pipeline import get_training_pipeline
logger = get_logger(__name__)
class ABTestStatus(Enum):
"""Status of an A/B test."""
DRAFT = "draft"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
STOPPED = "stopped"
class TrafficAllocationStrategy(Enum):
"""Strategy for allocating traffic between models."""
RANDOM = "random"
WEIGHTED = "weighted"
EPSILON_GREEDY = "epsilon_greedy"
THOMPSON_SAMPLING = "thompson_sampling"
class ABTestFramework:
"""
A/B Testing framework for ML models.
Features:
- Multiple allocation strategies
- Statistical significance testing
- Real-time performance tracking
- Automatic winner selection
- Gradual rollout support
"""
def __init__(self):
"""Initialize the A/B testing framework."""
self.active_tests = {}
self.test_results = {}
async def create_test(
self,
test_name: str,
model_a: Tuple[str, Optional[int]], # (model_id, version)
model_b: Tuple[str, Optional[int]],
allocation_strategy: TrafficAllocationStrategy = TrafficAllocationStrategy.RANDOM,
traffic_split: Tuple[float, float] = (0.5, 0.5),
success_metric: str = "f1_score",
minimum_sample_size: int = 1000,
significance_level: float = 0.05,
auto_stop: bool = True,
duration_hours: Optional[int] = None
) -> Dict[str, Any]:
"""
Create a new A/B test.
Args:
test_name: Unique name for the test
model_a: Model A (control) - (model_id, version)
model_b: Model B (treatment) - (model_id, version)
allocation_strategy: Traffic allocation strategy
traffic_split: Traffic split between models (must sum to 1.0)
success_metric: Metric to optimize
minimum_sample_size: Minimum samples before analysis
significance_level: Statistical significance threshold
auto_stop: Automatically stop when winner found
duration_hours: Maximum test duration
Returns:
Test configuration
"""
if test_name in self.active_tests:
raise ValueError(f"Test {test_name} already exists")
if abs(sum(traffic_split) - 1.0) > 0.001:
raise ValueError("Traffic split must sum to 1.0")
# Load models to verify they exist
pipeline = get_training_pipeline()
await pipeline.load_model(*model_a)
await pipeline.load_model(*model_b)
test_config = {
"test_id": f"ab_test_{test_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"test_name": test_name,
"model_a": {"model_id": model_a[0], "version": model_a[1]},
"model_b": {"model_id": model_b[0], "version": model_b[1]},
"allocation_strategy": allocation_strategy.value,
"traffic_split": traffic_split,
"success_metric": success_metric,
"minimum_sample_size": minimum_sample_size,
"significance_level": significance_level,
"auto_stop": auto_stop,
"status": ABTestStatus.DRAFT.value,
"created_at": datetime.now().isoformat(),
"start_time": None,
"end_time": None,
"duration_hours": duration_hours,
"results": {
"model_a": {"predictions": 0, "successes": 0, "metrics": {}},
"model_b": {"predictions": 0, "successes": 0, "metrics": {}}
}
}
# Initialize allocation strategy specific params
if allocation_strategy == TrafficAllocationStrategy.EPSILON_GREEDY:
test_config["epsilon"] = 0.1 # 10% exploration
elif allocation_strategy == TrafficAllocationStrategy.THOMPSON_SAMPLING:
test_config["thompson_params"] = {
"model_a": {"alpha": 1, "beta": 1},
"model_b": {"alpha": 1, "beta": 1}
}
self.active_tests[test_name] = test_config
# Save to Redis
await self._save_test_config(test_config)
logger.info(f"Created A/B test: {test_name}")
return test_config
async def start_test(self, test_name: str) -> bool:
"""Start an A/B test."""
if test_name not in self.active_tests:
# Try to load from Redis
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
self.active_tests[test_name] = test_config
test_config = self.active_tests[test_name]
if test_config["status"] not in [ABTestStatus.DRAFT.value, ABTestStatus.PAUSED.value]:
raise ValueError(f"Cannot start test in status {test_config['status']}")
test_config["status"] = ABTestStatus.RUNNING.value
test_config["start_time"] = datetime.now().isoformat()
await self._save_test_config(test_config)
logger.info(f"Started A/B test: {test_name}")
return True
async def allocate_model(
self,
test_name: str,
user_id: Optional[str] = None
) -> Tuple[str, int]:
"""
Allocate a model for a user based on the test configuration.
Args:
test_name: Test name
user_id: User identifier for consistent allocation
Returns:
Tuple of (model_id, version)
"""
test_config = self.active_tests.get(test_name)
if not test_config:
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
if test_config["status"] != ABTestStatus.RUNNING.value:
raise ValueError(f"Test {test_name} is not running")
# Select model based on allocation strategy
strategy = TrafficAllocationStrategy(test_config["allocation_strategy"])
if strategy == TrafficAllocationStrategy.RANDOM:
selected = await self._random_allocation(test_config, user_id)
elif strategy == TrafficAllocationStrategy.WEIGHTED:
selected = await self._weighted_allocation(test_config)
elif strategy == TrafficAllocationStrategy.EPSILON_GREEDY:
selected = await self._epsilon_greedy_allocation(test_config)
elif strategy == TrafficAllocationStrategy.THOMPSON_SAMPLING:
selected = await self._thompson_sampling_allocation(test_config)
else:
selected = "model_a" # Default fallback
# Return model info
model_info = test_config[selected]
return (model_info["model_id"], model_info["version"])
async def _random_allocation(
self,
test_config: Dict[str, Any],
user_id: Optional[str] = None
) -> str:
"""Random allocation with optional user-based consistency."""
if user_id:
# Hash user_id for consistent allocation
hash_val = hash(user_id + test_config["test_id"]) % 100
threshold = test_config["traffic_split"][0] * 100
return "model_a" if hash_val < threshold else "model_b"
else:
# Pure random
return "model_a" if random.random() < test_config["traffic_split"][0] else "model_b"
async def _weighted_allocation(self, test_config: Dict[str, Any]) -> str:
"""Weighted allocation based on traffic split."""
return np.random.choice(
["model_a", "model_b"],
p=test_config["traffic_split"]
)
async def _epsilon_greedy_allocation(self, test_config: Dict[str, Any]) -> str:
"""Epsilon-greedy allocation (explore vs exploit)."""
epsilon = test_config.get("epsilon", 0.1)
if random.random() < epsilon:
# Explore
return random.choice(["model_a", "model_b"])
else:
# Exploit - choose best performing
results = test_config["results"]
rate_a = (results["model_a"]["successes"] /
max(results["model_a"]["predictions"], 1))
rate_b = (results["model_b"]["successes"] /
max(results["model_b"]["predictions"], 1))
return "model_a" if rate_a >= rate_b else "model_b"
async def _thompson_sampling_allocation(self, test_config: Dict[str, Any]) -> str:
"""Thompson sampling allocation (Bayesian approach)."""
params = test_config["thompson_params"]
# Sample from Beta distributions
sample_a = np.random.beta(params["model_a"]["alpha"], params["model_a"]["beta"])
sample_b = np.random.beta(params["model_b"]["alpha"], params["model_b"]["beta"])
return "model_a" if sample_a >= sample_b else "model_b"
async def record_prediction(
self,
test_name: str,
model_selection: str, # "model_a" or "model_b"
success: bool,
prediction_metadata: Optional[Dict[str, Any]] = None
):
"""
Record a prediction result for the test.
Args:
test_name: Test name
model_selection: Which model was used
success: Whether prediction was successful
prediction_metadata: Additional metadata
"""
test_config = self.active_tests.get(test_name)
if not test_config:
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
# Update results
results = test_config["results"][model_selection]
results["predictions"] += 1
if success:
results["successes"] += 1
# Update Thompson sampling parameters if applicable
if test_config["allocation_strategy"] == TrafficAllocationStrategy.THOMPSON_SAMPLING.value:
params = test_config["thompson_params"][model_selection]
if success:
params["alpha"] += 1
else:
params["beta"] += 1
# Save updated config
await self._save_test_config(test_config)
# Check if we should analyze results
total_predictions = (test_config["results"]["model_a"]["predictions"] +
test_config["results"]["model_b"]["predictions"])
if total_predictions >= test_config["minimum_sample_size"]:
analysis = await self.analyze_test(test_name)
if test_config["auto_stop"] and analysis.get("winner"):
await self.stop_test(test_name, reason="Winner found")
async def analyze_test(self, test_name: str) -> Dict[str, Any]:
"""
Analyze test results for statistical significance.
Returns:
Analysis results including winner if found
"""
test_config = self.active_tests.get(test_name)
if not test_config:
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
results_a = test_config["results"]["model_a"]
results_b = test_config["results"]["model_b"]
# Calculate conversion rates
rate_a = results_a["successes"] / max(results_a["predictions"], 1)
rate_b = results_b["successes"] / max(results_b["predictions"], 1)
# Perform chi-square test
contingency_table = np.array([
[results_a["successes"], results_a["predictions"] - results_a["successes"]],
[results_b["successes"], results_b["predictions"] - results_b["successes"]]
])
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
# Calculate confidence intervals
ci_a = self._calculate_confidence_interval(
results_a["successes"], results_a["predictions"]
)
ci_b = self._calculate_confidence_interval(
results_b["successes"], results_b["predictions"]
)
# Determine winner
winner = None
if p_value < test_config["significance_level"]:
winner = "model_a" if rate_a > rate_b else "model_b"
# Calculate lift
lift = ((rate_b - rate_a) / rate_a * 100) if rate_a > 0 else 0
analysis = {
"model_a": {
"conversion_rate": rate_a,
"confidence_interval": ci_a,
"sample_size": results_a["predictions"]
},
"model_b": {
"conversion_rate": rate_b,
"confidence_interval": ci_b,
"sample_size": results_b["predictions"]
},
"p_value": p_value,
"chi_square": chi2,
"significant": p_value < test_config["significance_level"],
"winner": winner,
"lift": lift,
"analysis_time": datetime.now().isoformat()
}
# Update test config with latest analysis
test_config["latest_analysis"] = analysis
await self._save_test_config(test_config)
return analysis
def _calculate_confidence_interval(
self,
successes: int,
total: int,
confidence_level: float = 0.95
) -> Tuple[float, float]:
"""Calculate confidence interval for conversion rate."""
if total == 0:
return (0.0, 0.0)
rate = successes / total
z = stats.norm.ppf((1 + confidence_level) / 2)
# Wilson score interval
denominator = 1 + z**2 / total
center = (rate + z**2 / (2 * total)) / denominator
margin = z * np.sqrt(rate * (1 - rate) / total + z**2 / (4 * total**2)) / denominator
return (max(0, center - margin), min(1, center + margin))
async def stop_test(self, test_name: str, reason: str = "Manual stop") -> bool:
"""Stop an A/B test."""
test_config = self.active_tests.get(test_name)
if not test_config:
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
test_config["status"] = ABTestStatus.STOPPED.value
test_config["end_time"] = datetime.now().isoformat()
test_config["stop_reason"] = reason
# Perform final analysis
final_analysis = await self.analyze_test(test_name)
test_config["final_analysis"] = final_analysis
await self._save_test_config(test_config)
# Move to completed tests
self.test_results[test_name] = test_config
if test_name in self.active_tests:
del self.active_tests[test_name]
logger.info(f"Stopped A/B test {test_name}: {reason}")
return True
async def get_test_status(self, test_name: str) -> Dict[str, Any]:
"""Get current status of a test."""
test_config = self.active_tests.get(test_name)
if not test_config:
test_config = await self._load_test_config(test_name)
if not test_config:
raise ValueError(f"Test {test_name} not found")
# Add runtime if running
if test_config["status"] == ABTestStatus.RUNNING.value and test_config["start_time"]:
start = datetime.fromisoformat(test_config["start_time"])
runtime = (datetime.now() - start).total_seconds() / 3600
test_config["runtime_hours"] = runtime
# Check if should auto-stop due to duration
if test_config.get("duration_hours") and runtime >= test_config["duration_hours"]:
await self.stop_test(test_name, reason="Duration limit reached")
return test_config
async def promote_winner(self, test_name: str) -> bool:
"""Promote the winning model to production."""
test_config = self.test_results.get(test_name)
if not test_config:
# Try loading completed test
test_config = await self._load_test_config(test_name)
if not test_config or test_config["status"] != ABTestStatus.STOPPED.value:
raise ValueError(f"Test {test_name} not completed")
final_analysis = test_config.get("final_analysis", {})
winner = final_analysis.get("winner")
if not winner:
raise ValueError(f"No winner found for test {test_name}")
# Promote winning model
model_info = test_config[winner]
pipeline = get_training_pipeline()
success = await pipeline.promote_model(
model_info["model_id"],
model_info["version"],
"production"
)
if success:
logger.info(f"Promoted {winner} from test {test_name} to production")
return success
async def _save_test_config(self, test_config: Dict[str, Any]):
"""Save test configuration to Redis."""
redis_client = await get_redis_client()
key = f"ab_test:{test_config['test_name']}"
await redis_client.set(
key,
json.dumps(test_config),
ex=86400 * 90 # 90 days
)
async def _load_test_config(self, test_name: str) -> Optional[Dict[str, Any]]:
"""Load test configuration from Redis."""
redis_client = await get_redis_client()
key = f"ab_test:{test_name}"
data = await redis_client.get(key)
return json.loads(data) if data else None
async def list_active_tests(self) -> List[Dict[str, Any]]:
"""List all active tests."""
# Load from Redis pattern
redis_client = await get_redis_client()
keys = await redis_client.keys("ab_test:*")
active_tests = []
for key in keys:
data = await redis_client.get(key)
if data:
test_config = json.loads(data)
if test_config["status"] in [ABTestStatus.RUNNING.value, ABTestStatus.PAUSED.value]:
active_tests.append({
"test_name": test_config["test_name"],
"status": test_config["status"],
"model_a": test_config["model_a"]["model_id"],
"model_b": test_config["model_b"]["model_id"],
"start_time": test_config.get("start_time"),
"predictions": (test_config["results"]["model_a"]["predictions"] +
test_config["results"]["model_b"]["predictions"])
})
return active_tests
# Global A/B testing framework instance
ab_testing = ABTestFramework()
async def get_ab_testing() -> ABTestFramework:
"""Get the global A/B testing framework instance."""
return ab_testing