#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x) - Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights - Never assigns predictor.device (read-only) — moves .model to device instead - Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env - Downscale ladder on set_image(); upsample masks back to original H,W - torch.autocast(device_type="cuda", ...) + torch.inference_mode() - Thread-safe (Lock) for Gradio/Spaces concurrency - Returns {"masks": (N,H,W) float32, "scores": (N,) float32}; safe fallback on failure """ from __future__ import annotations import os import time import logging import traceback from typing import Optional, Dict, Any, Tuple, List import numpy as np import torch import cv2 import threading logger = logging.getLogger(__name__) if not logger.handlers: logging.basicConfig(level=logging.INFO) # Sanitize bad OMP before heavy libs use it _val = os.environ.get("OMP_NUM_THREADS") if _val is not None and not str(_val).strip().isdigit(): try: del os.environ["OMP_NUM_THREADS"] except Exception: pass def _select_device(pref: str) -> str: pref = (pref or "").lower() if pref.startswith("cuda"): return "cuda" if torch.cuda.is_available() else "cpu" if pref == "cpu": return "cpu" return "cuda" if torch.cuda.is_available() else "cpu" def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray: if img is None: raise ValueError("set_image received None image") arr = np.asarray(img) if arr.ndim != 3 or arr.shape[2] < 3: raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}") if np.issubdtype(arr.dtype, np.floating): arr = np.clip(arr, 0.0, 1.0) arr = (arr * 255.0 + 0.5).astype(np.uint8) elif arr.dtype == np.uint16: arr = (arr / 257).astype(np.uint8) elif arr.dtype != np.uint8: arr = arr.astype(np.uint8) if arr.shape[2] == 4: arr = arr[:, :, :3] if force_bgr_to_rgb: arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) return arr def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]: if h <= 0 or w <= 0: return h, w, 1.0 s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0 s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0 s = min(s1, s2) nh = max(1, int(round(h * s))) nw = max(1, int(round(w * s))) return nh, nw, s def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]: sizes = [(nh, nw)] for f in (0.85, 0.70, 0.55, 0.40, 0.30): sizes.append((max(64, int(nh * f)), max(64, int(nw * f)))) uniq, seen = [], set() for s in sizes: if s not in seen: uniq.append(s); seen.add(s) return uniq def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray: masks = np.asarray(masks) if masks.ndim == 2: masks = masks[None, ...] elif masks.ndim == 4 and masks.shape[1] == 1: masks = masks[:, 0, :, :] if masks.ndim != 3: masks = np.squeeze(masks) if masks.ndim == 2: masks = masks[None, ...] n, h, w = masks.shape H, W = out_hw if (h, w) == (H, W): return masks.astype(np.float32, copy=False) out = np.zeros((n, H, W), dtype=np.float32) for i in range(n): out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR) return np.clip(out, 0.0, 1.0) def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray: x = np.asarray(x) if x.dtype == np.uint8: return (x.astype(np.float32) / 255.0) return x.astype(np.float32, copy=False) class _SAM2Adapter: def __init__(self, predictor, device: str): self.pred = predictor self.device = device self.orig_hw: Tuple[int, int] = (0, 0) self._current_rgb: Optional[np.ndarray] = None self._current_hw: Tuple[int, int] = (0, 0) self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024")) self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000")) self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1" self._lock = threading.Lock() def set_image(self, image: np.ndarray): with self._lock: rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb) H, W = rgb.shape[:2] self.orig_hw = (H, W) nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels) if s < 1.0: work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA) self._current_rgb = work self._current_hw = (nh, nw) else: self._current_rgb = rgb self._current_hw = (H, W) self.pred.set_image(self._current_rgb) def predict(self, **kwargs) -> Dict[str, Any]: with self._lock: if self._current_rgb is None or self.orig_hw == (0, 0): raise RuntimeError("SAM2Adapter.predict called before set_image()") H, W = self.orig_hw nh, nw = self._current_hw sizes = _ladder(nh, nw) last_exc: Optional[BaseException] = None for (th, tw) in sizes: try: if (th, tw) != (nh, nw): small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA) self.pred.set_image(small) class _NoOp: def __enter__(self): return None def __exit__(self, *a): return False use_amp = (self.device == "cuda") if use_amp: amp_ctx = torch.autocast( device_type="cuda", dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) ) else: amp_ctx = _NoOp() with torch.inference_mode(): with amp_ctx: out = self.pred.predict(**kwargs) masks = None; scores = None; logits = None if isinstance(out, dict): masks = out.get("masks"); scores = out.get("scores"); logits = out.get("logits") elif isinstance(out, (tuple, list)): if len(out) >= 1: masks = out[0] if len(out) >= 2: scores = out[1] if len(out) >= 3: logits = out[2] else: masks = out if masks is None: raise RuntimeError("SAM2 returned no masks") masks = _normalize_masks_dtype(masks) masks_up = _upsample_stack(masks, (H, W)) if scores is None: scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5 else: scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1) out_dict = {"masks": masks_up, "scores": scores} if logits is not None: lg = np.asarray(logits) if lg.ndim == 3: lg = _upsample_stack(lg, (H, W)) elif lg.ndim == 4 and lg.shape[1] == 1: lg = _upsample_stack(lg[:, 0, :, :], (H, W)) out_dict["logits"] = lg.astype(np.float32, copy=False) return out_dict except torch.cuda.OutOfMemoryError as e: last_exc = e if torch.cuda.is_available(): torch.cuda.empty_cache() logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}") continue except Exception as e: last_exc = e if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug(traceback.format_exc()) logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}") continue logger.warning(f"SAM2 calls failed; returning fallback mask. {last_exc}") return { "masks": np.ones((1, H, W), dtype=np.float32), "scores": np.array([0.5], dtype=np.float32), } class SAM2Loader: """Dedicated loader for SAM2 models (PyTorch 2.x, Spaces-friendly).""" def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"): self.device = _select_device(device) self.cache_dir = cache_dir os.makedirs(self.cache_dir, exist_ok=True) os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") self.model = None self.adapter = None self.model_id = None self.load_time = 0.0 def _determine_optimal_size(self) -> str: # Check environment variable first env_size = os.environ.get("USE_SAM2", "").lower() if env_size in ["tiny", "small", "base", "large"]: logger.info(f"Using SAM2 size from environment: {env_size}") return env_size try: if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) vram_gb = props.total_memory / (1024**3) if vram_gb < 4: return "tiny" if vram_gb < 8: return "small" if vram_gb < 12: return "base" return "large" except Exception: pass return "tiny" def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]: if model_size == "auto": model_size = self._determine_optimal_size() # Use original SAM2 model names (without .1) for compatibility model_map = { "tiny": "facebook/sam2-hiera-tiny", "small": "facebook/sam2-hiera-small", "base": "facebook/sam2-hiera-base-plus", "large": "facebook/sam2-hiera-large", } self.model_id = model_map.get(model_size, model_map["tiny"]) logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})") for name, fn in (("official", self._load_official), ("fallback", self._load_fallback)): try: t0 = time.time() pred = fn() if pred is None: continue self.model = pred self.adapter = _SAM2Adapter(self.model, self.device) self.load_time = time.time() - t0 logger.info(f"SAM2 loaded via {name} in {self.load_time:.2f}s") return self.adapter except Exception as e: logger.error(f"SAM2 {name} strategy failed: {e}") logger.debug(traceback.format_exc()) logger.error("All SAM2 loading strategies failed") return None def _load_official(self): try: from huggingface_hub import hf_hub_download from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor except ImportError as e: logger.error(f"Failed to import SAM2 components: {e}") return None # Map model IDs to config files and checkpoint names config_map = { "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), } config_file, checkpoint_file = config_map.get(self.model_id, (None, None)) if not config_file: raise ValueError(f"Unknown model: {self.model_id}") try: # Download the checkpoint from HuggingFace logger.info(f"Downloading checkpoint: {checkpoint_file}") checkpoint_path = hf_hub_download( repo_id=self.model_id, filename=checkpoint_file, cache_dir=self.cache_dir, local_files_only=False ) logger.info(f"Checkpoint downloaded to: {checkpoint_path}") # Also download the config file if needed config_path = hf_hub_download( repo_id=self.model_id, filename=config_file, cache_dir=self.cache_dir, local_files_only=False ) logger.info(f"Config downloaded to: {config_path}") # Build the model using the traditional method sam2_model = build_sam2(config_path, checkpoint_path, device=self.device) predictor = SAM2ImagePredictor(sam2_model) # Ensure model is on the correct device and in eval mode if hasattr(predictor, "model"): predictor.model = predictor.model.to(self.device) predictor.model.eval() return predictor except Exception as e: logger.error(f"Error loading SAM2 model: {e}") logger.debug(traceback.format_exc()) return None def _load_fallback(self): class FallbackSAM2: def __init__(self, device): self.device = device self._img = None def set_image(self, image): self._img = np.asarray(image) def predict(self, **kwargs): h, w = (self._img.shape[:2] if self._img is not None else (512, 512)) return { "masks": np.ones((1, h, w), dtype=np.float32), "scores": np.array([0.5], dtype=np.float32), } logger.warning("Using fallback SAM2 (no real segmentation)") return FallbackSAM2(self.device) def cleanup(self): self.adapter = None if self.model is not None: try: del self.model except Exception: pass self.model = None if torch.cuda.is_available(): torch.cuda.empty_cache() def get_info(self) -> Dict[str, Any]: return { "loaded": self.adapter is not None, "model_id": self.model_id, "device": self.device, "load_time": self.load_time, "model_type": type(self.model).__name__ if self.model else None, } if __name__ == "__main__": # Standalone smoke test only; NOT executed when imported in your app import sys logging.basicConfig(level=logging.INFO) dev = "cuda" if torch.cuda.is_available() else "cpu" if len(sys.argv) < 2: print(f"Usage: {sys.argv[0]} image.jpg") raise SystemExit(1) path = sys.argv[1] img = cv2.imread(path, cv2.IMREAD_COLOR) if img is None: print(f"Could not load image {path}") raise SystemExit(2) loader = SAM2Loader(device=dev) sam = loader.load("auto") if not sam: print("Failed to load SAM2") raise SystemExit(3) sam.set_image(img) out = sam.predict(point_coords=None, point_labels=None) m = out["masks"] print("Masks:", m.shape, m.dtype, m.min(), m.max()) cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8)) print("Wrote sam2_mask0.png")