# models/wrappers/matanyone_wrapper.py import torch import torch.nn.functional as F from typing import Optional, Dict, Any, Tuple, Union import numpy as np class MatAnyOneWrapper: def __init__(self, core, device=None, config=None): """ Initialize MatAnyone wrapper with enhanced configuration. Args: core: MatAnyone InferenceCore instance device: torch device (auto-detect if None) config: Optional configuration dict for processing parameters """ self.core = core self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.config = config or {} # Default processing parameters self.threshold = self.config.get('threshold', 0.5) self.edge_refinement = self.config.get('edge_refinement', True) self.hair_refinement = self.config.get('hair_refinement', True) # Component weights for multi-layer processing self.component_weights = self.config.get('component_weights', { 'base': 1.0, 'hair': 1.2, 'edge': 1.5, 'detail': 1.1 }) # Initialize model try: self.core.model.to(self.device) except Exception: pass try: self.core.model.eval() except Exception: pass @torch.inference_mode() def step(self, image_tensor: torch.Tensor, mask_tensor: Optional[torch.Tensor] = None, objects: Optional[Dict] = None, first_frame_pred: bool = False, components: Optional[Dict[str, torch.Tensor]] = None, **kwargs) -> torch.Tensor: """ Process a single frame with optional component masks. Args: image_tensor: (1,3,H,W) float32 [0..1] on self.device mask_tensor: (1,1,H,W) float32 [0..1] on self.device objects: Optional object tracking info first_frame_pred: Whether this is the first frame components: Optional dict with keys like 'hair', 'edge', 'detail' Each value is a (1,1,H,W) tensor **kwargs: Additional arguments for InferenceCore Returns: (1,1,H,W) float32 probabilities in [0..1] """ # Ensure everything is on the correct device image_tensor = image_tensor.to(self.device, non_blocking=True) if mask_tensor is not None: mask_tensor = mask_tensor.to(self.device, non_blocking=True) # Process component masks if provided if components: components = { k: v.to(self.device, non_blocking=True) for k, v in components.items() } # Main inference call try: # Adapt to actual InferenceCore API out = self.core.step( image_tensor=image_tensor, mask_tensor=mask_tensor, first_frame_pred=first_frame_pred, objects=objects, **kwargs ) except TypeError: # Fallback for different API signatures out = self.core.step( frame=image_tensor, mask=mask_tensor, **kwargs ) # Normalize output shape out = self._normalize_output(out) # Apply component-based refinement if available if components: out = self._refine_with_components(out, components) # Apply edge refinement if enabled if self.edge_refinement and mask_tensor is not None: out = self._refine_edges(out, image_tensor, mask_tensor) return out.clamp_(0, 1) def _normalize_output(self, out: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: """Normalize output to (1,1,H,W) tensor.""" if isinstance(out, torch.Tensor): if out.ndim == 3: # (1,H,W) → (1,1,H,W) out = out.unsqueeze(1) elif out.ndim == 2: # (H,W) → (1,1,H,W) out = out.unsqueeze(0).unsqueeze(0) else: out = torch.as_tensor(out, dtype=torch.float32, device=self.device) if out.ndim == 2: out = out.unsqueeze(0).unsqueeze(0) elif out.ndim == 3: out = out.unsqueeze(1) return out def _refine_with_components(self, base_mask: torch.Tensor, components: Dict[str, torch.Tensor]) -> torch.Tensor: """ Refine mask using component layers (hair, edge, etc). Args: base_mask: (1,1,H,W) base alpha mask components: Dict of component masks Returns: Refined (1,1,H,W) mask """ refined = base_mask.clone() # Apply hair refinement if 'hair' in components and self.hair_refinement: hair_mask = components['hair'] weight = self.component_weights.get('hair', 1.0) # Enhance hair regions refined = torch.where( hair_mask > 0.1, torch.maximum(refined, hair_mask * weight), refined ) # Apply edge refinement if 'edge' in components: edge_mask = components['edge'] weight = self.component_weights.get('edge', 1.0) # Sharpen edges refined = self._apply_edge_enhancement(refined, edge_mask, weight) # Apply detail mask if available if 'detail' in components: detail_mask = components['detail'] weight = self.component_weights.get('detail', 1.0) refined = refined * (1 - detail_mask) + detail_mask * weight return refined.clamp_(0, 1) def _refine_edges(self, mask: torch.Tensor, image: torch.Tensor, reference_mask: torch.Tensor) -> torch.Tensor: """ Apply edge refinement using image gradients. Args: mask: (1,1,H,W) mask to refine image: (1,3,H,W) source image reference_mask: (1,1,H,W) reference mask Returns: Edge-refined mask """ # Compute image gradients for edge detection gray = image.mean(dim=1, keepdim=True) # Sobel filters for edge detection sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=self.device) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=self.device) sobel_x = sobel_x.view(1, 1, 3, 3) sobel_y = sobel_y.view(1, 1, 3, 3) # Apply Sobel filters edge_x = F.conv2d(gray, sobel_x, padding=1) edge_y = F.conv2d(gray, sobel_y, padding=1) edges = torch.sqrt(edge_x**2 + edge_y**2) # Normalize edges edges = edges / (edges.max() + 1e-7) # Apply edge-aware smoothing kernel_size = 3 refined = F.avg_pool2d(mask, kernel_size, stride=1, padding=1) # Blend based on edge strength alpha = 1 - edges * 0.5 refined = mask * alpha + refined * (1 - alpha) return refined.clamp_(0, 1) def _apply_edge_enhancement(self, mask: torch.Tensor, edge_mask: torch.Tensor, weight: float) -> torch.Tensor: """Apply edge enhancement using edge mask.""" # Dilate edges slightly for smoother boundaries kernel = torch.ones(1, 1, 3, 3, device=self.device) / 9 dilated_edges = F.conv2d(edge_mask, kernel, padding=1) # Enhance edges enhanced = torch.where( dilated_edges > 0.1, torch.maximum(mask, dilated_edges * weight), mask ) return enhanced def process_batch(self, images: torch.Tensor, masks: Optional[torch.Tensor] = None, components_batch: Optional[Dict[str, torch.Tensor]] = None, **kwargs) -> torch.Tensor: """ Process a batch of frames. Args: images: (B,3,H,W) batch of images masks: Optional (B,1,H,W) batch of masks components_batch: Optional dict of component batches **kwargs: Additional arguments Returns: (B,1,H,W) batch of refined masks """ batch_size = images.shape[0] results = [] for i in range(batch_size): image = images[i:i+1] mask = masks[i:i+1] if masks is not None else None # Extract components for this frame components = None if components_batch: components = { k: v[i:i+1] for k, v in components_batch.items() } # Process frame result = self.step( image, mask, components=components, first_frame_pred=(i == 0), **kwargs ) results.append(result) return torch.cat(results, dim=0) def output_prob_to_mask(self, prob: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: """Convert probability map to binary mask.""" if isinstance(prob, torch.Tensor): return (prob > self.threshold).float() t = torch.as_tensor(prob, device=self.device) return (t > self.threshold).float() def apply_morphology(self, mask: torch.Tensor, operation: str = 'close', kernel_size: int = 5) -> torch.Tensor: """ Apply morphological operations to clean up mask. Args: mask: Binary mask tensor operation: 'close', 'open', 'dilate', or 'erode' kernel_size: Size of morphological kernel Returns: Processed mask """ kernel = torch.ones(1, 1, kernel_size, kernel_size, device=self.device) if operation in ['close', 'dilate']: # Dilation mask = F.conv2d(mask, kernel, padding=kernel_size//2) mask = (mask > 0).float() if operation in ['close', 'erode']: # Erosion mask = F.conv2d(mask, kernel, padding=kernel_size//2) mask = (mask >= kernel_size**2).float() if operation == 'open': # Erosion followed by dilation mask = F.conv2d(mask, kernel, padding=kernel_size//2) mask = (mask >= kernel_size**2).float() mask = F.conv2d(mask, kernel, padding=kernel_size//2) mask = (mask > 0).float() return mask def get_alpha_matte(self, image: torch.Tensor, mask: torch.Tensor, trimap: Optional[torch.Tensor] = None) -> torch.Tensor: """ Get alpha matte with optional trimap refinement. Args: image: (1,3,H,W) RGB image mask: (1,1,H,W) initial mask trimap: Optional (1,1,H,W) trimap (0=bg, 0.5=unknown, 1=fg) Returns: (1,1,H,W) refined alpha matte """ # Process through MatAnyone alpha = self.step(image, mask) # Apply trimap constraints if provided if trimap is not None: alpha = torch.where(trimap == 0, torch.zeros_like(alpha), alpha) alpha = torch.where(trimap == 1, torch.ones_like(alpha), alpha) return alpha def composite(self, foreground: torch.Tensor, background: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: """ Composite foreground over background using alpha. Args: foreground: (1,3,H,W) foreground image background: (1,3,H,W) background image alpha: (1,1,H,W) alpha matte Returns: (1,3,H,W) composited image """ return foreground * alpha + background * (1 - alpha)