MogensR commited on
Commit
d51bd78
·
1 Parent(s): 30d0dd7

Update core/models.py

Browse files
Files changed (1) hide show
  1. core/models.py +255 -259
core/models.py CHANGED
@@ -1,27 +1,36 @@
 
1
  """
2
  Model management and optimization for BackgroundFX Pro.
3
  Fixes MatAnyone quality issues and manages model loading.
4
  """
5
 
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from typing import Dict, Any, Optional, Tuple, List
10
  from dataclasses import dataclass
11
- import numpy as np
 
12
  from pathlib import Path
13
- import logging
 
14
  import gc
15
- from functools import lru_cache
16
  import warnings
17
 
 
 
 
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
 
 
 
 
21
  @dataclass
22
  class ModelConfig:
23
  """Configuration for model management."""
24
  sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt"
 
25
  matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth"
26
  device: str = "cuda"
27
  dtype: torch.dtype = torch.float16
@@ -36,100 +45,101 @@ class ModelConfig:
36
 
37
  class ModelCache:
38
  """Intelligent model caching system."""
39
-
40
  def __init__(self, max_size: int = 5):
41
- self.cache = {}
42
  self.max_size = max_size
43
- self.access_count = {}
44
- self.memory_usage = {}
45
-
46
  def add(self, key: str, model: Any, memory_size: float):
47
  """Add model to cache with memory tracking."""
48
- if len(self.cache) >= self.max_size:
49
- # Remove least recently used
50
  lru_key = min(self.access_count, key=self.access_count.get)
51
  self.remove(lru_key)
52
-
53
  self.cache[key] = model
54
  self.access_count[key] = 0
55
  self.memory_usage[key] = memory_size
56
-
57
  def get(self, key: str) -> Optional[Any]:
58
  """Get model from cache."""
59
  if key in self.cache:
60
  self.access_count[key] += 1
61
  return self.cache[key]
62
  return None
63
-
64
  def remove(self, key: str):
65
  """Remove model from cache and free memory."""
66
  if key in self.cache:
67
  model = self.cache[key]
68
  del self.cache[key]
69
- del self.access_count[key]
70
- del self.memory_usage[key]
71
-
72
  # Force cleanup
73
- del model
 
 
 
74
  gc.collect()
75
  if torch.cuda.is_available():
76
  torch.cuda.empty_cache()
77
-
78
  def clear(self):
79
  """Clear entire cache."""
80
- keys = list(self.cache.keys())
81
- for key in keys:
82
  self.remove(key)
83
 
84
 
 
 
 
 
85
  class MatAnyoneModel(nn.Module):
86
  """Enhanced MatAnyone model with quality fixes."""
87
-
88
  def __init__(self, config: ModelConfig):
89
  super().__init__()
90
  self.config = config
91
- self.base_model = None
92
  self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None
93
  self.loaded = False
94
-
95
  def load(self):
96
  """Load MatAnyone model with optimizations."""
97
  if self.loaded:
98
  return
99
-
100
  try:
101
- # Load checkpoint
102
  checkpoint_path = Path(self.config.matanyone_checkpoint)
103
  if not checkpoint_path.exists():
104
  logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}")
105
  return
106
-
107
- # Load model weights
108
- state_dict = torch.load(
109
- checkpoint_path,
110
- map_location=self.config.device
111
- )
112
-
113
- # Initialize base model (placeholder - replace with actual MatAnyone architecture)
114
  self.base_model = self._build_matanyone_architecture()
115
-
116
- # Load weights with compatibility fixes
117
  self._load_weights_safe(state_dict)
118
-
119
- # Optimize model
120
  if self.config.optimize_memory:
121
  self._optimize_model()
122
-
123
  self.loaded = True
124
  logger.info("MatAnyone model loaded successfully")
125
-
126
  except Exception as e:
127
  logger.error(f"Failed to load MatAnyone model: {e}")
128
  self.loaded = False
129
-
130
  def _build_matanyone_architecture(self) -> nn.Module:
131
- """Build MatAnyone architecture."""
132
- # This is a placeholder - replace with actual MatAnyone architecture
133
  class MatAnyoneBase(nn.Module):
134
  def __init__(self):
135
  super().__init__()
@@ -147,300 +157,248 @@ def __init__(self):
147
  nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
148
  nn.ReLU(),
149
  nn.Conv2d(64, 4, 3, padding=1),
150
- nn.Sigmoid()
151
  )
152
-
153
  def forward(self, x):
154
  features = self.encoder(x)
155
  output = self.decoder(features)
156
  return output
157
-
158
- return MatAnyoneBase().to(self.config.device)
159
-
 
 
 
160
  def _load_weights_safe(self, state_dict: Dict):
161
  """Safely load weights with compatibility handling."""
 
 
 
162
  model_dict = self.base_model.state_dict()
163
-
164
- # Filter compatible weights
165
  compatible_dict = {}
166
  for k, v in state_dict.items():
167
- # Remove module prefix if present
168
- if k.startswith('module.'):
169
- k = k[7:]
170
-
171
- if k in model_dict and model_dict[k].shape == v.shape:
172
- compatible_dict[k] = v
173
  else:
174
  logger.warning(f"Skipping incompatible weight: {k}")
175
-
176
- # Load compatible weights
177
  model_dict.update(compatible_dict)
178
  self.base_model.load_state_dict(model_dict, strict=False)
179
-
180
  logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights")
181
-
182
  def _optimize_model(self):
183
  """Optimize model for inference."""
184
- if not self.base_model:
185
  return
186
-
187
  self.base_model.eval()
188
-
189
- # Convert to half precision if using GPU
190
- if self.config.dtype == torch.float16 and self.config.device != "cpu":
191
- self.base_model = self.base_model.half()
192
-
193
- # Disable gradient computation
194
- for param in self.base_model.parameters():
195
- param.requires_grad = False
196
-
197
- # TensorRT optimization (if available)
198
  if self.config.use_tensorrt:
199
  try:
200
  self._optimize_with_tensorrt()
201
  except Exception as e:
202
  logger.warning(f"TensorRT optimization failed: {e}")
203
-
 
 
 
 
204
  def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]:
205
  """Enhanced forward pass with quality fixes."""
206
  if not self.loaded:
207
  self.load()
208
-
209
- if not self.base_model:
210
- return {'alpha': mask, 'foreground': image}
211
-
212
- # Prepare input
213
  x = torch.cat([image, mask.unsqueeze(1)], dim=1)
214
-
215
- # Fix input quality issues
216
  if self.config.matanyone_enhancement:
217
  x = self._preprocess_input(x)
218
-
219
- # Forward pass with mixed precision
220
- with torch.cuda.amp.autocast(enabled=self.config.use_amp):
221
  output = self.base_model(x)
222
-
223
- # Parse output
224
  alpha = output[:, 3:4, :, :]
225
  foreground = output[:, :3, :, :]
226
-
227
- # Apply quality enhancement
228
  if self.quality_enhancer:
229
  alpha = self.quality_enhancer.enhance_alpha(alpha, mask)
230
  foreground = self.quality_enhancer.enhance_foreground(foreground, image)
231
-
232
- # Post-process to fix common MatAnyone issues
233
  alpha = self._fix_matanyone_artifacts(alpha, mask)
234
-
235
  return {
236
- 'alpha': alpha,
237
- 'foreground': foreground,
238
- 'confidence': self._compute_confidence(alpha, mask)
239
  }
240
-
241
  def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
242
  """Preprocess input to improve MatAnyone quality."""
243
- # Denoise input
244
- if x.shape[2] > 64: # Only for reasonable resolutions
245
  x = self._bilateral_filter_torch(x)
246
-
247
- # Normalize properly
248
  x = torch.clamp(x, 0, 1)
249
-
250
- # Enhance edges in mask channel
251
  mask_channel = x[:, 3:4, :, :]
252
  mask_enhanced = self._enhance_mask_edges(mask_channel)
253
  x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1)
254
-
255
  return x
256
-
257
- def _fix_matanyone_artifacts(self, alpha: torch.Tensor,
258
- original_mask: torch.Tensor) -> torch.Tensor:
259
  """Fix common MatAnyone artifacts."""
260
- # Fix edge bleeding
261
  alpha = self._fix_edge_bleeding(alpha, original_mask)
262
-
263
- # Fix transparency issues
264
  alpha = self._fix_transparency_issues(alpha)
265
-
266
- # Ensure consistency with original mask
267
  alpha = self._ensure_mask_consistency(alpha, original_mask)
268
-
269
  return alpha
270
-
271
- def _fix_edge_bleeding(self, alpha: torch.Tensor,
272
- original_mask: torch.Tensor) -> torch.Tensor:
273
  """Fix edge bleeding artifacts."""
274
- # Detect edges
275
  edges = self._detect_edges_torch(original_mask)
276
-
277
- # Create edge mask
278
  edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2)
279
-
280
- # Refine alpha near edges
281
  alpha_refined = alpha.clone()
282
  edge_region = edge_mask > 0.1
283
-
284
- # Apply guided filter near edges
285
  if edge_region.any():
286
  alpha_refined[edge_region] = (
287
- 0.7 * alpha[edge_region] +
288
- 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region]
289
  )
290
-
291
  return alpha_refined
292
-
293
  def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor:
294
  """Fix transparency artifacts."""
295
- # Identify problematic transparency values
296
  mid_range = (alpha > 0.2) & (alpha < 0.8)
297
-
298
- # Push mid-range values toward 0 or 1
299
  alpha_fixed = alpha.clone()
300
  alpha_fixed[mid_range] = torch.where(
301
  alpha[mid_range] > 0.5,
302
  torch.clamp(alpha[mid_range] * 1.2, max=1.0),
303
- torch.clamp(alpha[mid_range] * 0.8, min=0.0)
304
  )
305
-
306
- # Smooth transitions
307
  alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3))
308
-
309
  return alpha_fixed
310
-
311
- def _ensure_mask_consistency(self, alpha: torch.Tensor,
312
- original_mask: torch.Tensor) -> torch.Tensor:
313
  """Ensure consistency with original mask."""
314
- # Expand mask dimensions if needed
315
  if original_mask.dim() == 2:
316
  original_mask = original_mask.unsqueeze(0).unsqueeze(0)
317
  elif original_mask.dim() == 3:
318
  original_mask = original_mask.unsqueeze(1)
319
-
320
- # Where original mask is 0, alpha should also be 0
321
  alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha)
322
-
323
- # Where original mask is 1, alpha should be close to 1
324
  alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha)
325
-
326
  return alpha
327
-
328
- def _compute_confidence(self, alpha: torch.Tensor,
329
- original_mask: torch.Tensor) -> torch.Tensor:
330
  """Compute confidence score for the output."""
331
- # Expand dimensions if needed
332
  if original_mask.dim() < alpha.dim():
333
  original_mask = original_mask.unsqueeze(1).expand_as(alpha)
334
-
335
- # Compute similarity
336
  diff = torch.abs(alpha - original_mask)
337
  confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3))
338
-
339
  return confidence
340
-
341
  def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor:
342
- """Apply bilateral filter in PyTorch."""
343
- # Simple approximation using Gaussian blur
344
- # For true bilateral filtering, would need custom CUDA kernel
345
  return F.gaussian_blur(x, kernel_size=(5, 5))
346
-
347
  def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor:
348
  """Enhance edges in mask channel."""
349
- # Detect edges
350
  edges = self._detect_edges_torch(mask)
351
-
352
- # Enhance mask with edges
353
- enhanced = mask + 0.3 * edges
354
- enhanced = torch.clamp(enhanced, 0, 1)
355
-
356
  return enhanced
357
-
358
  def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor:
359
  """Detect edges using Sobel filters."""
360
- # Sobel kernels
361
- sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
362
- dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
363
- sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
364
- dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
365
-
366
- # Apply Sobel filters
367
  edges_x = F.conv2d(x, sobel_x, padding=1)
368
  edges_y = F.conv2d(x, sobel_y, padding=1)
369
-
370
- # Compute edge magnitude
371
  edges = torch.sqrt(edges_x ** 2 + edges_y ** 2)
372
-
373
  return edges
374
 
375
 
 
 
 
 
376
  class SAM2Model:
377
  """SAM2 model wrapper with optimizations."""
378
-
379
  def __init__(self, config: ModelConfig):
380
  self.config = config
381
  self.model = None
382
  self.predictor = None
383
  self.loaded = False
384
-
385
  def load(self):
386
  """Load SAM2 model."""
387
  if self.loaded:
388
  return
389
-
390
  try:
391
- # Import SAM2 (assuming it's installed)
392
  from sam2.build_sam import build_sam2
393
  from sam2.sam2_image_predictor import SAM2ImagePredictor
394
-
395
- # Build model
396
  self.model = build_sam2(
397
- config_file="sam2_hiera_l.yaml",
398
  ckpt_path=self.config.sam2_checkpoint,
399
- device=self.config.device
400
  )
401
-
402
- # Create predictor
403
  self.predictor = SAM2ImagePredictor(self.model)
404
-
405
  self.loaded = True
406
  logger.info("SAM2 model loaded successfully")
407
-
408
  except Exception as e:
409
  logger.error(f"Failed to load SAM2 model: {e}")
410
  self.loaded = False
411
-
412
  def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray:
413
  """Generate segmentation mask."""
414
  if not self.loaded:
415
  self.load()
416
-
417
- if not self.predictor:
418
  return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
419
-
420
- # Set image
421
  self.predictor.set_image(image)
422
-
423
- # Use prompts if provided, otherwise use automatic segmentation
424
  if prompts:
425
  masks, scores, _ = self.predictor.predict(
426
- point_coords=prompts.get('points'),
427
- point_labels=prompts.get('labels'),
428
- box=prompts.get('box'),
429
- multimask_output=True
430
  )
431
- # Select best mask
432
- mask = masks[np.argmax(scores)]
433
  else:
434
- # Automatic segmentation
435
- masks = self.predictor.generate_auto_masks(image)
436
- mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0])
437
-
 
 
 
 
438
  return mask
439
 
440
 
 
 
 
 
441
  class QualityEnhancer(nn.Module):
442
  """Neural quality enhancement module."""
443
-
444
  def __init__(self):
445
  super().__init__()
446
  self.alpha_refiner = nn.Sequential(
@@ -449,97 +407,87 @@ def __init__(self):
449
  nn.Conv2d(16, 16, 3, padding=1),
450
  nn.ReLU(),
451
  nn.Conv2d(16, 1, 3, padding=1),
452
- nn.Sigmoid()
453
  )
454
-
455
  self.foreground_enhancer = nn.Sequential(
456
  nn.Conv2d(3, 32, 3, padding=1),
457
  nn.ReLU(),
458
  nn.Conv2d(32, 32, 3, padding=1),
459
  nn.ReLU(),
460
  nn.Conv2d(32, 3, 3, padding=1),
461
- nn.Tanh()
462
  )
463
-
464
- def enhance_alpha(self, alpha: torch.Tensor,
465
- original_mask: torch.Tensor) -> torch.Tensor:
466
  """Enhance alpha channel quality."""
467
- # Refine with neural network
468
  refined = self.alpha_refiner(alpha)
469
-
470
- # Blend with original for stability
471
- enhanced = 0.7 * refined + 0.3 * alpha
472
-
473
- return torch.clamp(enhanced, 0, 1)
474
-
475
- def enhance_foreground(self, foreground: torch.Tensor,
476
- original_image: torch.Tensor) -> torch.Tensor:
477
  """Enhance foreground quality."""
478
- # Compute residual
479
  residual = self.foreground_enhancer(foreground)
480
-
481
- # Add residual
482
- enhanced = foreground + 0.1 * residual
483
-
484
- return torch.clamp(enhanced, 0, 1)
485
 
486
 
 
 
 
 
487
  class ModelManager:
488
  """Central model management system."""
489
-
490
  def __init__(self, config: Optional[ModelConfig] = None):
491
  self.config = config or ModelConfig()
492
  self.cache = ModelCache(max_size=self.config.cache_size)
493
- self.models = {}
494
-
495
- # Initialize models
496
  self.sam2 = SAM2Model(self.config)
497
  self.matanyone = MatAnyoneModel(self.config)
498
-
499
  def load_all(self):
500
  """Load all models."""
501
  logger.info("Loading all models...")
502
  self.sam2.load()
503
  self.matanyone.load()
504
  logger.info("All models loaded")
505
-
506
- def get_sam2(self) -> SAM2Model:
507
- """Get SAM2 model."""
508
  if not self.sam2.loaded:
509
  self.sam2.load()
510
  return self.sam2
511
-
512
- def get_matanyone(self) -> MatAnyoneModel:
513
- """Get MatAnyone model."""
514
  if not self.matanyone.loaded:
515
  self.matanyone.load()
516
  return self.matanyone
517
-
518
- def process_frame(self, image: np.ndarray,
519
- mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
520
- """Process single frame through pipeline."""
521
- # Convert to tensor
522
  image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
523
  image_tensor = image_tensor.to(self.config.device)
524
-
525
- # Get or generate mask
526
  if mask is None:
527
  mask = self.sam2.predict(image)
528
-
529
  mask_tensor = torch.from_numpy(mask).float().to(self.config.device)
530
-
531
- # Process with MatAnyone
532
  result = self.matanyone(image_tensor, mask_tensor)
533
-
534
- # Convert back to numpy
535
  output = {
536
- 'alpha': result['alpha'].squeeze().cpu().numpy(),
537
- 'foreground': result['foreground'].squeeze().permute(1, 2, 0).cpu().numpy() * 255,
538
- 'confidence': result['confidence'].cpu().numpy()
539
  }
540
-
541
  return output
542
-
543
  def cleanup(self):
544
  """Cleanup models and free memory."""
545
  self.cache.clear()
@@ -548,12 +496,60 @@ def cleanup(self):
548
  torch.cuda.empty_cache()
549
 
550
 
551
- # Export classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  __all__ = [
553
- 'ModelManager',
554
- 'SAM2Model',
555
- 'MatAnyoneModel',
556
- 'ModelConfig',
557
- 'ModelCache',
558
- 'QualityEnhancer'
559
- ]
 
 
 
1
+ #!/usr/bin/env python3
2
  """
3
  Model management and optimization for BackgroundFX Pro.
4
  Fixes MatAnyone quality issues and manages model loading.
5
  """
6
 
 
 
 
 
7
  from dataclasses import dataclass
8
+ from enum import Enum
9
+ from functools import lru_cache
10
  from pathlib import Path
11
+ from typing import Dict, Any, Optional, Tuple, List
12
+
13
  import gc
14
+ import logging
15
  import warnings
16
 
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
+ # -------------------------------
26
+ # Configuration & Caching
27
+ # -------------------------------
28
+
29
  @dataclass
30
  class ModelConfig:
31
  """Configuration for model management."""
32
  sam2_checkpoint: str = "checkpoints/sam2_hiera_large.pt"
33
+ sam2_config: str = "configs/sam2_hiera_l.yaml" # path to SAM2 config file
34
  matanyone_checkpoint: str = "checkpoints/matanyone_v2.pth"
35
  device: str = "cuda"
36
  dtype: torch.dtype = torch.float16
 
45
 
46
  class ModelCache:
47
  """Intelligent model caching system."""
48
+
49
  def __init__(self, max_size: int = 5):
50
+ self.cache: Dict[str, Any] = {}
51
  self.max_size = max_size
52
+ self.access_count: Dict[str, int] = {}
53
+ self.memory_usage: Dict[str, float] = {}
54
+
55
  def add(self, key: str, model: Any, memory_size: float):
56
  """Add model to cache with memory tracking."""
57
+ if len(self.cache) >= self.max_size and self.access_count:
 
58
  lru_key = min(self.access_count, key=self.access_count.get)
59
  self.remove(lru_key)
60
+
61
  self.cache[key] = model
62
  self.access_count[key] = 0
63
  self.memory_usage[key] = memory_size
64
+
65
  def get(self, key: str) -> Optional[Any]:
66
  """Get model from cache."""
67
  if key in self.cache:
68
  self.access_count[key] += 1
69
  return self.cache[key]
70
  return None
71
+
72
  def remove(self, key: str):
73
  """Remove model from cache and free memory."""
74
  if key in self.cache:
75
  model = self.cache[key]
76
  del self.cache[key]
77
+ self.access_count.pop(key, None)
78
+ self.memory_usage.pop(key, None)
79
+
80
  # Force cleanup
81
+ try:
82
+ del model
83
+ except Exception:
84
+ pass
85
  gc.collect()
86
  if torch.cuda.is_available():
87
  torch.cuda.empty_cache()
88
+
89
  def clear(self):
90
  """Clear entire cache."""
91
+ for key in list(self.cache.keys()):
 
92
  self.remove(key)
93
 
94
 
95
+ # -------------------------------
96
+ # MatAnyone model (enhanced)
97
+ # -------------------------------
98
+
99
  class MatAnyoneModel(nn.Module):
100
  """Enhanced MatAnyone model with quality fixes."""
101
+
102
  def __init__(self, config: ModelConfig):
103
  super().__init__()
104
  self.config = config
105
+ self.base_model: Optional[nn.Module] = None
106
  self.quality_enhancer = QualityEnhancer() if config.enable_quality_fixes else None
107
  self.loaded = False
108
+
109
  def load(self):
110
  """Load MatAnyone model with optimizations."""
111
  if self.loaded:
112
  return
113
+
114
  try:
 
115
  checkpoint_path = Path(self.config.matanyone_checkpoint)
116
  if not checkpoint_path.exists():
117
  logger.warning(f"MatAnyone checkpoint not found at {checkpoint_path}")
118
  return
119
+
120
+ # Load weights
121
+ state_dict = torch.load(checkpoint_path, map_location=self.config.device)
122
+
123
+ # Build model (placeholder architecture)
 
 
 
124
  self.base_model = self._build_matanyone_architecture()
125
+
126
+ # Load filtered weights
127
  self._load_weights_safe(state_dict)
128
+
129
+ # Optimize
130
  if self.config.optimize_memory:
131
  self._optimize_model()
132
+
133
  self.loaded = True
134
  logger.info("MatAnyone model loaded successfully")
135
+
136
  except Exception as e:
137
  logger.error(f"Failed to load MatAnyone model: {e}")
138
  self.loaded = False
139
+
140
  def _build_matanyone_architecture(self) -> nn.Module:
141
+ """Build MatAnyone architecture (placeholder)."""
142
+
143
  class MatAnyoneBase(nn.Module):
144
  def __init__(self):
145
  super().__init__()
 
157
  nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
158
  nn.ReLU(),
159
  nn.Conv2d(64, 4, 3, padding=1),
160
+ nn.Sigmoid(),
161
  )
162
+
163
  def forward(self, x):
164
  features = self.encoder(x)
165
  output = self.decoder(features)
166
  return output
167
+
168
+ model = MatAnyoneBase().to(self.config.device)
169
+ if self.config.dtype == torch.float16 and "cuda" in str(self.config.device).lower() and torch.cuda.is_available():
170
+ model = model.half()
171
+ return model
172
+
173
  def _load_weights_safe(self, state_dict: Dict):
174
  """Safely load weights with compatibility handling."""
175
+ if self.base_model is None:
176
+ return
177
+
178
  model_dict = self.base_model.state_dict()
179
+
 
180
  compatible_dict = {}
181
  for k, v in state_dict.items():
182
+ k_clean = k[7:] if k.startswith("module.") else k
183
+ if k_clean in model_dict and model_dict[k_clean].shape == v.shape:
184
+ compatible_dict[k_clean] = v
 
 
 
185
  else:
186
  logger.warning(f"Skipping incompatible weight: {k}")
187
+
 
188
  model_dict.update(compatible_dict)
189
  self.base_model.load_state_dict(model_dict, strict=False)
 
190
  logger.info(f"Loaded {len(compatible_dict)}/{len(state_dict)} weights")
191
+
192
  def _optimize_model(self):
193
  """Optimize model for inference."""
194
+ if self.base_model is None:
195
  return
196
+
197
  self.base_model.eval()
198
+
199
+ for p in self.base_model.parameters():
200
+ p.requires_grad = False
201
+
 
 
 
 
 
 
202
  if self.config.use_tensorrt:
203
  try:
204
  self._optimize_with_tensorrt()
205
  except Exception as e:
206
  logger.warning(f"TensorRT optimization failed: {e}")
207
+
208
+ def _optimize_with_tensorrt(self):
209
+ """Placeholder for optional TensorRT optimization."""
210
+ raise NotImplementedError("TensorRT path not implemented")
211
+
212
  def forward(self, image: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]:
213
  """Enhanced forward pass with quality fixes."""
214
  if not self.loaded:
215
  self.load()
216
+
217
+ if self.base_model is None:
218
+ return {"alpha": mask.unsqueeze(1), "foreground": image, "confidence": torch.tensor([0.0], device=image.device)}
219
+
220
+ # Concatenate image (3ch) + mask (1ch) => 4ch
221
  x = torch.cat([image, mask.unsqueeze(1)], dim=1)
222
+
223
+ # Quality enhancements
224
  if self.config.matanyone_enhancement:
225
  x = self._preprocess_input(x)
226
+
227
+ amp_enabled = self.config.use_amp and torch.cuda.is_available() and "cuda" in str(self.config.device).lower()
228
+ with torch.cuda.amp.autocast(enabled=amp_enabled):
229
  output = self.base_model(x)
230
+
 
231
  alpha = output[:, 3:4, :, :]
232
  foreground = output[:, :3, :, :]
233
+
 
234
  if self.quality_enhancer:
235
  alpha = self.quality_enhancer.enhance_alpha(alpha, mask)
236
  foreground = self.quality_enhancer.enhance_foreground(foreground, image)
237
+
 
238
  alpha = self._fix_matanyone_artifacts(alpha, mask)
239
+
240
  return {
241
+ "alpha": alpha,
242
+ "foreground": foreground,
243
+ "confidence": self._compute_confidence(alpha, mask),
244
  }
245
+
246
  def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
247
  """Preprocess input to improve MatAnyone quality."""
248
+ if x.shape[2] > 64:
 
249
  x = self._bilateral_filter_torch(x)
 
 
250
  x = torch.clamp(x, 0, 1)
251
+
252
+ # Enhance mask edges (last channel)
253
  mask_channel = x[:, 3:4, :, :]
254
  mask_enhanced = self._enhance_mask_edges(mask_channel)
255
  x = torch.cat([x[:, :3, :, :], mask_enhanced], dim=1)
 
256
  return x
257
+
258
+ def _fix_matanyone_artifacts(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor:
 
259
  """Fix common MatAnyone artifacts."""
 
260
  alpha = self._fix_edge_bleeding(alpha, original_mask)
 
 
261
  alpha = self._fix_transparency_issues(alpha)
 
 
262
  alpha = self._ensure_mask_consistency(alpha, original_mask)
 
263
  return alpha
264
+
265
+ def _fix_edge_bleeding(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor:
 
266
  """Fix edge bleeding artifacts."""
 
267
  edges = self._detect_edges_torch(original_mask)
 
 
268
  edge_mask = F.max_pool2d(edges, kernel_size=5, stride=1, padding=2)
269
+
 
270
  alpha_refined = alpha.clone()
271
  edge_region = edge_mask > 0.1
 
 
272
  if edge_region.any():
273
  alpha_refined[edge_region] = (
274
+ 0.7 * alpha[edge_region] + 0.3 * original_mask.unsqueeze(1).expand_as(alpha)[edge_region]
 
275
  )
 
276
  return alpha_refined
277
+
278
  def _fix_transparency_issues(self, alpha: torch.Tensor) -> torch.Tensor:
279
  """Fix transparency artifacts."""
 
280
  mid_range = (alpha > 0.2) & (alpha < 0.8)
 
 
281
  alpha_fixed = alpha.clone()
282
  alpha_fixed[mid_range] = torch.where(
283
  alpha[mid_range] > 0.5,
284
  torch.clamp(alpha[mid_range] * 1.2, max=1.0),
285
+ torch.clamp(alpha[mid_range] * 0.8, min=0.0),
286
  )
 
 
287
  alpha_fixed = F.gaussian_blur(alpha_fixed, kernel_size=(3, 3))
 
288
  return alpha_fixed
289
+
290
+ def _ensure_mask_consistency(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor:
 
291
  """Ensure consistency with original mask."""
 
292
  if original_mask.dim() == 2:
293
  original_mask = original_mask.unsqueeze(0).unsqueeze(0)
294
  elif original_mask.dim() == 3:
295
  original_mask = original_mask.unsqueeze(1)
296
+
 
297
  alpha = torch.where(original_mask < 0.1, torch.zeros_like(alpha), alpha)
 
 
298
  alpha = torch.where(original_mask > 0.9, torch.ones_like(alpha) * 0.95, alpha)
 
299
  return alpha
300
+
301
+ def _compute_confidence(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor:
 
302
  """Compute confidence score for the output."""
 
303
  if original_mask.dim() < alpha.dim():
304
  original_mask = original_mask.unsqueeze(1).expand_as(alpha)
 
 
305
  diff = torch.abs(alpha - original_mask)
306
  confidence = 1.0 - torch.mean(diff, dim=(1, 2, 3))
 
307
  return confidence
308
+
309
  def _bilateral_filter_torch(self, x: torch.Tensor) -> torch.Tensor:
310
+ """Approximate bilateral filter via Gaussian blur."""
 
 
311
  return F.gaussian_blur(x, kernel_size=(5, 5))
312
+
313
  def _enhance_mask_edges(self, mask: torch.Tensor) -> torch.Tensor:
314
  """Enhance edges in mask channel."""
 
315
  edges = self._detect_edges_torch(mask)
316
+ enhanced = torch.clamp(mask + 0.3 * edges, 0, 1)
 
 
 
 
317
  return enhanced
318
+
319
  def _detect_edges_torch(self, x: torch.Tensor) -> torch.Tensor:
320
  """Detect edges using Sobel filters."""
321
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
322
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=x.dtype, device=x.device).view(1, 1, 3, 3)
 
 
 
 
 
323
  edges_x = F.conv2d(x, sobel_x, padding=1)
324
  edges_y = F.conv2d(x, sobel_y, padding=1)
 
 
325
  edges = torch.sqrt(edges_x ** 2 + edges_y ** 2)
 
326
  return edges
327
 
328
 
329
+ # -------------------------------
330
+ # SAM2 wrapper
331
+ # -------------------------------
332
+
333
  class SAM2Model:
334
  """SAM2 model wrapper with optimizations."""
335
+
336
  def __init__(self, config: ModelConfig):
337
  self.config = config
338
  self.model = None
339
  self.predictor = None
340
  self.loaded = False
341
+
342
  def load(self):
343
  """Load SAM2 model."""
344
  if self.loaded:
345
  return
346
+
347
  try:
 
348
  from sam2.build_sam import build_sam2
349
  from sam2.sam2_image_predictor import SAM2ImagePredictor
350
+
 
351
  self.model = build_sam2(
352
+ config_file=self.config.sam2_config,
353
  ckpt_path=self.config.sam2_checkpoint,
354
+ device=self.config.device,
355
  )
 
 
356
  self.predictor = SAM2ImagePredictor(self.model)
357
+
358
  self.loaded = True
359
  logger.info("SAM2 model loaded successfully")
360
+
361
  except Exception as e:
362
  logger.error(f"Failed to load SAM2 model: {e}")
363
  self.loaded = False
364
+
365
  def predict(self, image: np.ndarray, prompts: Optional[Dict] = None) -> np.ndarray:
366
  """Generate segmentation mask."""
367
  if not self.loaded:
368
  self.load()
369
+
370
+ if self.predictor is None:
371
  return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
372
+
 
373
  self.predictor.set_image(image)
374
+
 
375
  if prompts:
376
  masks, scores, _ = self.predictor.predict(
377
+ point_coords=prompts.get("points"),
378
+ point_labels=prompts.get("labels"),
379
+ box=prompts.get("box"),
380
+ multimask_output=True,
381
  )
382
+ mask = masks[int(np.argmax(scores))]
 
383
  else:
384
+ # Fallback automatic segmentation (API may differ by version)
385
+ try:
386
+ masks = self.predictor.generate_auto_masks(image)
387
+ mask = masks[0] if len(masks) > 0 else np.zeros_like(image[:, :, 0])
388
+ except Exception:
389
+ # As a conservative fallback, return empty mask
390
+ mask = np.zeros_like(image[:, :, 0])
391
+
392
  return mask
393
 
394
 
395
+ # -------------------------------
396
+ # Quality enhancer
397
+ # -------------------------------
398
+
399
  class QualityEnhancer(nn.Module):
400
  """Neural quality enhancement module."""
401
+
402
  def __init__(self):
403
  super().__init__()
404
  self.alpha_refiner = nn.Sequential(
 
407
  nn.Conv2d(16, 16, 3, padding=1),
408
  nn.ReLU(),
409
  nn.Conv2d(16, 1, 3, padding=1),
410
+ nn.Sigmoid(),
411
  )
412
+
413
  self.foreground_enhancer = nn.Sequential(
414
  nn.Conv2d(3, 32, 3, padding=1),
415
  nn.ReLU(),
416
  nn.Conv2d(32, 32, 3, padding=1),
417
  nn.ReLU(),
418
  nn.Conv2d(32, 3, 3, padding=1),
419
+ nn.Tanh(),
420
  )
421
+
422
+ def enhance_alpha(self, alpha: torch.Tensor, original_mask: torch.Tensor) -> torch.Tensor:
 
423
  """Enhance alpha channel quality."""
 
424
  refined = self.alpha_refiner(alpha)
425
+ enhanced = torch.clamp(0.7 * refined + 0.3 * alpha, 0, 1)
426
+ return enhanced
427
+
428
+ def enhance_foreground(self, foreground: torch.Tensor, original_image: torch.Tensor) -> torch.Tensor:
 
 
 
 
429
  """Enhance foreground quality."""
 
430
  residual = self.foreground_enhancer(foreground)
431
+ enhanced = torch.clamp(foreground + 0.1 * residual, -1, 1)
432
+ # If inputs are [0,1], clamp to [0,1]
433
+ if foreground.min() >= 0.0 and foreground.max() <= 1.0:
434
+ enhanced = torch.clamp(enhanced, 0.0, 1.0)
435
+ return enhanced
436
 
437
 
438
+ # -------------------------------
439
+ # Model Manager
440
+ # -------------------------------
441
+
442
  class ModelManager:
443
  """Central model management system."""
444
+
445
  def __init__(self, config: Optional[ModelConfig] = None):
446
  self.config = config or ModelConfig()
447
  self.cache = ModelCache(max_size=self.config.cache_size)
448
+
449
+ # Instantiate default models
 
450
  self.sam2 = SAM2Model(self.config)
451
  self.matanyone = MatAnyoneModel(self.config)
452
+
453
  def load_all(self):
454
  """Load all models."""
455
  logger.info("Loading all models...")
456
  self.sam2.load()
457
  self.matanyone.load()
458
  logger.info("All models loaded")
459
+
460
+ def get_sam2(self) -> 'SAM2Model':
461
+ """Get SAM2 model (lazy-loaded)."""
462
  if not self.sam2.loaded:
463
  self.sam2.load()
464
  return self.sam2
465
+
466
+ def get_matanyone(self) -> 'MatAnyoneModel':
467
+ """Get MatAnyone model (lazy-loaded)."""
468
  if not self.matanyone.loaded:
469
  self.matanyone.load()
470
  return self.matanyone
471
+
472
+ def process_frame(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
473
+ """Process single frame through the pipeline."""
 
 
474
  image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
475
  image_tensor = image_tensor.to(self.config.device)
476
+
 
477
  if mask is None:
478
  mask = self.sam2.predict(image)
479
+
480
  mask_tensor = torch.from_numpy(mask).float().to(self.config.device)
481
+
 
482
  result = self.matanyone(image_tensor, mask_tensor)
483
+
 
484
  output = {
485
+ "alpha": result["alpha"].squeeze().cpu().numpy(),
486
+ "foreground": (result["foreground"].squeeze().permute(1, 2, 0).cpu().numpy() * 255.0),
487
+ "confidence": result["confidence"].detach().cpu().numpy(),
488
  }
 
489
  return output
490
+
491
  def cleanup(self):
492
  """Cleanup models and free memory."""
493
  self.cache.clear()
 
496
  torch.cuda.empty_cache()
497
 
498
 
499
+ # -------------------------------
500
+ # ModelType / ModelFactory (compat)
501
+ # -------------------------------
502
+
503
+ class ModelType(Enum):
504
+ SAM2 = "sam2"
505
+ MATANYONE = "matanyone"
506
+
507
+
508
+ class ModelFactory:
509
+ """
510
+ Lightweight factory that returns cached model instances by type.
511
+ Kept for backward compatibility with modules importing from core.models.
512
+ """
513
+
514
+ def __init__(self, config: Optional[ModelConfig] = None):
515
+ self.config = config or ModelConfig()
516
+ self._instances: Dict[ModelType, Any] = {}
517
+
518
+ def get(self, model_type: 'ModelType | str'):
519
+ """Return (and cache) a model instance for the given type."""
520
+ if isinstance(model_type, str):
521
+ try:
522
+ model_type = ModelType(model_type.lower())
523
+ except Exception:
524
+ raise ValueError(f"Unknown model type: {model_type}")
525
+
526
+ if model_type == ModelType.SAM2:
527
+ if model_type not in self._instances:
528
+ self._instances[model_type] = SAM2Model(self.config)
529
+ return self._instances[model_type]
530
+
531
+ if model_type == ModelType.MATANYONE:
532
+ if model_type not in self._instances:
533
+ self._instances[model_type] = MatAnyoneModel(self.config)
534
+ return self._instances[model_type]
535
+
536
+ raise ValueError(f"Unsupported model type: {model_type}")
537
+
538
+ # Alias for older code
539
+ create = get
540
+
541
+
542
+ # -------------------------------
543
+ # Exports
544
+ # -------------------------------
545
+
546
  __all__ = [
547
+ "ModelManager",
548
+ "SAM2Model",
549
+ "MatAnyoneModel",
550
+ "ModelConfig",
551
+ "ModelCache",
552
+ "QualityEnhancer",
553
+ "ModelType",
554
+ "ModelFactory",
555
+ ]