""" Edge processing and symmetry correction for BackgroundFX Pro. Fixes hair segmentation asymmetry and improves edge quality. """ import numpy as np import cv2 import torch import torch.nn.functional as F from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from scipy import ndimage, signal from scipy.spatial import distance import logging logger = logging.getLogger(__name__) @dataclass class EdgeConfig: """Configuration for edge processing.""" edge_thickness: int = 3 smoothing_iterations: int = 2 symmetry_threshold: float = 0.3 hair_detection_sensitivity: float = 0.7 refinement_radius: int = 5 use_guided_filter: bool = True bilateral_d: int = 9 bilateral_sigma_color: float = 75 bilateral_sigma_space: float = 75 morphology_kernel_size: int = 5 edge_preservation_weight: float = 0.8 class EdgeProcessor: """Main edge processing and refinement system.""" def __init__(self, config: Optional[EdgeConfig] = None): self.config = config or EdgeConfig() self.hair_segmentation = HairSegmentation(config) self.edge_refinement = EdgeRefinement(config) self.symmetry_corrector = SymmetryCorrector(config) def process(self, image: np.ndarray, mask: np.ndarray, detect_hair: bool = True) -> np.ndarray: """Process edges with full pipeline.""" # 1. Initial edge detection edges = self._detect_edges(mask) # 2. Hair-specific processing if detect_hair: hair_mask = self.hair_segmentation.segment(image, mask) mask = self._blend_hair_mask(mask, hair_mask) # 3. Symmetry correction mask = self.symmetry_corrector.correct(mask, image) # 4. Edge refinement mask = self.edge_refinement.refine(image, mask, edges) # 5. Final smoothing mask = self._final_smoothing(mask) return mask def _detect_edges(self, mask: np.ndarray) -> np.ndarray: """Detect edges in mask.""" # Convert to uint8 mask_uint8 = (mask * 255).astype(np.uint8) # Multi-scale edge detection edges1 = cv2.Canny(mask_uint8, 50, 150) edges2 = cv2.Canny(mask_uint8, 30, 100) edges3 = cv2.Canny(mask_uint8, 70, 200) # Combine edges edges = np.maximum(edges1, np.maximum(edges2, edges3)) return edges / 255.0 def _blend_hair_mask(self, original_mask: np.ndarray, hair_mask: np.ndarray) -> np.ndarray: """Blend hair mask with original mask.""" # Find hair regions hair_regions = hair_mask > 0.5 # Smooth blending alpha = 0.7 # Hair mask weight blended = original_mask.copy() blended[hair_regions] = ( alpha * hair_mask[hair_regions] + (1 - alpha) * original_mask[hair_regions] ) return blended def _final_smoothing(self, mask: np.ndarray) -> np.ndarray: """Apply final smoothing pass.""" # Guided filter for edge-preserving smoothing if self.config.use_guided_filter: mask = self._guided_filter(mask, mask) # Morphological smoothing kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (self.config.morphology_kernel_size, self.config.morphology_kernel_size) ) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) return mask def _guided_filter(self, input_img: np.ndarray, guidance: np.ndarray, radius: int = 4, epsilon: float = 0.2**2) -> np.ndarray: """Apply guided filter for edge-preserving smoothing.""" # Implementation of guided filter mean_I = cv2.boxFilter(guidance, cv2.CV_64F, (radius, radius)) mean_p = cv2.boxFilter(input_img, cv2.CV_64F, (radius, radius)) mean_Ip = cv2.boxFilter(guidance * input_img, cv2.CV_64F, (radius, radius)) cov_Ip = mean_Ip - mean_I * mean_p mean_II = cv2.boxFilter(guidance * guidance, cv2.CV_64F, (radius, radius)) var_I = mean_II - mean_I * mean_I a = cov_Ip / (var_I + epsilon) b = mean_p - a * mean_I mean_a = cv2.boxFilter(a, cv2.CV_64F, (radius, radius)) mean_b = cv2.boxFilter(b, cv2.CV_64F, (radius, radius)) q = mean_a * guidance + mean_b return q class HairSegmentation: """Specialized hair segmentation module.""" def __init__(self, config: EdgeConfig): self.config = config self.hair_detector = HairDetector() def segment(self, image: np.ndarray, initial_mask: np.ndarray) -> np.ndarray: """Segment hair regions with improved accuracy.""" # 1. Detect hair regions hair_probability = self.hair_detector.detect(image) # 2. Refine with initial mask hair_mask = self._refine_with_mask(hair_probability, initial_mask) # 3. Fix asymmetry specific to hair hair_mask = self._fix_hair_asymmetry(hair_mask, image) # 4. Enhance hair strands hair_mask = self._enhance_hair_strands(hair_mask, image) return hair_mask def _refine_with_mask(self, hair_prob: np.ndarray, initial_mask: np.ndarray) -> np.ndarray: """Refine hair probability with initial mask.""" # Only keep hair within or near initial mask kernel = np.ones((15, 15), np.uint8) dilated_mask = cv2.dilate(initial_mask, kernel, iterations=2) # Combine probabilities refined = hair_prob * dilated_mask # Threshold threshold = self.config.hair_detection_sensitivity hair_mask = (refined > threshold).astype(np.float32) # Smooth hair_mask = cv2.GaussianBlur(hair_mask, (5, 5), 1.0) return hair_mask def _fix_hair_asymmetry(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: """Fix asymmetry in hair segmentation.""" h, w = mask.shape[:2] center_x = w // 2 # Split mask into left and right left_mask = mask[:, :center_x] right_mask = mask[:, center_x:] # Flip right for comparison right_flipped = np.fliplr(right_mask) # Compute difference if left_mask.shape[1] == right_flipped.shape[1]: diff = np.abs(left_mask - right_flipped) asymmetry_score = np.mean(diff) if asymmetry_score > self.config.symmetry_threshold: logger.info(f"Detected hair asymmetry: {asymmetry_score:.3f}") # Balance the masks balanced_left = 0.5 * left_mask + 0.5 * right_flipped balanced_right = np.fliplr(0.5 * right_mask + 0.5 * np.fliplr(left_mask)) # Reconstruct mask[:, :center_x] = balanced_left mask[:, center_x:center_x + balanced_right.shape[1]] = balanced_right return mask def _enhance_hair_strands(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: """Enhance fine hair strands.""" # Convert image to grayscale gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image # Detect fine structures using Gabor filters enhanced_mask = mask.copy() # Multiple orientations for Gabor filters orientations = [0, 45, 90, 135] gabor_responses = [] for angle in orientations: theta = np.deg2rad(angle) kernel = cv2.getGaborKernel( (21, 21), 4.0, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F ) filtered = cv2.filter2D(gray, cv2.CV_32F, kernel) gabor_responses.append(np.abs(filtered)) # Combine Gabor responses gabor_max = np.max(gabor_responses, axis=0) gabor_normalized = gabor_max / (np.max(gabor_max) + 1e-6) # Enhance mask in high-response areas hair_enhancement = gabor_normalized * (1 - mask) enhanced_mask = np.clip(mask + 0.3 * hair_enhancement, 0, 1) return enhanced_mask class HairDetector: """Detects hair regions in images.""" def detect(self, image: np.ndarray) -> np.ndarray: """Detect hair probability map.""" # Convert to appropriate color spaces hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) # Hair color detection in HSV hair_colors = [ # Black hair ((0, 0, 0), (180, 255, 30)), # Brown hair ((10, 20, 20), (20, 255, 100)), # Blonde hair ((15, 30, 50), (25, 255, 200)), # Red hair ((0, 50, 50), (10, 255, 150)), ] hair_masks = [] for (lower, upper) in hair_colors: mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) hair_masks.append(mask) # Combine color masks color_mask = np.max(hair_masks, axis=0) / 255.0 # Texture analysis for hair-like patterns texture_mask = self._detect_hair_texture(image) # Combine color and texture hair_probability = 0.6 * color_mask + 0.4 * texture_mask # Smooth the probability map hair_probability = cv2.GaussianBlur(hair_probability, (7, 7), 2.0) return hair_probability def _detect_hair_texture(self, image: np.ndarray) -> np.ndarray: """Detect hair-like texture patterns.""" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image # Compute texture features using LBP-like approach texture_score = np.zeros_like(gray, dtype=np.float32) # Directional derivatives dx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) dy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) # Gradient magnitude and orientation magnitude = np.sqrt(dx**2 + dy**2) orientation = np.arctan2(dy, dx) # Hair tends to have consistent local orientation # Compute local orientation consistency window_size = 9 kernel = np.ones((window_size, window_size)) / (window_size**2) # Local orientation variance (low variance = consistent = hair-like) orient_mean = cv2.filter2D(orientation, -1, kernel) orient_sq_mean = cv2.filter2D(orientation**2, -1, kernel) orient_var = orient_sq_mean - orient_mean**2 # Low variance and high magnitude indicates hair texture_score = magnitude * np.exp(-orient_var) # Normalize texture_score = texture_score / (np.max(texture_score) + 1e-6) return texture_score class EdgeRefinement: """Refines edges for better quality.""" def __init__(self, config: EdgeConfig): self.config = config def refine(self, image: np.ndarray, mask: np.ndarray, edges: np.ndarray) -> np.ndarray: """Refine mask edges.""" # 1. Bilateral filtering for edge-aware smoothing refined = self._bilateral_smooth(mask, image) # 2. Snap to image edges refined = self._snap_to_edges(refined, image, edges) # 3. Subpixel refinement refined = self._subpixel_refinement(refined, image) # 4. Feathering refined = self._apply_feathering(refined) return refined def _bilateral_smooth(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: """Apply bilateral filtering for edge-aware smoothing.""" # Convert mask to uint8 for bilateral filter mask_uint8 = (mask * 255).astype(np.uint8) # Apply bilateral filter smoothed = cv2.bilateralFilter( mask_uint8, self.config.bilateral_d, self.config.bilateral_sigma_color, self.config.bilateral_sigma_space ) return smoothed / 255.0 def _snap_to_edges(self, mask: np.ndarray, image: np.ndarray, detected_edges: np.ndarray) -> np.ndarray: """Snap mask boundaries to image edges.""" # Detect strong edges in image gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image image_edges = cv2.Canny(gray, 50, 150) / 255.0 # Find mask edges mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) / 255.0 # Distance transform from image edges dist_transform = cv2.distanceTransform( (1 - image_edges).astype(np.uint8), cv2.DIST_L2, 5 ) # Snap mask edges to nearby image edges snap_radius = self.config.refinement_radius refined = mask.copy() # For pixels near mask edges edge_region = cv2.dilate(mask_edges, np.ones((5, 5))) > 0 # If close to image edge, strengthen the mask edge close_to_image_edge = (dist_transform < snap_radius) & edge_region refined[close_to_image_edge] = np.where( mask[close_to_image_edge] > 0.5, 1.0, 0.0 ) return refined def _subpixel_refinement(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: """Apply subpixel refinement to edges.""" # Use image gradient for subpixel accuracy gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image # Compute gradients grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) grad_mag = np.sqrt(grad_x**2 + grad_y**2) # Normalize gradient grad_mag = grad_mag / (np.max(grad_mag) + 1e-6) # Refine mask edges based on gradient # Strong gradients push toward binary values refined = mask.copy() strong_gradient = grad_mag > 0.3 refined[strong_gradient] = np.where( mask[strong_gradient] > 0.5, np.minimum(mask[strong_gradient] + 0.1, 1.0), np.maximum(mask[strong_gradient] - 0.1, 0.0) ) return refined def _apply_feathering(self, mask: np.ndarray, radius: int = 3) -> np.ndarray: """Apply feathering to edges.""" # Distance transform from edges mask_binary = (mask > 0.5).astype(np.uint8) # Distance from outside dist_outside = cv2.distanceTransform( mask_binary, cv2.DIST_L2, 5 ) # Distance from inside dist_inside = cv2.distanceTransform( 1 - mask_binary, cv2.DIST_L2, 5 ) # Create feathering feather_region = (dist_outside <= radius) | (dist_inside <= radius) if np.any(feather_region): # Smooth transition in feather region alpha = np.zeros_like(mask) alpha[dist_outside > radius] = 1.0 alpha[feather_region] = dist_outside[feather_region] / radius # Blend mask = mask * (1 - feather_region) + alpha * feather_region return mask class SymmetryCorrector: """Corrects asymmetry in masks.""" def __init__(self, config: EdgeConfig): self.config = config def correct(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray: """Correct asymmetry in mask.""" # Detect face/object center center = self._find_center(mask) # Check asymmetry asymmetry_score = self._compute_asymmetry(mask, center) if asymmetry_score > self.config.symmetry_threshold: logger.info(f"Correcting asymmetry: {asymmetry_score:.3f}") mask = self._balance_mask(mask, center) return mask def _find_center(self, mask: np.ndarray) -> int: """Find vertical center of object.""" # Use center of mass mask_binary = (mask > 0.5).astype(np.uint8) moments = cv2.moments(mask_binary) if moments['m00'] > 0: cx = int(moments['m10'] / moments['m00']) return cx else: return mask.shape[1] // 2 def _compute_asymmetry(self, mask: np.ndarray, center: int) -> float: """Compute asymmetry score.""" h, w = mask.shape[:2] # Split at center left_width = center right_width = w - center min_width = min(left_width, right_width) if min_width <= 0: return 0.0 # Compare left and right left = mask[:, center-min_width:center] right = mask[:, center:center+min_width] # Flip right for comparison right_flipped = np.fliplr(right) # Compute difference diff = np.abs(left - right_flipped) asymmetry = np.mean(diff) return asymmetry def _balance_mask(self, mask: np.ndarray, center: int) -> np.ndarray: """Balance mask to reduce asymmetry.""" h, w = mask.shape[:2] balanced = mask.copy() # Split at center left_width = center right_width = w - center min_width = min(left_width, right_width) if min_width <= 0: return mask # Get regions left = mask[:, center-min_width:center] right = mask[:, center:center+min_width] # Weight based on confidence (higher values = more confident) left_confidence = np.mean(np.abs(left - 0.5)) right_confidence = np.mean(np.abs(right - 0.5)) # Weighted average favoring more confident side total_conf = left_confidence + right_confidence + 1e-6 left_weight = left_confidence / total_conf right_weight = right_confidence / total_conf # Balance balanced_left = left_weight * left + right_weight * np.fliplr(right) balanced_right = right_weight * right + left_weight * np.fliplr(left) # Apply balanced versions balanced[:, center-min_width:center] = balanced_left balanced[:, center:center+min_width] = balanced_right # Smooth the center seam seam_width = 5 seam_start = max(0, center - seam_width) seam_end = min(w, center + seam_width) balanced[:, seam_start:seam_end] = cv2.GaussianBlur( balanced[:, seam_start:seam_end], (5, 1), 1.0 ) return balanced # Export classes __all__ = [ 'EdgeProcessor', 'EdgeConfig', 'HairSegmentation', 'EdgeRefinement', 'SymmetryCorrector', 'HairDetector' ]