""" Main processing pipeline for BackgroundFX Pro. Orchestrates the complete background removal and replacement workflow. """ import cv2 import numpy as np import torch from typing import Dict, List, Optional, Tuple, Union, Callable, Any from dataclasses import dataclass, field from enum import Enum from pathlib import Path import time import threading from queue import Queue import json import hashlib from concurrent.futures import ThreadPoolExecutor, Future from ..utils.logger import setup_logger from ..utils.device import DeviceManager from ..utils.config import ConfigManager from ..utils import TimeEstimator, MemoryMonitor from ..core.models import ModelFactory, ModelType from ..core.temporal import TemporalCoherence from ..core.quality import QualityAnalyzer from ..core.edge import EdgeRefinement from ..core.hair_segmentation import HairSegmentation from ..processing.matting import AlphaMatting, MattingConfig, CompositingEngine from ..processing.fallback import FallbackStrategy, FallbackLevel from ..processing.effects import BackgroundEffects, CompositeEffects, EffectType logger = setup_logger(__name__) class ProcessingMode(Enum): """Processing mode types.""" PHOTO = "photo" VIDEO = "video" REALTIME = "realtime" BATCH = "batch" class PipelineStage(Enum): """Pipeline processing stages.""" INITIALIZATION = "initialization" PREPROCESSING = "preprocessing" SEGMENTATION = "segmentation" MATTING = "matting" REFINEMENT = "refinement" EFFECTS = "effects" COMPOSITING = "compositing" POSTPROCESSING = "postprocessing" COMPLETE = "complete" @dataclass class PipelineConfig: """Configuration for the processing pipeline.""" # Model settings model_type: ModelType = ModelType.RMBG_1_4 use_gpu: bool = True device: Optional[str] = None # Processing settings mode: ProcessingMode = ProcessingMode.PHOTO enable_temporal: bool = True enable_hair_refinement: bool = True enable_edge_refinement: bool = True enable_fallback: bool = True # Quality settings quality_preset: str = "high" # low, medium, high, ultra target_resolution: Optional[Tuple[int, int]] = None maintain_aspect_ratio: bool = True # Matting settings matting_method: str = "auto" # auto, trimap, deep, guided matting_config: MattingConfig = field(default_factory=MattingConfig) # Effects settings background_blur: bool = False blur_strength: float = 15.0 apply_effects: List[EffectType] = field(default_factory=list) # Performance settings batch_size: int = 1 num_workers: int = 4 enable_caching: bool = True cache_size_mb: int = 500 # Output settings output_format: str = "png" # png, jpg, webp output_quality: int = 95 preserve_metadata: bool = True # Callbacks progress_callback: Optional[Callable[[float, str], None]] = None stage_callback: Optional[Callable[[PipelineStage, Dict], None]] = None @dataclass class PipelineResult: """Result from pipeline processing.""" success: bool output_image: Optional[np.ndarray] = None alpha_matte: Optional[np.ndarray] = None foreground: Optional[np.ndarray] = None background: Optional[np.ndarray] = None metadata: Dict[str, Any] = field(default_factory=dict) processing_time: float = 0.0 stages_completed: List[PipelineStage] = field(default_factory=list) errors: List[str] = field(default_factory=list) quality_score: float = 0.0 class ProcessingPipeline: """Main processing pipeline orchestrator.""" def __init__(self, config: Optional[PipelineConfig] = None): """ Initialize the processing pipeline. Args: config: Pipeline configuration """ self.config = config or PipelineConfig() self.logger = setup_logger(f"{__name__}.ProcessingPipeline") # Initialize components self._initialize_components() # State management self.current_stage = PipelineStage.INITIALIZATION self.processing_stats = {} self.cache = {} self.is_processing = False # Thread pool for parallel processing self.executor = ThreadPoolExecutor(max_workers=self.config.num_workers) self.logger.info("Pipeline initialized successfully") def _initialize_components(self): """Initialize all pipeline components.""" try: # Device management self.device_manager = DeviceManager() if self.config.device: self.device_manager.set_device(self.config.device) elif not self.config.use_gpu: self.device_manager.set_device('cpu') # Core components self.model_factory = ModelFactory() self.quality_analyzer = QualityAnalyzer() self.edge_refinement = EdgeRefinement() self.temporal_coherence = TemporalCoherence() if self.config.enable_temporal else None self.hair_segmentation = HairSegmentation() if self.config.enable_hair_refinement else None # Processing components self.alpha_matting = AlphaMatting(self.config.matting_config) self.compositing_engine = CompositingEngine() self.background_effects = BackgroundEffects() self.composite_effects = CompositeEffects() # Fallback strategy self.fallback_strategy = FallbackStrategy() if self.config.enable_fallback else None # Memory monitoring self.memory_monitor = MemoryMonitor() self.time_estimator = TimeEstimator() # Load model self._load_model() except Exception as e: self.logger.error(f"Component initialization failed: {e}") raise def _load_model(self): """Load the segmentation model.""" try: self.logger.info(f"Loading model: {self.config.model_type.value}") self.model = self.model_factory.load_model( self.config.model_type, device=self.device_manager.get_device(), optimize=True ) self.logger.info("Model loaded successfully") except Exception as e: self.logger.error(f"Model loading failed: {e}") if self.config.enable_fallback: self.logger.info("Attempting fallback model loading") self.config.model_type = ModelType.U2NET_LITE self.model = self.model_factory.load_model( self.config.model_type, device='cpu' ) def process_image(self, image: Union[np.ndarray, str, Path], background: Optional[Union[np.ndarray, str, Path]] = None, **kwargs) -> PipelineResult: """ Process a single image through the pipeline. Args: image: Input image (array or path) background: Optional background image/path **kwargs: Additional processing parameters Returns: PipelineResult with processed image and metadata """ start_time = time.time() self.is_processing = True result = PipelineResult(success=False) try: # Stage 1: Initialization self._update_stage(PipelineStage.INITIALIZATION) image_array = self._load_image(image) bg_array = self._load_image(background) if background is not None else None # Generate cache key cache_key = self._generate_cache_key(image_array, kwargs) # Check cache if self.config.enable_caching and cache_key in self.cache: self.logger.info("Using cached result") cached_result = self.cache[cache_key] cached_result.processing_time = time.time() - start_time return cached_result # Stage 2: Preprocessing self._update_stage(PipelineStage.PREPROCESSING) preprocessed = self._preprocess_image(image_array) result.metadata['original_size'] = image_array.shape[:2] result.metadata['preprocessed_size'] = preprocessed.shape[:2] # Quality analysis quality_metrics = self.quality_analyzer.analyze_frame(preprocessed) result.metadata['quality_metrics'] = quality_metrics # Stage 3: Segmentation self._update_stage(PipelineStage.SEGMENTATION) segmentation_mask = self._segment_image(preprocessed) # Hair refinement if enabled if self.config.enable_hair_refinement: self.logger.info("Applying hair refinement") hair_mask = self.hair_segmentation.segment_hair(preprocessed) segmentation_mask = self._combine_masks(segmentation_mask, hair_mask) # Stage 4: Matting self._update_stage(PipelineStage.MATTING) matting_result = self.alpha_matting.process( preprocessed, segmentation_mask, method=self.config.matting_method ) alpha_matte = matting_result['alpha'] result.metadata['matting_confidence'] = matting_result['confidence'] # Stage 5: Refinement self._update_stage(PipelineStage.REFINEMENT) if self.config.enable_edge_refinement: alpha_matte = self.edge_refinement.refine_edges( preprocessed, (alpha_matte * 255).astype(np.uint8) ) / 255.0 # Resize alpha to original size if needed if preprocessed.shape[:2] != image_array.shape[:2]: alpha_matte = cv2.resize( alpha_matte, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_LINEAR ) # Extract foreground foreground = self._extract_foreground(image_array, alpha_matte) # Stage 6: Background & Effects self._update_stage(PipelineStage.EFFECTS) if bg_array is not None: # Resize background to match image bg_array = self._resize_background(bg_array, image_array.shape[:2]) # Apply background effects if self.config.background_blur: bg_array = self.background_effects.apply_blur( bg_array, strength=self.config.blur_strength, mask=1 - alpha_matte ) # Apply configured effects if self.config.apply_effects: bg_array = self._apply_effects(bg_array, alpha_matte) else: # Create transparent background bg_array = np.zeros_like(image_array) # Stage 7: Compositing self._update_stage(PipelineStage.COMPOSITING) if self.config.apply_effects and EffectType.LIGHT_WRAP in self.config.apply_effects: foreground = self.background_effects.apply_light_wrap( foreground, bg_array, alpha_matte ) composited = self.compositing_engine.composite( foreground, bg_array, alpha_matte ) # Apply post-composite effects if self.config.apply_effects: composited = self._apply_post_effects(composited, alpha_matte) # Stage 8: Postprocessing self._update_stage(PipelineStage.POSTPROCESSING) final_output = self._postprocess_image(composited, alpha_matte) # Calculate quality score result.quality_score = self._calculate_quality_score( final_output, alpha_matte, quality_metrics ) # Build result result.success = True result.output_image = final_output result.alpha_matte = alpha_matte result.foreground = foreground result.background = bg_array result.stages_completed = list(PipelineStage) result.processing_time = time.time() - start_time # Cache result if self.config.enable_caching: self._cache_result(cache_key, result) # Complete self._update_stage(PipelineStage.COMPLETE) self.logger.info(f"Processing completed in {result.processing_time:.2f}s") # Update statistics self._update_statistics(result) except Exception as e: self.logger.error(f"Pipeline processing failed: {e}") result.errors.append(str(e)) if self.config.enable_fallback and self.fallback_strategy: self.logger.info("Attempting fallback processing") result = self._fallback_processing(image_array, bg_array) finally: self.is_processing = False return result def _preprocess_image(self, image: np.ndarray) -> np.ndarray: """Preprocess image for optimal processing.""" processed = image.copy() # Resize if needed if self.config.target_resolution: target_h, target_w = self.config.target_resolution h, w = image.shape[:2] if self.config.maintain_aspect_ratio: scale = min(target_w / w, target_h / h) new_w = int(w * scale) new_h = int(h * scale) else: new_w, new_h = target_w, target_h if (new_w, new_h) != (w, h): processed = cv2.resize(processed, (new_w, new_h), interpolation=cv2.INTER_AREA) # Apply quality-based preprocessing if self.config.quality_preset == "low": # Reduce noise for faster processing processed = cv2.fastNlMeansDenoising(processed, None, 10, 7, 21) elif self.config.quality_preset in ["high", "ultra"]: # Enhance details processed = cv2.detailEnhance(processed, sigma_s=10, sigma_r=0.15) return processed def _segment_image(self, image: np.ndarray) -> np.ndarray: """Perform image segmentation.""" try: # Use the loaded model for segmentation with torch.no_grad(): # Prepare input input_tensor = self._prepare_input_tensor(image) # Run inference output = self.model(input_tensor) # Process output if isinstance(output, tuple): output = output[0] # Convert to numpy mask mask = output.squeeze().cpu().numpy() # Threshold and convert to uint8 mask = (mask > 0.5).astype(np.uint8) * 255 # Resize to original size if needed if mask.shape[:2] != image.shape[:2]: mask = cv2.resize(mask, (image.shape[1], image.shape[0])) return mask except Exception as e: self.logger.error(f"Segmentation failed: {e}") if self.config.enable_fallback: # Use basic segmentation as fallback from ..processing.fallback import ProcessingFallback fallback = ProcessingFallback() return fallback.basic_segmentation(image) raise def _prepare_input_tensor(self, image: np.ndarray) -> torch.Tensor: """Prepare image tensor for model input.""" # Resize to model input size (typically 512x512 or 1024x1024) model_size = 512 # Default, should be from model config resized = cv2.resize(image, (model_size, model_size)) # Convert to tensor tensor = torch.from_numpy(resized.transpose(2, 0, 1)).float() tensor = tensor.unsqueeze(0) / 255.0 # Move to device tensor = tensor.to(self.device_manager.get_device()) return tensor def _combine_masks(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: """Combine two masks intelligently.""" # Convert to float for blending m1 = mask1.astype(np.float32) / 255.0 m2 = mask2.astype(np.float32) / 255.0 # Combine using maximum (union) combined = np.maximum(m1, m2) # Convert back to uint8 return (combined * 255).astype(np.uint8) def _extract_foreground(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: """Extract foreground using alpha matte.""" if len(alpha.shape) == 2: alpha = np.expand_dims(alpha, axis=2) if alpha.shape[2] == 1: alpha = np.repeat(alpha, 3, axis=2) # Premultiply alpha foreground = image.astype(np.float32) * alpha return foreground.astype(np.uint8) def _resize_background(self, background: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: """Resize background to match target shape.""" h, w = target_shape bg_h, bg_w = background.shape[:2] if (bg_h, bg_w) == (h, w): return background # Calculate scale to cover entire image scale = max(h / bg_h, w / bg_w) new_h = int(bg_h * scale) new_w = int(bg_w * scale) # Resize resized = cv2.resize(background, (new_w, new_h), interpolation=cv2.INTER_LINEAR) # Center crop start_y = (new_h - h) // 2 start_x = (new_w - w) // 2 cropped = resized[start_y:start_y + h, start_x:start_x + w] return cropped def _apply_effects(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: """Apply configured effects to image.""" result = image.copy() for effect in self.config.apply_effects: if effect == EffectType.BOKEH: result = self.background_effects.apply_bokeh(result) elif effect == EffectType.VIGNETTE: result = self.background_effects.add_vignette(result) elif effect == EffectType.FILM_GRAIN: result = self.background_effects.add_film_grain(result) return result def _apply_post_effects(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: """Apply post-composite effects.""" result = image.copy() for effect in self.config.apply_effects: if effect == EffectType.SHADOW: result = self.background_effects.add_shadow(result, mask) elif effect == EffectType.REFLECTION: result = self.background_effects.add_reflection(result, mask) elif effect == EffectType.GLOW: result = self.background_effects.add_glow(result, mask) elif effect == EffectType.CHROMATIC_ABERRATION: result = self.background_effects.chromatic_aberration(result) return result def _postprocess_image(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray: """Apply final postprocessing.""" result = image.copy() # Color correction based on quality preset if self.config.quality_preset in ["high", "ultra"]: # Auto color balance lab = cv2.cvtColor(result, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) l = cv2.equalizeHist(l) result = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR) # Sharpen if ultra quality if self.config.quality_preset == "ultra": kernel = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]]) result = cv2.filter2D(result, -1, kernel) return result def _calculate_quality_score(self, image: np.ndarray, alpha: np.ndarray, metrics: Dict) -> float: """Calculate overall quality score.""" scores = [] # Edge quality edge_score = metrics.get('edge_clarity', 0.5) scores.append(edge_score) # Alpha matte quality (contrast) alpha_std = np.std(alpha) alpha_score = min(alpha_std * 2, 1.0) # Higher std = better separation scores.append(alpha_score) # Overall image quality quality_score = metrics.get('overall_quality', 0.5) scores.append(quality_score) return np.mean(scores) def _load_image(self, source: Union[np.ndarray, str, Path]) -> np.ndarray: """Load image from various sources.""" if isinstance(source, np.ndarray): return source path = Path(source) if not isinstance(source, Path) else source if not path.exists(): raise FileNotFoundError(f"Image not found: {path}") image = cv2.imread(str(path)) if image is None: raise ValueError(f"Failed to load image: {path}") return image def _generate_cache_key(self, image: np.ndarray, params: Dict) -> str: """Generate cache key for result.""" # Create hash from image and parameters hasher = hashlib.md5() hasher.update(image.tobytes()) hasher.update(json.dumps(params, sort_keys=True).encode()) return hasher.hexdigest() def _cache_result(self, key: str, result: PipelineResult): """Cache processing result.""" self.cache[key] = result # Limit cache size cache_memory = sum( r.output_image.nbytes if r.output_image is not None else 0 for r in self.cache.values() ) max_bytes = self.config.cache_size_mb * 1024 * 1024 if cache_memory > max_bytes: # Remove oldest entries for old_key in list(self.cache.keys())[:len(self.cache)//4]: del self.cache[old_key] def _update_stage(self, stage: PipelineStage): """Update current processing stage.""" self.current_stage = stage if self.config.stage_callback: self.config.stage_callback(stage, { 'timestamp': time.time(), 'memory_usage': self.memory_monitor.get_usage() }) if self.config.progress_callback: progress = list(PipelineStage).index(stage) / len(PipelineStage) self.config.progress_callback(progress, stage.value) def _update_statistics(self, result: PipelineResult): """Update processing statistics.""" if 'total_processed' not in self.processing_stats: self.processing_stats['total_processed'] = 0 self.processing_stats['total_time'] = 0 self.processing_stats['avg_quality'] = 0 self.processing_stats['total_processed'] += 1 self.processing_stats['total_time'] += result.processing_time self.processing_stats['avg_time'] = ( self.processing_stats['total_time'] / self.processing_stats['total_processed'] ) # Update average quality n = self.processing_stats['total_processed'] old_avg = self.processing_stats['avg_quality'] self.processing_stats['avg_quality'] = ( (old_avg * (n - 1) + result.quality_score) / n ) def _fallback_processing(self, image: np.ndarray, background: Optional[np.ndarray]) -> PipelineResult: """Fallback processing when main pipeline fails.""" from ..processing.fallback import ProcessingFallback result = PipelineResult(success=False) fallback = ProcessingFallback() try: # Basic segmentation mask = fallback.basic_segmentation(image) # Basic matting alpha = fallback.basic_matting(image, mask) # Simple composite if background provided if background is not None: background = self._resize_background(background, image.shape[:2]) output = self.compositing_engine.composite( image, background, alpha ) else: output = image result.success = True result.output_image = output result.alpha_matte = alpha result.metadata['fallback_used'] = True except Exception as e: self.logger.error(f"Fallback processing also failed: {e}") result.errors.append(str(e)) return result def process_batch(self, images: List[Union[np.ndarray, str, Path]], background: Optional[Union[np.ndarray, str, Path]] = None, **kwargs) -> List[PipelineResult]: """ Process multiple images in batch. Args: images: List of input images background: Optional background for all images **kwargs: Additional processing parameters Returns: List of PipelineResults """ results = [] total = len(images) self.logger.info(f"Processing batch of {total} images") # Process in parallel using thread pool futures = [] for i, image in enumerate(images): future = self.executor.submit( self.process_image, image, background, **kwargs ) futures.append(future) # Collect results for i, future in enumerate(futures): try: result = future.result(timeout=30) results.append(result) if self.config.progress_callback: progress = (i + 1) / total self.config.progress_callback( progress, f"Processed {i + 1}/{total}" ) except Exception as e: self.logger.error(f"Batch item {i} failed: {e}") results.append(PipelineResult( success=False, errors=[str(e)] )) return results def get_statistics(self) -> Dict[str, Any]: """Get processing statistics.""" return { **self.processing_stats, 'cache_size': len(self.cache), 'current_stage': self.current_stage.value, 'is_processing': self.is_processing, 'device': str(self.device_manager.get_device()), 'model_type': self.config.model_type.value } def clear_cache(self): """Clear the result cache.""" self.cache.clear() self.logger.info("Cache cleared") def shutdown(self): """Shutdown the pipeline and cleanup resources.""" self.executor.shutdown(wait=True) self.clear_cache() # Cleanup models if hasattr(self, 'model'): del self.model torch.cuda.empty_cache() self.logger.info("Pipeline shutdown complete")