|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x) |
|
|
- Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights |
|
|
- Never assigns predictor.device (read-only) — moves .model to device instead |
|
|
- Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env |
|
|
- Downscale ladder on set_image(); upsample masks back to original H,W |
|
|
- torch.autocast(device_type="cuda", ...) + torch.inference_mode() |
|
|
- Thread-safe (Lock) for Gradio/Spaces concurrency |
|
|
- Returns {"masks": (N,H,W) float32, "scores": (N,) float32}; safe fallback on failure |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
import traceback |
|
|
from typing import Optional, Dict, Any, Tuple, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import cv2 |
|
|
import threading |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
if not logger.handlers: |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
_val = os.environ.get("OMP_NUM_THREADS") |
|
|
if _val is not None and not str(_val).strip().isdigit(): |
|
|
try: |
|
|
del os.environ["OMP_NUM_THREADS"] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
def _select_device(pref: str) -> str: |
|
|
pref = (pref or "").lower() |
|
|
if pref.startswith("cuda"): |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if pref == "cpu": |
|
|
return "cpu" |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray: |
|
|
if img is None: |
|
|
raise ValueError("set_image received None image") |
|
|
arr = np.asarray(img) |
|
|
if arr.ndim != 3 or arr.shape[2] < 3: |
|
|
raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}") |
|
|
if np.issubdtype(arr.dtype, np.floating): |
|
|
arr = np.clip(arr, 0.0, 1.0) |
|
|
arr = (arr * 255.0 + 0.5).astype(np.uint8) |
|
|
elif arr.dtype == np.uint16: |
|
|
arr = (arr / 257).astype(np.uint8) |
|
|
elif arr.dtype != np.uint8: |
|
|
arr = arr.astype(np.uint8) |
|
|
if arr.shape[2] == 4: |
|
|
arr = arr[:, :, :3] |
|
|
if force_bgr_to_rgb: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) |
|
|
return arr |
|
|
|
|
|
|
|
|
def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]: |
|
|
if h <= 0 or w <= 0: |
|
|
return h, w, 1.0 |
|
|
s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0 |
|
|
s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0 |
|
|
s = min(s1, s2) |
|
|
nh = max(1, int(round(h * s))) |
|
|
nw = max(1, int(round(w * s))) |
|
|
return nh, nw, s |
|
|
|
|
|
|
|
|
def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]: |
|
|
sizes = [(nh, nw)] |
|
|
for f in (0.85, 0.70, 0.55, 0.40, 0.30): |
|
|
sizes.append((max(64, int(nh * f)), max(64, int(nw * f)))) |
|
|
uniq, seen = [], set() |
|
|
for s in sizes: |
|
|
if s not in seen: |
|
|
uniq.append(s); seen.add(s) |
|
|
return uniq |
|
|
|
|
|
|
|
|
def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray: |
|
|
masks = np.asarray(masks) |
|
|
if masks.ndim == 2: |
|
|
masks = masks[None, ...] |
|
|
elif masks.ndim == 4 and masks.shape[1] == 1: |
|
|
masks = masks[:, 0, :, :] |
|
|
if masks.ndim != 3: |
|
|
masks = np.squeeze(masks) |
|
|
if masks.ndim == 2: |
|
|
masks = masks[None, ...] |
|
|
n, h, w = masks.shape |
|
|
H, W = out_hw |
|
|
if (h, w) == (H, W): |
|
|
return masks.astype(np.float32, copy=False) |
|
|
out = np.zeros((n, H, W), dtype=np.float32) |
|
|
for i in range(n): |
|
|
out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR) |
|
|
return np.clip(out, 0.0, 1.0) |
|
|
|
|
|
|
|
|
def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray: |
|
|
x = np.asarray(x) |
|
|
if x.dtype == np.uint8: |
|
|
return (x.astype(np.float32) / 255.0) |
|
|
return x.astype(np.float32, copy=False) |
|
|
|
|
|
|
|
|
class _SAM2Adapter: |
|
|
def __init__(self, predictor, device: str): |
|
|
self.pred = predictor |
|
|
self.device = device |
|
|
self.orig_hw: Tuple[int, int] = (0, 0) |
|
|
self._current_rgb: Optional[np.ndarray] = None |
|
|
self._current_hw: Tuple[int, int] = (0, 0) |
|
|
self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024")) |
|
|
self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000")) |
|
|
self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1" |
|
|
self._lock = threading.Lock() |
|
|
|
|
|
def set_image(self, image: np.ndarray): |
|
|
with self._lock: |
|
|
rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb) |
|
|
H, W = rgb.shape[:2] |
|
|
self.orig_hw = (H, W) |
|
|
nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels) |
|
|
if s < 1.0: |
|
|
work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA) |
|
|
self._current_rgb = work |
|
|
self._current_hw = (nh, nw) |
|
|
else: |
|
|
self._current_rgb = rgb |
|
|
self._current_hw = (H, W) |
|
|
self.pred.set_image(self._current_rgb) |
|
|
|
|
|
def predict(self, **kwargs) -> Dict[str, Any]: |
|
|
with self._lock: |
|
|
if self._current_rgb is None or self.orig_hw == (0, 0): |
|
|
raise RuntimeError("SAM2Adapter.predict called before set_image()") |
|
|
|
|
|
H, W = self.orig_hw |
|
|
nh, nw = self._current_hw |
|
|
sizes = _ladder(nh, nw) |
|
|
last_exc: Optional[BaseException] = None |
|
|
|
|
|
for (th, tw) in sizes: |
|
|
try: |
|
|
if (th, tw) != (nh, nw): |
|
|
small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA) |
|
|
self.pred.set_image(small) |
|
|
|
|
|
class _NoOp: |
|
|
def __enter__(self): return None |
|
|
def __exit__(self, *a): return False |
|
|
|
|
|
use_amp = (self.device == "cuda") |
|
|
if use_amp: |
|
|
amp_ctx = torch.autocast( |
|
|
device_type="cuda", |
|
|
dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) |
|
|
) |
|
|
else: |
|
|
amp_ctx = _NoOp() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
with amp_ctx: |
|
|
out = self.pred.predict(**kwargs) |
|
|
|
|
|
masks = None; scores = None; logits = None |
|
|
if isinstance(out, dict): |
|
|
masks = out.get("masks"); scores = out.get("scores"); logits = out.get("logits") |
|
|
elif isinstance(out, (tuple, list)): |
|
|
if len(out) >= 1: masks = out[0] |
|
|
if len(out) >= 2: scores = out[1] |
|
|
if len(out) >= 3: logits = out[2] |
|
|
else: |
|
|
masks = out |
|
|
|
|
|
if masks is None: |
|
|
raise RuntimeError("SAM2 returned no masks") |
|
|
|
|
|
masks = _normalize_masks_dtype(masks) |
|
|
masks_up = _upsample_stack(masks, (H, W)) |
|
|
|
|
|
if scores is None: |
|
|
scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5 |
|
|
else: |
|
|
scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1) |
|
|
|
|
|
out_dict = {"masks": masks_up, "scores": scores} |
|
|
if logits is not None: |
|
|
lg = np.asarray(logits) |
|
|
if lg.ndim == 3: |
|
|
lg = _upsample_stack(lg, (H, W)) |
|
|
elif lg.ndim == 4 and lg.shape[1] == 1: |
|
|
lg = _upsample_stack(lg[:, 0, :, :], (H, W)) |
|
|
out_dict["logits"] = lg.astype(np.float32, copy=False) |
|
|
|
|
|
return out_dict |
|
|
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
|
last_exc = e |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}") |
|
|
continue |
|
|
except Exception as e: |
|
|
last_exc = e |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
logger.debug(traceback.format_exc()) |
|
|
logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}") |
|
|
continue |
|
|
|
|
|
logger.warning(f"SAM2 calls failed; returning fallback mask. {last_exc}") |
|
|
return { |
|
|
"masks": np.ones((1, H, W), dtype=np.float32), |
|
|
"scores": np.array([0.5], dtype=np.float32), |
|
|
} |
|
|
|
|
|
|
|
|
class SAM2Loader: |
|
|
"""Dedicated loader for SAM2 models (PyTorch 2.x, Spaces-friendly).""" |
|
|
|
|
|
def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"): |
|
|
self.device = _select_device(device) |
|
|
self.cache_dir = cache_dir |
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1") |
|
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") |
|
|
self.model = None |
|
|
self.adapter = None |
|
|
self.model_id = None |
|
|
self.load_time = 0.0 |
|
|
|
|
|
def _determine_optimal_size(self) -> str: |
|
|
|
|
|
env_size = os.environ.get("USE_SAM2", "").lower() |
|
|
if env_size in ["tiny", "small", "base", "large"]: |
|
|
logger.info(f"Using SAM2 size from environment: {env_size}") |
|
|
return env_size |
|
|
|
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
props = torch.cuda.get_device_properties(0) |
|
|
vram_gb = props.total_memory / (1024**3) |
|
|
if vram_gb < 4: return "tiny" |
|
|
if vram_gb < 8: return "small" |
|
|
if vram_gb < 12: return "base" |
|
|
return "large" |
|
|
except Exception: |
|
|
pass |
|
|
return "tiny" |
|
|
|
|
|
def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]: |
|
|
if model_size == "auto": |
|
|
model_size = self._determine_optimal_size() |
|
|
|
|
|
|
|
|
model_map = { |
|
|
"tiny": "facebook/sam2-hiera-tiny", |
|
|
"small": "facebook/sam2-hiera-small", |
|
|
"base": "facebook/sam2-hiera-base-plus", |
|
|
"large": "facebook/sam2-hiera-large", |
|
|
} |
|
|
self.model_id = model_map.get(model_size, model_map["tiny"]) |
|
|
logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})") |
|
|
|
|
|
for name, fn in (("official", self._load_official), ("fallback", self._load_fallback)): |
|
|
try: |
|
|
t0 = time.time() |
|
|
pred = fn() |
|
|
if pred is None: |
|
|
continue |
|
|
self.model = pred |
|
|
self.adapter = _SAM2Adapter(self.model, self.device) |
|
|
self.load_time = time.time() - t0 |
|
|
logger.info(f"SAM2 loaded via {name} in {self.load_time:.2f}s") |
|
|
return self.adapter |
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 {name} strategy failed: {e}") |
|
|
logger.debug(traceback.format_exc()) |
|
|
|
|
|
logger.error("All SAM2 loading strategies failed") |
|
|
return None |
|
|
|
|
|
def _load_official(self): |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
except ImportError as e: |
|
|
logger.error(f"Failed to import SAM2 components: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
config_map = { |
|
|
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), |
|
|
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), |
|
|
"facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), |
|
|
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), |
|
|
} |
|
|
|
|
|
config_file, checkpoint_file = config_map.get(self.model_id, (None, None)) |
|
|
if not config_file: |
|
|
raise ValueError(f"Unknown model: {self.model_id}") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Downloading checkpoint: {checkpoint_file}") |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=self.model_id, |
|
|
filename=checkpoint_file, |
|
|
cache_dir=self.cache_dir, |
|
|
local_files_only=False |
|
|
) |
|
|
logger.info(f"Checkpoint downloaded to: {checkpoint_path}") |
|
|
|
|
|
|
|
|
config_path = hf_hub_download( |
|
|
repo_id=self.model_id, |
|
|
filename=config_file, |
|
|
cache_dir=self.cache_dir, |
|
|
local_files_only=False |
|
|
) |
|
|
logger.info(f"Config downloaded to: {config_path}") |
|
|
|
|
|
|
|
|
sam2_model = build_sam2(config_path, checkpoint_path, device=self.device) |
|
|
predictor = SAM2ImagePredictor(sam2_model) |
|
|
|
|
|
|
|
|
if hasattr(predictor, "model"): |
|
|
predictor.model = predictor.model.to(self.device) |
|
|
predictor.model.eval() |
|
|
|
|
|
return predictor |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading SAM2 model: {e}") |
|
|
logger.debug(traceback.format_exc()) |
|
|
return None |
|
|
|
|
|
def _load_fallback(self): |
|
|
class FallbackSAM2: |
|
|
def __init__(self, device): |
|
|
self.device = device |
|
|
self._img = None |
|
|
def set_image(self, image): |
|
|
self._img = np.asarray(image) |
|
|
def predict(self, **kwargs): |
|
|
h, w = (self._img.shape[:2] if self._img is not None else (512, 512)) |
|
|
return { |
|
|
"masks": np.ones((1, h, w), dtype=np.float32), |
|
|
"scores": np.array([0.5], dtype=np.float32), |
|
|
} |
|
|
logger.warning("Using fallback SAM2 (no real segmentation)") |
|
|
return FallbackSAM2(self.device) |
|
|
|
|
|
def cleanup(self): |
|
|
self.adapter = None |
|
|
if self.model is not None: |
|
|
try: |
|
|
del self.model |
|
|
except Exception: |
|
|
pass |
|
|
self.model = None |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_info(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"loaded": self.adapter is not None, |
|
|
"model_id": self.model_id, |
|
|
"device": self.device, |
|
|
"load_time": self.load_time, |
|
|
"model_type": type(self.model).__name__ if self.model else None, |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import sys |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
dev = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if len(sys.argv) < 2: |
|
|
print(f"Usage: {sys.argv[0]} image.jpg") |
|
|
raise SystemExit(1) |
|
|
path = sys.argv[1] |
|
|
img = cv2.imread(path, cv2.IMREAD_COLOR) |
|
|
if img is None: |
|
|
print(f"Could not load image {path}") |
|
|
raise SystemExit(2) |
|
|
loader = SAM2Loader(device=dev) |
|
|
sam = loader.load("auto") |
|
|
if not sam: |
|
|
print("Failed to load SAM2") |
|
|
raise SystemExit(3) |
|
|
sam.set_image(img) |
|
|
out = sam.predict(point_coords=None, point_labels=None) |
|
|
m = out["masks"] |
|
|
print("Masks:", m.shape, m.dtype, m.min(), m.max()) |
|
|
cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8)) |
|
|
print("Wrote sam2_mask0.png") |