MogensR's picture
Update models/loaders/sam2_loader.py
2cd1a25
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
- Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights
- Never assigns predictor.device (read-only) — moves .model to device instead
- Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
- Downscale ladder on set_image(); upsample masks back to original H,W
- torch.autocast(device_type="cuda", ...) + torch.inference_mode()
- Thread-safe (Lock) for Gradio/Spaces concurrency
- Returns {"masks": (N,H,W) float32, "scores": (N,) float32}; safe fallback on failure
"""
from __future__ import annotations
import os
import time
import logging
import traceback
from typing import Optional, Dict, Any, Tuple, List
import numpy as np
import torch
import cv2
import threading
logger = logging.getLogger(__name__)
if not logger.handlers:
logging.basicConfig(level=logging.INFO)
# Sanitize bad OMP before heavy libs use it
_val = os.environ.get("OMP_NUM_THREADS")
if _val is not None and not str(_val).strip().isdigit():
try:
del os.environ["OMP_NUM_THREADS"]
except Exception:
pass
def _select_device(pref: str) -> str:
pref = (pref or "").lower()
if pref.startswith("cuda"):
return "cuda" if torch.cuda.is_available() else "cpu"
if pref == "cpu":
return "cpu"
return "cuda" if torch.cuda.is_available() else "cpu"
def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
if img is None:
raise ValueError("set_image received None image")
arr = np.asarray(img)
if arr.ndim != 3 or arr.shape[2] < 3:
raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
if np.issubdtype(arr.dtype, np.floating):
arr = np.clip(arr, 0.0, 1.0)
arr = (arr * 255.0 + 0.5).astype(np.uint8)
elif arr.dtype == np.uint16:
arr = (arr / 257).astype(np.uint8)
elif arr.dtype != np.uint8:
arr = arr.astype(np.uint8)
if arr.shape[2] == 4:
arr = arr[:, :, :3]
if force_bgr_to_rgb:
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
return arr
def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
if h <= 0 or w <= 0:
return h, w, 1.0
s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
s = min(s1, s2)
nh = max(1, int(round(h * s)))
nw = max(1, int(round(w * s)))
return nh, nw, s
def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
sizes = [(nh, nw)]
for f in (0.85, 0.70, 0.55, 0.40, 0.30):
sizes.append((max(64, int(nh * f)), max(64, int(nw * f))))
uniq, seen = [], set()
for s in sizes:
if s not in seen:
uniq.append(s); seen.add(s)
return uniq
def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
masks = np.asarray(masks)
if masks.ndim == 2:
masks = masks[None, ...]
elif masks.ndim == 4 and masks.shape[1] == 1:
masks = masks[:, 0, :, :]
if masks.ndim != 3:
masks = np.squeeze(masks)
if masks.ndim == 2:
masks = masks[None, ...]
n, h, w = masks.shape
H, W = out_hw
if (h, w) == (H, W):
return masks.astype(np.float32, copy=False)
out = np.zeros((n, H, W), dtype=np.float32)
for i in range(n):
out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
return np.clip(out, 0.0, 1.0)
def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
x = np.asarray(x)
if x.dtype == np.uint8:
return (x.astype(np.float32) / 255.0)
return x.astype(np.float32, copy=False)
class _SAM2Adapter:
def __init__(self, predictor, device: str):
self.pred = predictor
self.device = device
self.orig_hw: Tuple[int, int] = (0, 0)
self._current_rgb: Optional[np.ndarray] = None
self._current_hw: Tuple[int, int] = (0, 0)
self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
self._lock = threading.Lock()
def set_image(self, image: np.ndarray):
with self._lock:
rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
H, W = rgb.shape[:2]
self.orig_hw = (H, W)
nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
if s < 1.0:
work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
self._current_rgb = work
self._current_hw = (nh, nw)
else:
self._current_rgb = rgb
self._current_hw = (H, W)
self.pred.set_image(self._current_rgb)
def predict(self, **kwargs) -> Dict[str, Any]:
with self._lock:
if self._current_rgb is None or self.orig_hw == (0, 0):
raise RuntimeError("SAM2Adapter.predict called before set_image()")
H, W = self.orig_hw
nh, nw = self._current_hw
sizes = _ladder(nh, nw)
last_exc: Optional[BaseException] = None
for (th, tw) in sizes:
try:
if (th, tw) != (nh, nw):
small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
self.pred.set_image(small)
class _NoOp:
def __enter__(self): return None
def __exit__(self, *a): return False
use_amp = (self.device == "cuda")
if use_amp:
amp_ctx = torch.autocast(
device_type="cuda",
dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)
)
else:
amp_ctx = _NoOp()
with torch.inference_mode():
with amp_ctx:
out = self.pred.predict(**kwargs)
masks = None; scores = None; logits = None
if isinstance(out, dict):
masks = out.get("masks"); scores = out.get("scores"); logits = out.get("logits")
elif isinstance(out, (tuple, list)):
if len(out) >= 1: masks = out[0]
if len(out) >= 2: scores = out[1]
if len(out) >= 3: logits = out[2]
else:
masks = out
if masks is None:
raise RuntimeError("SAM2 returned no masks")
masks = _normalize_masks_dtype(masks)
masks_up = _upsample_stack(masks, (H, W))
if scores is None:
scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
else:
scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
out_dict = {"masks": masks_up, "scores": scores}
if logits is not None:
lg = np.asarray(logits)
if lg.ndim == 3:
lg = _upsample_stack(lg, (H, W))
elif lg.ndim == 4 and lg.shape[1] == 1:
lg = _upsample_stack(lg[:, 0, :, :], (H, W))
out_dict["logits"] = lg.astype(np.float32, copy=False)
return out_dict
except torch.cuda.OutOfMemoryError as e:
last_exc = e
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}")
continue
except Exception as e:
last_exc = e
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug(traceback.format_exc())
logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
continue
logger.warning(f"SAM2 calls failed; returning fallback mask. {last_exc}")
return {
"masks": np.ones((1, H, W), dtype=np.float32),
"scores": np.array([0.5], dtype=np.float32),
}
class SAM2Loader:
"""Dedicated loader for SAM2 models (PyTorch 2.x, Spaces-friendly)."""
def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"):
self.device = _select_device(device)
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
self.model = None
self.adapter = None
self.model_id = None
self.load_time = 0.0
def _determine_optimal_size(self) -> str:
# Check environment variable first
env_size = os.environ.get("USE_SAM2", "").lower()
if env_size in ["tiny", "small", "base", "large"]:
logger.info(f"Using SAM2 size from environment: {env_size}")
return env_size
try:
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
vram_gb = props.total_memory / (1024**3)
if vram_gb < 4: return "tiny"
if vram_gb < 8: return "small"
if vram_gb < 12: return "base"
return "large"
except Exception:
pass
return "tiny"
def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
if model_size == "auto":
model_size = self._determine_optimal_size()
# Use original SAM2 model names (without .1) for compatibility
model_map = {
"tiny": "facebook/sam2-hiera-tiny",
"small": "facebook/sam2-hiera-small",
"base": "facebook/sam2-hiera-base-plus",
"large": "facebook/sam2-hiera-large",
}
self.model_id = model_map.get(model_size, model_map["tiny"])
logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
for name, fn in (("official", self._load_official), ("fallback", self._load_fallback)):
try:
t0 = time.time()
pred = fn()
if pred is None:
continue
self.model = pred
self.adapter = _SAM2Adapter(self.model, self.device)
self.load_time = time.time() - t0
logger.info(f"SAM2 loaded via {name} in {self.load_time:.2f}s")
return self.adapter
except Exception as e:
logger.error(f"SAM2 {name} strategy failed: {e}")
logger.debug(traceback.format_exc())
logger.error("All SAM2 loading strategies failed")
return None
def _load_official(self):
try:
from huggingface_hub import hf_hub_download
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError as e:
logger.error(f"Failed to import SAM2 components: {e}")
return None
# Map model IDs to config files and checkpoint names
config_map = {
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
"facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"),
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
}
config_file, checkpoint_file = config_map.get(self.model_id, (None, None))
if not config_file:
raise ValueError(f"Unknown model: {self.model_id}")
try:
# Download the checkpoint from HuggingFace
logger.info(f"Downloading checkpoint: {checkpoint_file}")
checkpoint_path = hf_hub_download(
repo_id=self.model_id,
filename=checkpoint_file,
cache_dir=self.cache_dir,
local_files_only=False
)
logger.info(f"Checkpoint downloaded to: {checkpoint_path}")
# Also download the config file if needed
config_path = hf_hub_download(
repo_id=self.model_id,
filename=config_file,
cache_dir=self.cache_dir,
local_files_only=False
)
logger.info(f"Config downloaded to: {config_path}")
# Build the model using the traditional method
sam2_model = build_sam2(config_path, checkpoint_path, device=self.device)
predictor = SAM2ImagePredictor(sam2_model)
# Ensure model is on the correct device and in eval mode
if hasattr(predictor, "model"):
predictor.model = predictor.model.to(self.device)
predictor.model.eval()
return predictor
except Exception as e:
logger.error(f"Error loading SAM2 model: {e}")
logger.debug(traceback.format_exc())
return None
def _load_fallback(self):
class FallbackSAM2:
def __init__(self, device):
self.device = device
self._img = None
def set_image(self, image):
self._img = np.asarray(image)
def predict(self, **kwargs):
h, w = (self._img.shape[:2] if self._img is not None else (512, 512))
return {
"masks": np.ones((1, h, w), dtype=np.float32),
"scores": np.array([0.5], dtype=np.float32),
}
logger.warning("Using fallback SAM2 (no real segmentation)")
return FallbackSAM2(self.device)
def cleanup(self):
self.adapter = None
if self.model is not None:
try:
del self.model
except Exception:
pass
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_info(self) -> Dict[str, Any]:
return {
"loaded": self.adapter is not None,
"model_id": self.model_id,
"device": self.device,
"load_time": self.load_time,
"model_type": type(self.model).__name__ if self.model else None,
}
if __name__ == "__main__":
# Standalone smoke test only; NOT executed when imported in your app
import sys
logging.basicConfig(level=logging.INFO)
dev = "cuda" if torch.cuda.is_available() else "cpu"
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} image.jpg")
raise SystemExit(1)
path = sys.argv[1]
img = cv2.imread(path, cv2.IMREAD_COLOR)
if img is None:
print(f"Could not load image {path}")
raise SystemExit(2)
loader = SAM2Loader(device=dev)
sam = loader.load("auto")
if not sam:
print("Failed to load SAM2")
raise SystemExit(3)
sam.set_image(img)
out = sam.predict(point_coords=None, point_labels=None)
m = out["masks"]
print("Masks:", m.shape, m.dtype, m.min(), m.max())
cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
print("Wrote sam2_mask0.png")