#!/usr/bin/env python3 """Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support.""" from __future__ import annotations import logging import os import shutil import sys import tempfile import threading import uuid from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import cv2 import gradio as gr import numpy as np import torch try: import spaces # type: ignore except ImportError: # pragma: no cover - optional dependency on Spaces runtime spaces = None REPO_ROOT = Path(__file__).resolve().parent SAM2_REPO = REPO_ROOT / "sam2" if SAM2_REPO.exists(): sys.path.insert(0, str(SAM2_REPO)) from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator # noqa: E402 from sam2.build_sam import build_sam2, build_sam2_video_predictor # noqa: E402 from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402 logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger("unsamv2-gradio") USE_M2M_REFINEMENT = True CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml") CKPT_PATH = Path( os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt") ).resolve() if not CKPT_PATH.exists(): raise FileNotFoundError( f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file." ) GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1)) GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0)) ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"} ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60")) ZERO_GPU_WHOLE_DURATION = int( os.getenv("UNSAMV2_ZEROGPU_WHOLE_DURATION", str(ZERO_GPU_DURATION)) ) ZERO_GPU_VIDEO_DURATION = int( os.getenv("UNSAMV2_ZEROGPU_VIDEO_DURATION", str(max(120, ZERO_GPU_DURATION))) ) MAX_VIDEO_FRAMES = int(os.getenv("UNSAMV2_MAX_VIDEO_FRAMES", "360")) WHOLE_IMAGE_POINTS_PER_SIDE = int(os.getenv("UNSAMV2_WHOLE_POINTS", "64")) WHOLE_IMAGE_MAX_MASKS = 1000 POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0} POINT_COLORS_BGR = { 1: (72, 201, 127), # green-ish for positives 0: (64, 76, 225), # red-ish for negatives } MASK_COLOR_BGR = (0, 0, 255) DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp" WHOLE_IMAGE_DEFAULT_PATH = REPO_ROOT / "demo" / "sa_291195.jpg" DEFAULT_VIDEO_PATH = REPO_ROOT / "demo" / "bedroom.mp4" def _load_image_from_path(path: Path) -> Optional[np.ndarray]: if not path.exists(): LOGGER.warning("Default image missing at %s", path) return None img_bgr = cv2.imread(str(path), cv2.IMREAD_COLOR) if img_bgr is None: LOGGER.warning("Could not read default image at %s", path) return None return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) DEFAULT_IMAGE = _load_image_from_path(DEFAULT_IMAGE_PATH) WHOLE_IMAGE_DEFAULT = _load_image_from_path(WHOLE_IMAGE_DEFAULT_PATH) TMP_ROOT = REPO_ROOT / "_tmp" TMP_ROOT.mkdir(exist_ok=True) class ModelManager: """Keeps SAM2 models on each device and spawns lightweight predictors.""" def __init__(self) -> None: self._models: dict[str, torch.nn.Module] = {} self._lock = threading.Lock() def _build(self, device: torch.device) -> torch.nn.Module: LOGGER.info("Loading UnSAMv2 weights onto %s", device) return build_sam2( CONFIG_PATH, ckpt_path=str(CKPT_PATH), device=device, mode="eval", ) def get_model(self, device: torch.device) -> torch.nn.Module: key = ( f"{device.type}:{device.index}" if device.type == "cuda" else device.type ) with self._lock: if key not in self._models: self._models[key] = self._build(device) return self._models[key] def make_predictor(self, device: torch.device) -> SAM2ImagePredictor: return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0) def make_auto_mask_generator( self, device: torch.device, **kwargs, ) -> SAM2AutomaticMaskGenerator: return SAM2AutomaticMaskGenerator(self.get_model(device), **kwargs) MODEL_MANAGER = ModelManager() class VideoPredictorManager: """Caches heavy video predictors per device.""" def __init__(self) -> None: self._predictors: dict[str, torch.nn.Module] = {} self._lock = threading.Lock() def _build(self, device: torch.device) -> torch.nn.Module: LOGGER.info("Loading UnSAMv2 video predictor onto %s", device) return build_sam2_video_predictor( CONFIG_PATH, ckpt_path=str(CKPT_PATH), device=device, ) def get_predictor(self, device: torch.device) -> torch.nn.Module: key = ( f"{device.type}:{device.index}" if device.type == "cuda" else device.type ) with self._lock: if key not in self._predictors: self._predictors[key] = self._build(device) return self._predictors[key] VIDEO_PREDICTOR_MANAGER = VideoPredictorManager() def make_empty_video_state() -> Dict[str, Any]: return { "frame_dir": None, "frame_paths": [], "fps": 0.0, "frame_size": (0, 0), } def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]: if image is None: return None img = image[..., :3] # drop alpha if present if img.dtype == np.float32 or img.dtype == np.float64: if img.max() <= 1.0: img = (img * 255).clip(0, 255).astype(np.uint8) else: img = img.clip(0, 255).astype(np.uint8) elif img.dtype != np.uint8: img = img.clip(0, 255).astype(np.uint8) return img def make_temp_subdir(prefix: str) -> Path: TMP_ROOT.mkdir(exist_ok=True) return Path(tempfile.mkdtemp(prefix=prefix, dir=str(TMP_ROOT))) def remove_dir_if_exists(path_str: Optional[str]) -> None: if not path_str: return path = Path(path_str) if path.exists(): shutil.rmtree(path, ignore_errors=True) def load_rgb_image(path: Path) -> np.ndarray: bgr = cv2.imread(str(path), cv2.IMREAD_COLOR) if bgr is None: raise FileNotFoundError(f"Failed to read frame at {path}") return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) def resolve_video_path(video_value: Any) -> Optional[str]: if video_value is None: return None if isinstance(video_value, str): return video_value if isinstance(video_value, dict): return video_value.get("name") or video_value.get("path") # Gradio may pass a FileData/MediaData object with a .name attribute for attr in ("name", "path", "video", "data"): candidate = getattr(video_value, attr, None) if isinstance(candidate, str): return candidate return None def match_mask_to_image(mask: np.ndarray, image: np.ndarray) -> np.ndarray: mask_arr = np.asarray(mask) if mask_arr.ndim == 3: mask_arr = mask_arr.squeeze() h, w = image.shape[:2] if mask_arr.shape[:2] != (h, w): mask_arr = cv2.resize( mask_arr.astype(np.float32), (w, h), interpolation=cv2.INTER_NEAREST, ) return mask_arr.astype(bool) def colorize_mask_collection( image: np.ndarray, masks: Sequence[np.ndarray], alpha: float = 0.55, ) -> np.ndarray: if not masks: return image canvas = image.astype(np.float32) rng = np.random.default_rng(1337) for mask in masks: mask_arr = match_mask_to_image(mask, image) if not mask_arr.any(): continue color = rng.integers(20, 235, size=3) canvas[mask_arr] = ( canvas[mask_arr] * (1.0 - alpha) + color * alpha ) return canvas.clip(0, 255).astype(np.uint8) def render_video_overlay( video_state: Dict[str, Any], frame_idx: int, pts: Sequence[Sequence[float]], lbls: Sequence[int], ) -> Optional[np.ndarray]: frame_paths: List[str] = list(video_state.get("frame_paths", [])) if not frame_paths: return None safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) frame = load_rgb_image(Path(frame_paths[safe_idx])) return draw_overlay(frame, None, pts, lbls) def mask_entries_to_arrays(entries: Sequence[Dict[str, Any]]) -> List[np.ndarray]: arrays: List[np.ndarray] = [] for entry in entries: seg = entry.get("segmentation", entry) if isinstance(seg, np.ndarray): mask = seg elif isinstance(seg, dict): from sam2.utils.amg import rle_to_mask mask = rle_to_mask(seg) else: mask = np.asarray(seg) arrays.append(mask.astype(bool)) return arrays def summarize_masks(entries: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]: summary: List[Dict[str, Any]] = [] for idx, entry in enumerate(entries, start=1): summary.append( { "mask": idx, "area": int(entry.get("area", 0)), "pred_iou": round(float(entry.get("predicted_iou", 0.0)), 3), "stability": round(float(entry.get("stability_score", 0.0)), 3), } ) return summary def extract_video_frames(video_path: str) -> Tuple[List[Path], float, Tuple[int, int], Path]: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("Could not open the uploaded video.") fps = cap.get(cv2.CAP_PROP_FPS) if not fps or fps <= 1e-3: fps = 12.0 frame_dir = make_temp_subdir("video_frames_") frame_paths: List[Path] = [] height = width = 0 idx = 0 while True: ok, frame = cap.read() if not ok: break rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if idx == 0: height, width = rgb.shape[:2] out_path = frame_dir / f"{idx:05d}.jpg" if not cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)): cap.release() raise RuntimeError(f"Failed to write frame {idx} to disk") frame_paths.append(out_path) idx += 1 if idx >= MAX_VIDEO_FRAMES: LOGGER.warning( "Stopping frame extraction at %d frames per UNSAMV2_MAX_VIDEO_FRAMES", MAX_VIDEO_FRAMES, ) break cap.release() if not frame_paths: remove_dir_if_exists(str(frame_dir)) raise ValueError("No frames decoded from the provided video.") if height == 0 or width == 0: sample = load_rgb_image(frame_paths[0]) height, width = sample.shape[:2] return frame_paths, float(fps), (height, width), frame_dir def write_video_from_frames(frames: Sequence[np.ndarray], fps: float) -> Path: if not frames: raise ValueError("No frames available to write video output.") height, width = frames[0].shape[:2] safe_fps = fps if fps and fps > 0 else 12.0 out_path = TMP_ROOT / f"video_seg_{uuid.uuid4().hex}.mp4" fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(str(out_path), fourcc, safe_fps, (width, height)) if not writer.isOpened(): raise RuntimeError("Failed to initialize video writer. Check codec support.") for frame in frames: if frame.shape[:2] != (height, width): raise ValueError("All frames must share the same spatial resolution.") writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) writer.release() return out_path def choose_device() -> torch.device: preference = os.getenv("UNSAMV2_DEVICE", "auto").lower() if preference == "cpu": return torch.device("cpu") if preference.startswith("cuda") or preference == "gpu": if torch.cuda.is_available(): return torch.device(preference if preference.startswith("cuda") else "cuda") LOGGER.warning("CUDA requested but not available; defaulting to CPU") return torch.device("cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def wrap_with_zero_gpu( fn: Callable[..., Any], duration: int, ) -> Callable[..., Any]: if spaces is None or not ZERO_GPU_ENABLED: return fn try: LOGGER.info("Enabling ZeroGPU (duration=%ss) for %s", duration, fn.__name__) return spaces.GPU(duration=duration)(fn) # type: ignore[misc] except Exception: # pragma: no cover - defensive logging LOGGER.exception("Failed to wrap %s with ZeroGPU; running on CPU", fn.__name__) return fn def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor: tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device) return tensor def apply_m2m_refinement( predictor, point_coords, point_labels, granularity, logits, best_mask_idx, use_m2m: bool = True, ): """Optionally run a second M2M pass using the best mask's logits.""" if not use_m2m: return None logging.info("Applying M2M refinement...") try: if logits is None: raise ValueError("logits must be provided for M2M refinement.") low_res_logits = logits[best_mask_idx : best_mask_idx + 1] refined_masks, refined_scores, _ = predictor.predict( point_coords=point_coords, point_labels=point_labels, multimask_output=False, gra=granularity, mask_input=low_res_logits, ) refined_mask = refined_masks[0] refined_score = float(refined_scores[0]) logging.info("M2M refinement completed with score: %.3f", refined_score) return refined_mask, refined_score except Exception as exc: # pragma: no cover - logging only logging.error("M2M refinement failed: %s, using original mask", exc) return None def draw_overlay( image: np.ndarray, mask: Optional[np.ndarray], points: Sequence[Sequence[float]], labels: Sequence[int], alpha: float = 0.55, ) -> np.ndarray: canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if mask is not None: mask_bool = match_mask_to_image(mask, image) overlay = np.zeros_like(canvas_bgr, dtype=np.uint8) overlay[mask_bool] = MASK_COLOR_BGR canvas_bgr = np.where( mask_bool[..., None], (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8), canvas_bgr, ) for (x, y), lbl in zip(points, labels): color = POINT_COLORS_BGR.get(lbl, (255, 255, 255)) center = (int(round(x)), int(round(y))) cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA) cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA) return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) def handle_image_upload(image: Optional[np.ndarray]): img = ensure_uint8(image) if img is None: return ( None, None, [], [], "Upload an image to start adding clicks.", ) return ( img, img, [], [], "Image loaded. Choose click type, then tap on the image.", ) def handle_click( point_mode: str, pts: List[Sequence[float]], lbls: List[int], image: Optional[np.ndarray], evt: gr.SelectData, ): if image is None: return ( gr.update(), pts, lbls, "Upload an image first.", ) coord = evt.index # (x, y) if coord is None: return ( gr.update(), pts, lbls, "Couldn't read click position.", ) x, y = coord label = POINT_MODE_TO_LABEL.get(point_mode, 1) pts = pts + [[float(x), float(y)]] lbls = lbls + [label] overlay = draw_overlay(image, None, pts, lbls) status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})." return overlay, pts, lbls, status def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]): if not pts: return ( gr.update(), pts, lbls, "No clicks to undo.", ) pts = pts[:-1] lbls = lbls[:-1] overlay = draw_overlay(image, None, pts, lbls) if image is not None else None status = "Removed the last click." return overlay, pts, lbls, status def clear_clicks(image: Optional[np.ndarray]): overlay = image if image is not None else None return overlay, [], [], "Cleared all clicks." def _run_segmentation( image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int], granularity: float, ): img = ensure_uint8(image) if img is None: return None, "Upload an image to segment." if not pts: return draw_overlay(img, None, [], []), "Add at least one click before running segmentation." device = choose_device() predictor = MODEL_MANAGER.make_predictor(device) predictor.set_image(img) coords = np.asarray(pts, dtype=np.float32) labels = np.asarray(lbls, dtype=np.int32) gran_tensor = build_granularity_tensor(granularity, predictor.device) masks, scores, logits = predictor.predict( point_coords=coords, point_labels=labels, multimask_output=True, gra=float(granularity), granularity=gran_tensor, ) best_idx = int(np.argmax(scores)) best_mask = masks[best_idx].astype(bool) status = ( f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | " f"granularity={granularity:.2f}" ) refinement = apply_m2m_refinement( predictor=predictor, point_coords=coords, point_labels=labels, granularity=float(granularity), logits=logits, best_mask_idx=best_idx, use_m2m=USE_M2M_REFINEMENT, ) if refinement is not None: refined_mask, refined_score = refinement best_mask = refined_mask.astype(bool) status += f" | M2M IoU: {refined_score:.3f}" overlay = draw_overlay(img, best_mask, pts, lbls) return overlay, status def run_whole_image_segmentation( image: Optional[np.ndarray], granularity: float, pred_iou_thresh: float, stability_thresh: float, ): img = ensure_uint8(image) if img is None: return None, [], "Upload an image to run whole-image segmentation." device = choose_device() mask_generator = MODEL_MANAGER.make_auto_mask_generator( device=device, points_per_side=WHOLE_IMAGE_POINTS_PER_SIDE, points_per_batch=128, pred_iou_thresh=float(pred_iou_thresh), stability_score_thresh=float(stability_thresh), mask_threshold=-1.0, box_nms_thresh=0.7, crop_n_layers=0, min_mask_region_area=0, use_m2m=USE_M2M_REFINEMENT, output_mode="binary_mask", ) try: masks = mask_generator.generate(img, gra=float(granularity)) except Exception as exc: LOGGER.exception("Whole-image segmentation failed") return None, [], f"Whole-image segmentation failed: {exc}" if not masks: return img, [], "Mask generator did not return any regions. Try lowering thresholds." trimmed = masks[:WHOLE_IMAGE_MAX_MASKS] mask_arrays = mask_entries_to_arrays(trimmed) overlay = colorize_mask_collection(img, mask_arrays) table = summarize_masks(trimmed) status = ( f"Generated {len(trimmed)} masks | granularity={granularity:.2f}, " f"IoU≥{pred_iou_thresh:.2f}, stability≥{stability_thresh:.2f}" ) return overlay, table, status def handle_video_upload( video_file: Any, current_state: Optional[Dict[str, Any]] = None, ): if current_state: remove_dir_if_exists(current_state.get("frame_dir")) state = make_empty_video_state() if isinstance(video_file, (list, tuple)): video_file = video_file[0] if video_file else None video_path = resolve_video_path(video_file) if not video_path: return ( gr.update(value=None, visible=False), state, gr.update(value=0, minimum=0, maximum=0, interactive=False), [], [], 0, "Upload a video to start adding clicks.", ) try: frame_paths, fps, frame_size, frame_dir = extract_video_frames(video_path) except Exception as exc: LOGGER.exception("Video decoding failed") return ( gr.update(value=None, visible=False), state, gr.update(value=0, minimum=0, maximum=0, interactive=False), [], [], 0, f"Video decoding failed: {exc}", ) state.update( { "frame_dir": str(frame_dir), "frame_paths": [str(p) for p in frame_paths], "fps": fps, "frame_size": frame_size, } ) first_overlay = render_video_overlay(state, 0, [], []) slider_update = gr.update( value=0, minimum=0, maximum=len(frame_paths) - 1, step=1, interactive=True, ) status = f"Loaded video with {len(frame_paths)} frames at {fps:.1f} FPS." return ( gr.update(value=first_overlay, visible=True), state, slider_update, [], [], 0, status, ) def handle_video_frame_change( frame_idx: int, video_state: Dict[str, Any], ): overlay = render_video_overlay(video_state, frame_idx, [], []) if overlay is None: return gr.update(), [], [], 0, "Upload a video first." safe_idx = int(np.clip(frame_idx, 0, len(video_state.get("frame_paths", [])) - 1)) status = f"Annotating frame {safe_idx}." return overlay, [], [], safe_idx, status def handle_video_click( point_mode: str, pts: List[Sequence[float]], lbls: List[int], video_state: Dict[str, Any], frame_idx: int, evt: gr.SelectData, ): overlay = render_video_overlay(video_state, frame_idx, pts, lbls) if overlay is None: return gr.update(), pts, lbls, "Upload a video first." if evt.index is None: return overlay, pts, lbls, "Couldn't read click position." x, y = evt.index label = POINT_MODE_TO_LABEL.get(point_mode, 1) pts = pts + [[float(x), float(y)]] lbls = lbls + [label] overlay = render_video_overlay(video_state, frame_idx, pts, lbls) status = ( f"Added {'positive' if label == 1 else 'negative'} click at " f"({int(x)}, {int(y)}) on frame {int(frame_idx)}." ) return overlay, pts, lbls, status def undo_video_click( video_state: Dict[str, Any], pts: List[Sequence[float]], lbls: List[int], frame_idx: int, ): if not pts: return gr.update(), pts, lbls, "No clicks to undo." pts = pts[:-1] lbls = lbls[:-1] overlay = render_video_overlay(video_state, frame_idx, pts, lbls) return overlay, pts, lbls, "Removed the last click." def clear_video_clicks(video_state: Dict[str, Any], frame_idx: int): overlay = render_video_overlay(video_state, frame_idx, [], []) return overlay, [], [], "Cleared all clicks for the selected frame." def reset_video_interface(current_state: Dict[str, Any]): remove_dir_if_exists(current_state.get("frame_dir")) state = make_empty_video_state() return ( gr.update(value=None, visible=False), state, gr.update(value=0, minimum=0, maximum=0, interactive=False), [], [], 0, "Cleared video. Upload a new clip to continue.", ) def run_video_segmentation( video_state: Dict[str, Any], pts: List[Sequence[float]], lbls: List[int], frame_idx: int, granularity: float, ): frame_paths: List[str] = list(video_state.get("frame_paths", [])) if not frame_paths: return None, "Upload a video to segment." if not pts: return None, "Add at least one click on the annotation frame." frame_dir = video_state.get("frame_dir") if not frame_dir: return None, "Video frames are unavailable. Please re-upload the video." safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) device = choose_device() predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device) inference_state = predictor.init_state(video_path=frame_dir) predictor.reset_state(inference_state) coords = np.asarray(pts, dtype=np.float32) labels = np.asarray(lbls, dtype=np.int32) try: _, obj_ids, mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=safe_idx, obj_id=1, points=coords, labels=labels, gra=float(granularity), ) except Exception as exc: LOGGER.exception("Video add_new_points_or_box failed") return None, f"Video segmentation failed during prompting: {exc}" video_masks: Dict[int, Dict[int, np.ndarray]] = {} video_masks[safe_idx] = { int(obj_id): (mask_logits[i] > -1.0).cpu().numpy() for i, obj_id in enumerate(obj_ids) } try: for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( inference_state, gra=float(granularity), ): video_masks[out_frame_idx] = { int(obj_id): (out_mask_logits[i] > -1.0).cpu().numpy() for i, obj_id in enumerate(out_obj_ids) } except Exception as exc: LOGGER.exception("Video propagation failed") return None, f"Video propagation failed: {exc}" overlays: List[np.ndarray] = [] for idx, frame_path in enumerate(frame_paths): base = load_rgb_image(Path(frame_path)) mask = video_masks.get(idx, {}).get(1) overlays.append(draw_overlay(base, mask, [], [])) try: video_path = write_video_from_frames(overlays, video_state.get("fps", 12.0)) except Exception as exc: LOGGER.exception("Failed to encode output video") return None, f"Tracking succeeded but video export failed: {exc}" status = ( f"Tracked object from frame {safe_idx} across {len(frame_paths)} frames | " f"granularity={granularity:.2f}" ) return str(video_path), status def run_video_frame_segmentation( video_state: Dict[str, Any], pts: List[Sequence[float]], lbls: List[int], frame_idx: int, granularity: float, ): frame_paths: List[str] = list(video_state.get("frame_paths", [])) if not frame_paths: return None, "Upload a video to segment." if not pts: return None, "Add at least one click on the annotation frame." frame_dir = video_state.get("frame_dir") if not frame_dir: return None, "Video frames are unavailable. Please re-upload the video." safe_idx = int(np.clip(frame_idx, 0, len(frame_paths) - 1)) device = choose_device() predictor = VIDEO_PREDICTOR_MANAGER.get_predictor(device) inference_state = predictor.init_state(video_path=frame_dir) predictor.reset_state(inference_state) coords = np.asarray(pts, dtype=np.float32) labels = np.asarray(lbls, dtype=np.int32) try: _, obj_ids, mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=safe_idx, obj_id=1, points=coords, labels=labels, gra=float(granularity), ) except Exception as exc: LOGGER.exception("Video frame segmentation failed") return None, f"Frame segmentation failed: {exc}" if not obj_ids: return None, "Predictor did not return a mask for this frame." mask = (mask_logits[0] > -1.0).cpu().numpy() base = load_rgb_image(Path(frame_paths[safe_idx])) overlay = draw_overlay(base, mask, pts, lbls) status = ( f"Segmented frame {safe_idx} with {len(pts)} clicks | " f"granularity={granularity:.2f}" ) return overlay, status segment_fn = wrap_with_zero_gpu(_run_segmentation, ZERO_GPU_DURATION) whole_image_fn = wrap_with_zero_gpu( run_whole_image_segmentation, ZERO_GPU_WHOLE_DURATION, ) video_frame_fn = wrap_with_zero_gpu( run_video_frame_segmentation, ZERO_GPU_VIDEO_DURATION, ) video_segmentation_fn = wrap_with_zero_gpu( run_video_segmentation, ZERO_GPU_VIDEO_DURATION, ) def build_demo() -> gr.Blocks: with gr.Blocks(title="UnSAMv2 Interactive + Whole Image + Video", theme=gr.themes.Soft()) as demo: gr.Markdown( """

UnSAMv2 · Segment Anything at Any Granularity

""" ) gr.HTML( """ """ ) with gr.Tabs(elem_id="mode-tabs"): # Interactive Image Tab with gr.Tab("Interactive Image Segmentation"): image_state = gr.State(DEFAULT_IMAGE) points_state = gr.State([]) labels_state = gr.State([]) image_input = gr.Image( label="Image · clicks & mask", type="numpy", height=480, value=DEFAULT_IMAGE, sources=["upload"], ) with gr.Row(equal_height=True): point_mode = gr.Radio( choices=list(POINT_MODE_TO_LABEL.keys()), value="Foreground (+)", label="Click type", ) granularity_slider = gr.Slider( minimum=GRANULARITY_MIN, maximum=GRANULARITY_MAX, value=0.2, step=0.01, label="Granularity", info="Lower = finer details, Higher = coarser regions", ) segment_button = gr.Button("Segment", variant="primary") with gr.Row(): undo_button = gr.Button("Undo last click") clear_button = gr.Button("Clear clicks") status_markdown = gr.Markdown(" Ready for interactive clicks.") image_input.upload( handle_image_upload, inputs=[image_input], outputs=[ image_input, image_state, points_state, labels_state, status_markdown, ], ) image_input.clear( handle_image_upload, inputs=[image_input], outputs=[ image_input, image_state, points_state, labels_state, status_markdown, ], ) image_input.select( handle_click, inputs=[ point_mode, points_state, labels_state, image_state, ], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) undo_button.click( undo_last_click, inputs=[image_state, points_state, labels_state], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) clear_button.click( clear_clicks, inputs=[image_state], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) segment_button.click( segment_fn, inputs=[image_state, points_state, labels_state, granularity_slider], outputs=[image_input, status_markdown], ) # Whole Image Tab with gr.Tab("Whole Image Segmentation"): whole_image_input = gr.Image( label="Image · automatic masks", type="numpy", height=480, value=WHOLE_IMAGE_DEFAULT if WHOLE_IMAGE_DEFAULT is not None else DEFAULT_IMAGE, sources=["upload"], ) whole_granularity = gr.Slider( minimum=GRANULARITY_MIN, maximum=GRANULARITY_MAX, value=0.15, step=0.01, label="Granularity", ) whole_generate_btn = gr.Button("Generate masks", variant="primary") with gr.Accordion("Advanced mask filtering", open=False): pred_iou_thresh = gr.Slider( minimum=0.1, maximum=0.99, value=0.77, step=0.01, label="Predicted IoU threshold", ) stability_thresh = gr.Slider( minimum=0.1, maximum=0.99, value=0.9, step=0.01, label="Stability threshold", ) whole_overlay = gr.Image(label="Mask overlay", height=480) whole_table = gr.Dataframe( headers=["mask", "area", "pred_iou", "stability"], datatype=["number", "number", "number", "number"], label="Mask stats", wrap=True, visible=False, ) whole_status = gr.Markdown(" Ready for whole-image masks.") whole_generate_btn.click( whole_image_fn, inputs=[ whole_image_input, whole_granularity, pred_iou_thresh, stability_thresh, ], outputs=[whole_overlay, whole_table, whole_status], ) # Video Tab with gr.Tab("Video Segmentation"): video_state = gr.State(make_empty_video_state()) video_points_state = gr.State([]) video_labels_state = gr.State([]) annotation_frame_state = gr.State(0) with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=360): upload_button = gr.UploadButton( "Upload video", file_types=["video"], file_count="single", ) frame_display = gr.Image( label="Video · add clicks", type="numpy", height=420, interactive=True, visible=False, ) frame_slider = gr.Slider( minimum=0, maximum=0, value=0, step=1, interactive=False, label="Select frame", ) video_point_mode = gr.Radio( choices=list(POINT_MODE_TO_LABEL.keys()), value="Foreground (+)", label="Click type", ) with gr.Row(): video_undo = gr.Button("Undo click") video_clear = gr.Button("Clear clicks") video_granularity = gr.Slider( minimum=GRANULARITY_MIN, maximum=GRANULARITY_MAX, value=0.33, step=0.01, label="Granularity", ) with gr.Row(): video_frame_btn = gr.Button("Segment frame", variant="secondary") video_segment_btn = gr.Button("Propagate video", variant="primary") with gr.Column(scale=1, min_width=320): video_output = gr.Video( label="Segmented preview", autoplay=False, height=420, ) video_status = gr.Markdown(" Ready for video segmentation.") upload_button.upload( handle_video_upload, inputs=[upload_button, video_state], outputs=[ frame_display, video_state, frame_slider, video_points_state, video_labels_state, annotation_frame_state, video_status, ], ) if DEFAULT_VIDEO_PATH.exists(): def _load_default_video(state): return handle_video_upload(str(DEFAULT_VIDEO_PATH), state) demo.load( _load_default_video, inputs=[video_state], outputs=[ frame_display, video_state, frame_slider, video_points_state, video_labels_state, annotation_frame_state, video_status, ], queue=False, ) frame_slider.change( handle_video_frame_change, inputs=[frame_slider, video_state], outputs=[ frame_display, video_points_state, video_labels_state, annotation_frame_state, video_status, ], ) frame_display.select( handle_video_click, inputs=[ video_point_mode, video_points_state, video_labels_state, video_state, annotation_frame_state, ], outputs=[ frame_display, video_points_state, video_labels_state, video_status, ], ) frame_display.clear( reset_video_interface, inputs=[video_state], outputs=[ frame_display, video_state, frame_slider, video_points_state, video_labels_state, annotation_frame_state, video_status, ], ) video_frame_btn.click( video_frame_fn, inputs=[ video_state, video_points_state, video_labels_state, annotation_frame_state, video_granularity, ], outputs=[frame_display, video_status], ) video_undo.click( undo_video_click, inputs=[ video_state, video_points_state, video_labels_state, annotation_frame_state, ], outputs=[ frame_display, video_points_state, video_labels_state, video_status, ], ) video_clear.click( clear_video_clicks, inputs=[video_state, annotation_frame_state], outputs=[ frame_display, video_points_state, video_labels_state, video_status, ], ) video_segment_btn.click( video_segmentation_fn, inputs=[ video_state, video_points_state, video_labels_state, annotation_frame_state, video_granularity, ], outputs=[video_output, video_status], ) demo.queue(max_size=8) return demo demo = build_demo() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)