| """ |
| Configuration Management Module |
| ============================== |
| |
| Centralized configuration management for BackgroundFX Pro. |
| Handles settings, model paths, quality parameters, and environment variables. |
| |
| Features: |
| - YAML and JSON configuration files |
| - Environment variable integration |
| - Model path management (works with checkpoints/ folder) |
| - Quality thresholds and processing parameters |
| - Development vs Production configurations |
| - Runtime configuration updates |
| |
| Author: BackgroundFX Pro Team |
| License: MIT |
| """ |
|
|
| import os |
| import yaml |
| import json |
| from typing import Dict, Any, Optional, Union |
| from pathlib import Path |
| from dataclasses import dataclass, field |
| import logging |
| from copy import deepcopy |
|
|
| logger = logging.getLogger(__name__) |
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for AI models""" |
| name: str |
| path: Optional[str] = None |
| device: str = "auto" |
| enabled: bool = True |
| fallback: bool = False |
| parameters: Dict[str, Any] = field(default_factory=dict) |
|
|
| @dataclass |
| class QualityConfig: |
| """Quality assessment configuration""" |
| min_detection_confidence: float = 0.5 |
| min_edge_quality: float = 0.3 |
| min_mask_coverage: float = 0.05 |
| max_asymmetry_score: float = 0.8 |
| temporal_consistency_threshold: float = 0.05 |
| matanyone_quality_threshold: float = 0.3 |
|
|
| @dataclass |
| class ProcessingConfig: |
| """Processing pipeline configuration""" |
| batch_size: int = 1 |
| max_resolution: tuple = (1920, 1080) |
| temporal_smoothing: bool = True |
| edge_refinement: bool = True |
| fallback_enabled: bool = True |
| cache_enabled: bool = True |
|
|
| @dataclass |
| class VideoConfig: |
| """Video processing configuration""" |
| output_format: str = "mp4" |
| output_quality: str = "high" |
| preserve_audio: bool = True |
| fps_limit: Optional[int] = None |
| codec: str = "h264" |
|
|
| class ConfigManager: |
| """Main configuration manager""" |
| |
| def __init__(self, config_dir: str = ".", checkpoints_dir: str = "checkpoints"): |
| self.config_dir = Path(config_dir) |
| self.checkpoints_dir = Path(checkpoints_dir) |
| |
| |
| self.models: Dict[str, ModelConfig] = {} |
| self.quality = QualityConfig() |
| self.processing = ProcessingConfig() |
| self.video = VideoConfig() |
| |
| |
| self.debug_mode = False |
| self.environment = "development" |
| |
| |
| self._initialize_default_configs() |
| |
| def _initialize_default_configs(self): |
| """Initialize with default model configurations""" |
| |
| |
| self.models['sam2'] = ModelConfig( |
| name='sam2', |
| path=self._find_model_path('sam2', ['sam2_hiera_large.pt', 'sam2_hiera_base.pt']), |
| device='auto', |
| enabled=True, |
| fallback=False, |
| parameters={ |
| 'model_type': 'vit_l', |
| 'checkpoint': None, |
| 'multimask_output': False, |
| 'use_checkpoint': True |
| } |
| ) |
| |
| |
| self.models['matanyone'] = ModelConfig( |
| name='matanyone', |
| path=None, |
| device='auto', |
| enabled=True, |
| fallback=False, |
| parameters={ |
| 'use_hf_api': True, |
| 'hf_model': 'PeiqingYang/MatAnyone', |
| 'api_timeout': 60, |
| 'quality_threshold': 0.3, |
| 'fallback_enabled': True |
| } |
| ) |
| |
| |
| self.models['traditional_cv'] = ModelConfig( |
| name='traditional_cv', |
| path=None, |
| device='cpu', |
| enabled=True, |
| fallback=True, |
| parameters={ |
| 'methods': ['canny', 'color_detection', 'texture_analysis'], |
| 'edge_threshold': [50, 150], |
| 'color_ranges': { |
| 'dark_hair': [[0, 0, 0], [180, 255, 80]], |
| 'brown_hair': [[8, 50, 20], [25, 255, 200]] |
| } |
| } |
| ) |
| |
| def _find_model_path(self, model_name: str, possible_files: list) -> Optional[str]: |
| """Find model file in checkpoints directory""" |
| try: |
| |
| for filename in possible_files: |
| full_path = self.checkpoints_dir / filename |
| if full_path.exists(): |
| logger.info(f"✅ Found {model_name} at: {full_path}") |
| return str(full_path) |
| |
| |
| model_subdir = self.checkpoints_dir / model_name / filename |
| if model_subdir.exists(): |
| logger.info(f"✅ Found {model_name} at: {model_subdir}") |
| return str(model_subdir) |
| |
| logger.warning(f"⚠️ {model_name} model not found in {self.checkpoints_dir}") |
| return None |
| |
| except Exception as e: |
| logger.error(f"❌ Error finding {model_name}: {e}") |
| return None |
| |
| def load_from_file(self, config_path: str) -> bool: |
| """Load configuration from YAML or JSON file""" |
| try: |
| config_path = Path(config_path) |
| |
| if not config_path.exists(): |
| logger.warning(f"⚠️ Config file not found: {config_path}") |
| return False |
| |
| |
| if config_path.suffix.lower() in ['.yaml', '.yml']: |
| with open(config_path, 'r') as f: |
| config_data = yaml.safe_load(f) |
| elif config_path.suffix.lower() == '.json': |
| with open(config_path, 'r') as f: |
| config_data = json.load(f) |
| else: |
| logger.error(f"❌ Unsupported config format: {config_path.suffix}") |
| return False |
| |
| |
| self._apply_config_data(config_data) |
| logger.info(f"✅ Configuration loaded from: {config_path}") |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to load config from {config_path}: {e}") |
| return False |
| |
| def _apply_config_data(self, config_data: Dict[str, Any]): |
| """Apply configuration data to current settings""" |
| try: |
| |
| if 'models' in config_data: |
| for model_name, model_config in config_data['models'].items(): |
| if model_name in self.models: |
| |
| for key, value in model_config.items(): |
| if hasattr(self.models[model_name], key): |
| setattr(self.models[model_name], key, value) |
| elif key == 'parameters': |
| self.models[model_name].parameters.update(value) |
| |
| |
| if 'quality' in config_data: |
| for key, value in config_data['quality'].items(): |
| if hasattr(self.quality, key): |
| setattr(self.quality, key, value) |
| |
| |
| if 'processing' in config_data: |
| for key, value in config_data['processing'].items(): |
| if hasattr(self.processing, key): |
| setattr(self.processing, key, value) |
| |
| |
| if 'video' in config_data: |
| for key, value in config_data['video'].items(): |
| if hasattr(self.video, key): |
| setattr(self.video, key, value) |
| |
| |
| if 'environment' in config_data: |
| self.environment = config_data['environment'] |
| |
| if 'debug_mode' in config_data: |
| self.debug_mode = config_data['debug_mode'] |
| |
| except Exception as e: |
| logger.error(f"❌ Error applying config data: {e}") |
| raise |
| |
| def load_from_environment(self): |
| """Load configuration from environment variables""" |
| try: |
| |
| sam2_path = os.getenv('SAM2_MODEL_PATH') |
| if sam2_path and Path(sam2_path).exists(): |
| self.models['sam2'].path = sam2_path |
| |
| |
| hf_token = os.getenv('HUGGINGFACE_TOKEN') |
| if hf_token: |
| self.models['matanyone'].parameters['hf_token'] = hf_token |
| |
| |
| device = os.getenv('TORCH_DEVICE', os.getenv('DEVICE')) |
| if device: |
| for model in self.models.values(): |
| if model.device == 'auto': |
| model.device = device |
| |
| |
| batch_size = os.getenv('BATCH_SIZE') |
| if batch_size: |
| self.processing.batch_size = int(batch_size) |
| |
| |
| min_confidence = os.getenv('MIN_DETECTION_CONFIDENCE') |
| if min_confidence: |
| self.quality.min_detection_confidence = float(min_confidence) |
| |
| |
| env_mode = os.getenv('ENVIRONMENT', os.getenv('ENV')) |
| if env_mode: |
| self.environment = env_mode |
| |
| |
| debug = os.getenv('DEBUG', os.getenv('DEBUG_MODE')) |
| if debug: |
| self.debug_mode = debug.lower() in ['true', '1', 'yes'] |
| |
| logger.info("✅ Environment variables loaded") |
| |
| except Exception as e: |
| logger.error(f"❌ Error loading environment variables: {e}") |
| |
| def save_to_file(self, config_path: str, format: str = 'yaml') -> bool: |
| """Save current configuration to file""" |
| try: |
| config_path = Path(config_path) |
| config_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| config_data = self.to_dict() |
| |
| |
| if format.lower() in ['yaml', 'yml']: |
| with open(config_path, 'w') as f: |
| yaml.dump(config_data, f, default_flow_style=False, indent=2) |
| elif format.lower() == 'json': |
| with open(config_path, 'w') as f: |
| json.dump(config_data, f, indent=2) |
| else: |
| logger.error(f"❌ Unsupported save format: {format}") |
| return False |
| |
| logger.info(f"✅ Configuration saved to: {config_path}") |
| return True |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to save config to {config_path}: {e}") |
| return False |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """Convert configuration to dictionary""" |
| return { |
| 'models': { |
| name: { |
| 'name': config.name, |
| 'path': config.path, |
| 'device': config.device, |
| 'enabled': config.enabled, |
| 'fallback': config.fallback, |
| 'parameters': config.parameters |
| } for name, config in self.models.items() |
| }, |
| 'quality': { |
| 'min_detection_confidence': self.quality.min_detection_confidence, |
| 'min_edge_quality': self.quality.min_edge_quality, |
| 'min_mask_coverage': self.quality.min_mask_coverage, |
| 'max_asymmetry_score': self.quality.max_asymmetry_score, |
| 'temporal_consistency_threshold': self.quality.temporal_consistency_threshold, |
| 'matanyone_quality_threshold': self.quality.matanyone_quality_threshold |
| }, |
| 'processing': { |
| 'batch_size': self.processing.batch_size, |
| 'max_resolution': self.processing.max_resolution, |
| 'temporal_smoothing': self.processing.temporal_smoothing, |
| 'edge_refinement': self.processing.edge_refinement, |
| 'fallback_enabled': self.processing.fallback_enabled, |
| 'cache_enabled': self.processing.cache_enabled |
| }, |
| 'video': { |
| 'output_format': self.video.output_format, |
| 'output_quality': self.video.output_quality, |
| 'preserve_audio': self.video.preserve_audio, |
| 'fps_limit': self.video.fps_limit, |
| 'codec': self.video.codec |
| }, |
| 'environment': self.environment, |
| 'debug_mode': self.debug_mode |
| } |
| |
| def get_model_config(self, model_name: str) -> Optional[ModelConfig]: |
| """Get configuration for specific model""" |
| return self.models.get(model_name) |
| |
| def is_model_enabled(self, model_name: str) -> bool: |
| """Check if model is enabled""" |
| model = self.models.get(model_name) |
| return model.enabled if model else False |
| |
| def get_enabled_models(self) -> Dict[str, ModelConfig]: |
| """Get all enabled models""" |
| return {name: config for name, config in self.models.items() if config.enabled} |
| |
| def get_fallback_models(self) -> Dict[str, ModelConfig]: |
| """Get all fallback models""" |
| return {name: config for name, config in self.models.items() |
| if config.enabled and config.fallback} |
| |
| def update_model_path(self, model_name: str, path: str) -> bool: |
| """Update model path""" |
| if model_name in self.models: |
| if Path(path).exists(): |
| self.models[model_name].path = path |
| logger.info(f"✅ Updated {model_name} path: {path}") |
| return True |
| else: |
| logger.error(f"❌ Model path does not exist: {path}") |
| return False |
| else: |
| logger.error(f"❌ Unknown model: {model_name}") |
| return False |
| |
| def validate_configuration(self) -> Dict[str, Any]: |
| """Validate current configuration and return status""" |
| validation_results = { |
| 'valid': True, |
| 'errors': [], |
| 'warnings': [], |
| 'model_status': {} |
| } |
| |
| try: |
| |
| for name, config in self.models.items(): |
| model_status = {'enabled': config.enabled, 'path_exists': True, 'issues': []} |
| |
| if config.enabled and config.path: |
| if not Path(config.path).exists(): |
| model_status['path_exists'] = False |
| model_status['issues'].append(f"Model file not found: {config.path}") |
| validation_results['errors'].append(f"{name}: Model file not found") |
| validation_results['valid'] = False |
| |
| validation_results['model_status'][name] = model_status |
| |
| |
| if not 0 <= self.quality.min_detection_confidence <= 1: |
| validation_results['errors'].append("min_detection_confidence must be between 0 and 1") |
| validation_results['valid'] = False |
| |
| |
| if self.processing.batch_size < 1: |
| validation_results['errors'].append("batch_size must be >= 1") |
| validation_results['valid'] = False |
| |
| |
| enabled_models = self.get_enabled_models() |
| if not enabled_models: |
| validation_results['warnings'].append("No models are enabled") |
| |
| |
| fallback_models = self.get_fallback_models() |
| if not fallback_models: |
| validation_results['warnings'].append("No fallback models configured") |
| |
| logger.info(f"✅ Configuration validation completed: {'Valid' if validation_results['valid'] else 'Invalid'}") |
| |
| except Exception as e: |
| validation_results['valid'] = False |
| validation_results['errors'].append(f"Validation error: {str(e)}") |
| logger.error(f"❌ Configuration validation failed: {e}") |
| |
| return validation_results |
| |
| def create_runtime_config(self) -> Dict[str, Any]: |
| """Create runtime configuration for processing pipeline""" |
| return { |
| 'models': self.get_enabled_models(), |
| 'quality_thresholds': { |
| 'min_confidence': self.quality.min_detection_confidence, |
| 'min_edge_quality': self.quality.min_edge_quality, |
| 'temporal_threshold': self.quality.temporal_consistency_threshold, |
| 'matanyone_threshold': self.quality.matanyone_quality_threshold |
| }, |
| 'processing_options': { |
| 'batch_size': self.processing.batch_size, |
| 'temporal_smoothing': self.processing.temporal_smoothing, |
| 'edge_refinement': self.processing.edge_refinement, |
| 'fallback_enabled': self.processing.fallback_enabled, |
| 'cache_enabled': self.processing.cache_enabled |
| }, |
| 'video_settings': { |
| 'format': self.video.output_format, |
| 'quality': self.video.output_quality, |
| 'preserve_audio': self.video.preserve_audio, |
| 'codec': self.video.codec |
| }, |
| 'debug_mode': self.debug_mode |
| } |
|
|
| |
| _config_manager: Optional[ConfigManager] = None |
|
|
| def get_config(config_dir: str = ".", checkpoints_dir: str = "checkpoints") -> ConfigManager: |
| """Get global configuration manager""" |
| global _config_manager |
| if _config_manager is None: |
| _config_manager = ConfigManager(config_dir, checkpoints_dir) |
| |
| _config_manager.load_from_environment() |
| |
| |
| config_files = ['config.yaml', 'config.yml', 'config.json'] |
| for config_file in config_files: |
| if Path(config_file).exists(): |
| _config_manager.load_from_file(config_file) |
| break |
| |
| return _config_manager |
|
|
| def load_config(config_path: str) -> ConfigManager: |
| """Load configuration from specific file""" |
| config = get_config() |
| config.load_from_file(config_path) |
| return config |
|
|
| def get_model_config(model_name: str) -> Optional[ModelConfig]: |
| """Get model configuration""" |
| return get_config().get_model_config(model_name) |
|
|
| def is_model_enabled(model_name: str) -> bool: |
| """Check if model is enabled""" |
| return get_config().is_model_enabled(model_name) |
|
|
| def get_quality_thresholds() -> QualityConfig: |
| """Get quality configuration""" |
| return get_config().quality |
|
|
| def get_processing_config() -> ProcessingConfig: |
| """Get processing configuration""" |
| return get_config().processing |