anderson-ufrj commited on
Commit
44eae1d
·
1 Parent(s): 8273ba9

fix(ml): implement lazy initialization for MLTrainingPipeline

Browse files

- Replace global instance with lazy initialization function
- Fix settings.get() to use getattr() for Pydantic settings
- Update all imports to use get_training_pipeline() function
- Prevents initialization errors during module import

src/ml/__init__.py CHANGED
@@ -9,7 +9,6 @@ This module provides machine learning capabilities including:
9
 
10
  from src.ml.training_pipeline import (
11
  MLTrainingPipeline,
12
- training_pipeline,
13
  get_training_pipeline
14
  )
15
 
 
9
 
10
  from src.ml.training_pipeline import (
11
  MLTrainingPipeline,
 
12
  get_training_pipeline
13
  )
14
 
src/ml/ab_testing.py CHANGED
@@ -16,7 +16,7 @@ from scipy import stats
16
 
17
  from src.core import get_logger
18
  from src.core.cache import get_redis_client
19
- from src.ml.training_pipeline import training_pipeline
20
 
21
 
22
  logger = get_logger(__name__)
@@ -94,8 +94,9 @@ class ABTestFramework:
94
  raise ValueError("Traffic split must sum to 1.0")
95
 
96
  # Load models to verify they exist
97
- await training_pipeline.load_model(*model_a)
98
- await training_pipeline.load_model(*model_b)
 
99
 
100
  test_config = {
101
  "test_id": f"ab_test_{test_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
@@ -450,7 +451,8 @@ class ABTestFramework:
450
 
451
  # Promote winning model
452
  model_info = test_config[winner]
453
- success = await training_pipeline.promote_model(
 
454
  model_info["model_id"],
455
  model_info["version"],
456
  "production"
 
16
 
17
  from src.core import get_logger
18
  from src.core.cache import get_redis_client
19
+ from src.ml.training_pipeline import get_training_pipeline
20
 
21
 
22
  logger = get_logger(__name__)
 
94
  raise ValueError("Traffic split must sum to 1.0")
95
 
96
  # Load models to verify they exist
97
+ pipeline = get_training_pipeline()
98
+ await pipeline.load_model(*model_a)
99
+ await pipeline.load_model(*model_b)
100
 
101
  test_config = {
102
  "test_id": f"ab_test_{test_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
 
451
 
452
  # Promote winning model
453
  model_info = test_config[winner]
454
+ pipeline = get_training_pipeline()
455
+ success = await pipeline.promote_model(
456
  model_info["model_id"],
457
  model_info["version"],
458
  "production"
src/ml/training_pipeline.py CHANGED
@@ -52,7 +52,7 @@ class MLTrainingPipeline:
52
  """Initialize the training pipeline."""
53
  self.experiment_name = experiment_name
54
  self.mlflow_client = None
55
- self.models_dir = Path(settings.get("ML_MODELS_DIR", "./models"))
56
  self.models_dir.mkdir(exist_ok=True)
57
 
58
  # Supported algorithms
@@ -70,7 +70,7 @@ class MLTrainingPipeline:
70
  def _initialize_mlflow(self):
71
  """Initialize MLflow tracking."""
72
  try:
73
- mlflow.set_tracking_uri(settings.get("MLFLOW_TRACKING_URI", "file:./mlruns"))
74
  mlflow.set_experiment(self.experiment_name)
75
  self.mlflow_client = MlflowClient()
76
  logger.info(f"MLflow initialized with experiment: {self.experiment_name}")
@@ -514,10 +514,12 @@ class MLTrainingPipeline:
514
  return count
515
 
516
 
517
- # Global training pipeline instance
518
- training_pipeline = MLTrainingPipeline()
519
 
520
-
521
- async def get_training_pipeline() -> MLTrainingPipeline:
522
- """Get the global training pipeline instance."""
523
- return training_pipeline
 
 
 
52
  """Initialize the training pipeline."""
53
  self.experiment_name = experiment_name
54
  self.mlflow_client = None
55
+ self.models_dir = Path(getattr(settings, "ML_MODELS_DIR", "./models"))
56
  self.models_dir.mkdir(exist_ok=True)
57
 
58
  # Supported algorithms
 
70
  def _initialize_mlflow(self):
71
  """Initialize MLflow tracking."""
72
  try:
73
+ mlflow.set_tracking_uri(getattr(settings, "MLFLOW_TRACKING_URI", "file:./mlruns"))
74
  mlflow.set_experiment(self.experiment_name)
75
  self.mlflow_client = MlflowClient()
76
  logger.info(f"MLflow initialized with experiment: {self.experiment_name}")
 
514
  return count
515
 
516
 
517
+ # Global training pipeline instance (lazy initialization)
518
+ _training_pipeline = None
519
 
520
+ def get_training_pipeline() -> MLTrainingPipeline:
521
+ """Get or create the global training pipeline instance."""
522
+ global _training_pipeline
523
+ if _training_pipeline is None:
524
+ _training_pipeline = MLTrainingPipeline()
525
+ return _training_pipeline