MogensR commited on
Commit
b6786fa
·
1 Parent(s): e2ca8f7

Update models/loaders/sam2_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/sam2_loader.py +351 -136
models/loaders/sam2_loader.py CHANGED
@@ -1,221 +1,436 @@
1
  #!/usr/bin/env python3
2
  """
3
- SAM2 Model Loader
4
- Handles all SAM2 loading strategies with proper fallbacks
 
 
 
 
 
5
  """
6
 
 
 
7
  import os
8
  import time
9
  import logging
10
  import traceback
11
- from pathlib import Path
12
- from typing import Optional, Dict, Any
13
 
14
- import torch
15
  import numpy as np
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class SAM2Loader:
21
  """Dedicated loader for SAM2 models"""
22
-
23
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"):
24
- self.device = device
25
  self.cache_dir = cache_dir
26
  os.makedirs(self.cache_dir, exist_ok=True)
27
-
28
  # Configure HF hub for spaces
29
- os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
30
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
31
-
32
- self.model = None
 
33
  self.model_id = None
34
  self.load_time = 0.0
35
-
36
  def load(self, model_size: str = "auto") -> Optional[Any]:
37
  """
38
  Load SAM2 model with specified size
39
  Args:
40
  model_size: "tiny", "small", "base", "large", or "auto"
41
  Returns:
42
- Loaded model or None
43
  """
44
  if model_size == "auto":
45
  model_size = self._determine_optimal_size()
46
-
47
  model_map = {
48
- "tiny": "facebook/sam2.1-hiera-tiny",
49
  "small": "facebook/sam2.1-hiera-small",
50
- "base": "facebook/sam2.1-hiera-base-plus",
51
  "large": "facebook/sam2.1-hiera-large",
52
  }
53
-
54
  self.model_id = model_map.get(model_size, model_map["tiny"])
55
- logger.info(f"Loading SAM2 model: {self.model_id}")
56
-
57
- # Try loading strategies in order
58
- strategies = [
59
- ("official", self._load_official),
60
- ("transformers", self._load_transformers),
61
- ("fallback", self._load_fallback)
62
- ]
63
-
64
- for strategy_name, strategy_func in strategies:
65
  try:
66
- logger.info(f"Trying SAM2 loading strategy: {strategy_name}")
67
- start_time = time.time()
68
- model = strategy_func()
69
- if model:
70
- self.load_time = time.time() - start_time
71
- self.model = model
72
- logger.info(f"SAM2 loaded successfully via {strategy_name} in {self.load_time:.2f}s")
73
- return model
 
74
  except Exception as e:
75
- logger.error(f"SAM2 {strategy_name} strategy failed: {e}")
76
  logger.debug(traceback.format_exc())
77
- continue
78
-
79
  logger.error("All SAM2 loading strategies failed")
80
  return None
81
-
82
  def _determine_optimal_size(self) -> str:
83
  """Determine optimal model size based on available memory"""
84
  try:
85
  if torch.cuda.is_available():
86
  props = torch.cuda.get_device_properties(0)
87
  vram_gb = props.total_memory / (1024**3)
88
-
89
- if vram_gb < 4:
90
- return "tiny"
91
- elif vram_gb < 8:
92
- return "small"
93
- elif vram_gb < 12:
94
- return "base"
95
- else:
96
- return "large"
97
- except:
98
  pass
99
- return "tiny" # Conservative default
100
-
101
  def _load_official(self) -> Optional[Any]:
102
- """Load using official SAM2 API - return directly without wrapper"""
103
  from sam2.sam2_image_predictor import SAM2ImagePredictor
104
-
105
  predictor = SAM2ImagePredictor.from_pretrained(
106
  self.model_id,
107
  cache_dir=self.cache_dir,
108
  local_files_only=False,
109
  trust_remote_code=True,
110
  )
111
-
112
- # Move to device and set to eval mode
113
  if hasattr(predictor, "model"):
114
  predictor.model = predictor.model.to(self.device)
115
  predictor.model.eval()
116
-
117
- # Set device attribute if it exists
118
  if hasattr(predictor, "device"):
119
  predictor.device = self.device
120
-
121
- # Return the predictor directly - no wrapper!
122
- # The calling code expects the standard SAM2 interface
123
  return predictor
124
-
125
- def _load_transformers(self) -> Optional[Any]:
126
- """Load using transformers library"""
127
- from transformers import AutoModel, AutoProcessor
128
-
129
- dtype = torch.float16 if "cuda" in self.device else torch.float32
130
-
131
- model = AutoModel.from_pretrained(
132
- self.model_id,
133
- trust_remote_code=True,
134
- torch_dtype=dtype,
135
- cache_dir=self.cache_dir
136
- )
137
- model = model.to(self.device)
138
- model.eval()
139
-
140
- try:
141
- processor = AutoProcessor.from_pretrained(
142
- self.model_id,
143
- cache_dir=self.cache_dir
144
- )
145
- except:
146
- processor = None
147
-
148
- # Wrap to match expected API
149
- class SAM2TransformersWrapper:
150
- def __init__(self, model, processor, device):
151
- self.model = model
152
- self.processor = processor
153
- self.device = device
154
- self.current_image = None
155
-
156
- def set_image(self, image):
157
- """Store image for processing"""
158
- self.current_image = image
159
- # TODO: Actually encode image with model here
160
-
161
- def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
162
- """Generate masks from prompts"""
163
- # TODO: Implement actual prediction
164
- if self.current_image is not None:
165
- h, w = self.current_image.shape[:2]
166
- else:
167
- h, w = 512, 512
168
-
169
- # For now, return dummy mask
170
- return {
171
- "masks": np.ones((1, h, w), dtype=np.float32),
172
- "scores": np.array([0.9]),
173
- "logits": np.ones((1, h, w), dtype=np.float32),
174
- }
175
-
176
- return SAM2TransformersWrapper(model, processor, self.device)
177
-
178
  def _load_fallback(self) -> Optional[Any]:
179
- """Create fallback predictor for testing"""
180
-
181
  class FallbackSAM2:
182
  def __init__(self, device):
183
  self.device = device
184
- self.current_image = None
185
-
186
  def set_image(self, image):
187
- self.current_image = image
188
-
189
- def predict(self, point_coords=None, point_labels=None, box=None, **kwargs):
190
- """Return full mask as fallback"""
191
- if self.current_image is not None:
192
- h, w = self.current_image.shape[:2]
193
  else:
194
  h, w = 512, 512
195
-
196
  return {
197
  "masks": np.ones((1, h, w), dtype=np.float32),
198
- "scores": np.array([0.5]),
199
- "logits": np.ones((1, h, w), dtype=np.float32),
200
  }
201
-
202
  logger.warning("Using fallback SAM2 (no real segmentation)")
203
  return FallbackSAM2(self.device)
204
-
205
  def cleanup(self):
206
  """Clean up resources"""
207
- if self.model:
208
- del self.model
 
 
 
 
209
  self.model = None
210
  if torch.cuda.is_available():
211
  torch.cuda.empty_cache()
212
-
213
  def get_info(self) -> Dict[str, Any]:
214
  """Get loader information"""
215
  return {
216
- "loaded": self.model is not None,
217
  "model_id": self.model_id,
218
  "device": self.device,
219
  "load_time": self.load_time,
220
- "model_type": type(self.model).__name__ if self.model else None
221
- }
 
1
  #!/usr/bin/env python3
2
  """
3
+ SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe)
4
+
5
+ - Loads a SAM2 image predictor on the desired device.
6
+ - set_image(): accepts RGB/BGR, uint8/float; optional model-only downscale to save VRAM.
7
+ - predict(): forwards prompts, upsamples masks back to original size, normalizes outputs.
8
+ - Uses torch.inference_mode + optional autocast on CUDA.
9
+ - Returns shapes compatible with utils.cv_processing.segment_person_hq logic.
10
  """
11
 
12
+ from __future__ import annotations
13
+
14
  import os
15
  import time
16
  import logging
17
  import traceback
18
+ from typing import Optional, Dict, Any, Tuple, List
 
19
 
 
20
  import numpy as np
21
+ import torch
22
+ import cv2
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
+ # -------------------------- helpers --------------------------
28
+
29
+ def _select_device(pref: str) -> str:
30
+ pref = (pref or "").lower()
31
+ if pref.startswith("cuda"):
32
+ return "cuda" if torch.cuda.is_available() else "cpu"
33
+ if pref == "cpu":
34
+ return "cpu"
35
+ return "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+
38
+ def _ensure_rgb_uint8(img: np.ndarray, force_bgr_to_rgb: bool = False) -> np.ndarray:
39
+ """
40
+ Accept BGR/RGB, 3ch/4ch, uint8/float; return RGB uint8 [H,W,3].
41
+ We DO NOT blindly swap channels; cv_processing already feeds RGB.
42
+ Set force_bgr_to_rgb=True only if you know inputs are BGR.
43
+ """
44
+ if img is None:
45
+ raise ValueError("set_image received None image")
46
+
47
+ arr = np.asarray(img)
48
+ if arr.ndim != 3 or arr.shape[2] < 3:
49
+ raise ValueError(f"Expected HxWxC image with C>=3, got shape={arr.shape}")
50
+
51
+ # If float, clamp + scale to uint8
52
+ if np.issubdtype(arr.dtype, np.floating):
53
+ arr = np.clip(arr, 0.0, 1.0)
54
+ arr = (arr * 255.0 + 0.5).astype(np.uint8)
55
+ elif arr.dtype != np.uint8:
56
+ if arr.dtype == np.uint16:
57
+ arr = (arr / 257).astype(np.uint8)
58
+ else:
59
+ arr = arr.astype(np.uint8)
60
+
61
+ # If 4-channel, drop alpha
62
+ if arr.shape[2] == 4:
63
+ arr = arr[:, :, :3]
64
+
65
+ if force_bgr_to_rgb:
66
+ arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
67
+
68
+ return arr
69
+
70
+
71
+ def _compute_scaled_size(h: int, w: int, max_edge: int, target_pixels: int) -> Tuple[int, int, float]:
72
+ if h <= 0 or w <= 0:
73
+ return h, w, 1.0
74
+ s1 = min(1.0, float(max_edge) / float(max(h, w))) if max_edge > 0 else 1.0
75
+ s2 = min(1.0, (float(target_pixels) / float(h * w)) ** 0.5) if target_pixels > 0 else 1.0
76
+ s = min(s1, s2)
77
+ nh = max(1, int(round(h * s)))
78
+ nw = max(1, int(round(w * s)))
79
+ return nh, nw, s
80
+
81
+
82
+ def _ladder(nh: int, nw: int) -> List[Tuple[int, int]]:
83
+ """
84
+ Progressive smaller sizes for OOM fallback.
85
+ """
86
+ sizes = [(nh, nw)]
87
+ sizes.append((max(1, int(nh * 0.85)), max(1, int(nw * 0.85))))
88
+ sizes.append((max(1, int(nh * 0.70)), max(1, int(nw * 0.70))))
89
+ sizes.append((max(1, int(nh * 0.50)), max(1, int(nw * 0.50))))
90
+ sizes.append((max(1, int(nh * 0.35)), max(1, int(nw * 0.35))))
91
+ # de-duplicate and keep order
92
+ uniq = []
93
+ seen = set()
94
+ for s in sizes:
95
+ if s not in seen:
96
+ uniq.append(s); seen.add(s)
97
+ return uniq
98
+
99
+
100
+ def _upsample_stack(masks: np.ndarray, out_hw: Tuple[int, int]) -> np.ndarray:
101
+ """
102
+ masks: (N,h,w) float → bilinear → (N,H,W) float [0..1]
103
+ """
104
+ if masks.ndim != 3:
105
+ masks = np.asarray(masks)
106
+ if masks.ndim == 2:
107
+ masks = masks[None, ...]
108
+ elif masks.ndim == 4 and masks.shape[1] == 1:
109
+ masks = masks[:, 0, :, :]
110
+ else:
111
+ # try to squeeze to N,H,W
112
+ masks = np.squeeze(masks)
113
+ if masks.ndim == 2:
114
+ masks = masks[None, ...]
115
+ n, h, w = masks.shape
116
+ H, W = out_hw
117
+ if (h, w) == (H, W):
118
+ return masks.astype(np.float32, copy=False)
119
+ out = np.zeros((n, H, W), dtype=np.float32)
120
+ for i in range(n):
121
+ out[i] = cv2.resize(masks[i].astype(np.float32), (W, H), interpolation=cv2.INTER_LINEAR)
122
+ return np.clip(out, 0.0, 1.0)
123
+
124
+
125
+ def _normalize_masks_dtype(x: np.ndarray) -> np.ndarray:
126
+ x = np.asarray(x)
127
+ if x.dtype == np.uint8:
128
+ return (x.astype(np.float32) / 255.0)
129
+ return x.astype(np.float32, copy=False)
130
+
131
+
132
+ # -------------------------- adapter --------------------------
133
+
134
+ class _SAM2Adapter:
135
+ """
136
+ Wraps SAM2ImagePredictor to:
137
+ - store original H,W
138
+ - model-only downscale on set_image
139
+ - OOM-aware predict with retry at smaller sizes
140
+ - upsample masks back to original size
141
+ """
142
+ def __init__(self, predictor, device: str):
143
+ self.pred = predictor
144
+ self.device = device
145
+
146
+ # original image size (for upsample)
147
+ self.orig_hw: Tuple[int, int] = (0, 0)
148
+
149
+ # env tunables
150
+ self.max_edge = int(os.environ.get("SAM2_MAX_EDGE", "1024"))
151
+ self.target_pixels = int(os.environ.get("SAM2_TARGET_PIXELS", "900000"))
152
+ self.force_bgr_to_rgb = os.environ.get("SAM2_ASSUME_BGR", "0") == "1"
153
+
154
+ # precision
155
+ self.use_autocast = (device == "cuda")
156
+ # prefer bf16 if available, else fp16; it's only a hint for the internal ops
157
+ self.autocast_dtype = None
158
+ if self.use_autocast:
159
+ try:
160
+ if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
161
+ self.autocast_dtype = torch.bfloat16
162
+ else:
163
+ cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
164
+ self.autocast_dtype = torch.float16 if cc[0] >= 7 else None
165
+ except Exception:
166
+ self.autocast_dtype = None
167
+
168
+ # cached current working image (RGB uint8) and its size
169
+ self._current_rgb: Optional[np.ndarray] = None
170
+ self._current_hw: Tuple[int, int] = (0, 0)
171
+
172
+ # --- API mirror ---
173
+
174
+ def set_image(self, image: np.ndarray):
175
+ """
176
+ Accept RGB or BGR, uint8 or float, any resolution.
177
+ Model-only downscale; keep orig H,W for upsample later.
178
+ """
179
+ rgb = _ensure_rgb_uint8(image, force_bgr_to_rgb=self.force_bgr_to_rgb)
180
+ H, W = rgb.shape[:2]
181
+ self.orig_hw = (H, W)
182
+
183
+ nh, nw, s = _compute_scaled_size(H, W, self.max_edge, self.target_pixels)
184
+ if s < 1.0:
185
+ work = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
186
+ self._current_rgb = work
187
+ self._current_hw = (nh, nw)
188
+ else:
189
+ self._current_rgb = rgb
190
+ self._current_hw = (H, W)
191
+
192
+ # prime embeddings on predictor
193
+ self.pred.set_image(self._current_rgb)
194
+
195
+ def predict(self, **kwargs) -> Dict[str, Any]:
196
+ """
197
+ Forwards prompts to underlying predictor; retries smaller if OOM.
198
+ Always returns:
199
+ {"masks": (N,H,W) float32 [0..1], "scores": (N,), "logits": optional}
200
+ where (H,W) are the ORIGINAL image size provided to set_image().
201
+ """
202
+ if self._current_rgb is None or self.orig_hw == (0, 0):
203
+ raise RuntimeError("SAM2Adapter.predict called before set_image()")
204
+
205
+ H, W = self.orig_hw
206
+ nh, nw = self._current_hw
207
+ sizes = _ladder(nh, nw)
208
+
209
+ last_exc: Optional[BaseException] = None
210
+
211
+ for (th, tw) in sizes:
212
+ try:
213
+ # if we need a smaller embedding, rebuild set_image()
214
+ if (th, tw) != (nh, nw):
215
+ small = cv2.resize(self._current_rgb, (tw, th), interpolation=cv2.INTER_AREA)
216
+ self.pred.set_image(small)
217
+
218
+ # inference guard
219
+ class _NoOp:
220
+ def __enter__(self): return None
221
+ def __exit__(self, *a): return False
222
+
223
+ amp_ctx = _NoOp()
224
+ if self.use_autocast and self.autocast_dtype is not None:
225
+ amp_ctx = torch.cuda.amp.autocast(dtype=self.autocast_dtype)
226
+
227
+ with torch.inference_mode():
228
+ with amp_ctx:
229
+ out = self.pred.predict(**kwargs)
230
+
231
+ # normalize outputs to dict
232
+ masks = None
233
+ scores = None
234
+ logits = None
235
+
236
+ if isinstance(out, dict):
237
+ masks = out.get("masks", None)
238
+ scores = out.get("scores", None)
239
+ logits = out.get("logits", None)
240
+ elif isinstance(out, (tuple, list)):
241
+ if len(out) >= 1: masks = out[0]
242
+ if len(out) >= 2: scores = out[1]
243
+ if len(out) >= 3: logits = out[2]
244
+ else:
245
+ masks = out
246
+
247
+ if masks is None:
248
+ raise RuntimeError("SAM2 returned no masks")
249
+
250
+ masks = np.asarray(masks)
251
+ # SAM2 variants: (N,H,W) or (N,1,H,W) or (H,W)
252
+ if masks.ndim == 2:
253
+ masks = masks[None, ...]
254
+ elif masks.ndim == 4 and masks.shape[1] == 1:
255
+ masks = masks[:, 0, :, :]
256
+
257
+ masks = _normalize_masks_dtype(masks)
258
+
259
+ # upsample to original resolution
260
+ masks_up = _upsample_stack(masks, (H, W))
261
+
262
+ # standardize scores
263
+ if scores is None:
264
+ scores = np.ones((masks_up.shape[0],), dtype=np.float32) * 0.5
265
+ else:
266
+ scores = np.asarray(scores).astype(np.float32, copy=False).reshape(-1)
267
+
268
+ out_dict = {"masks": masks_up, "scores": scores}
269
+ if logits is not None:
270
+ # best-effort: resize per-channel to (H,W)
271
+ lg = np.asarray(logits)
272
+ if lg.ndim == 3:
273
+ lg = _upsample_stack(lg, (H, W))
274
+ elif lg.ndim == 4 and lg.shape[1] == 1:
275
+ lg = _upsample_stack(lg[:, 0, :, :], (H, W))
276
+ out_dict["logits"] = lg.astype(np.float32, copy=False)
277
+ return out_dict
278
+
279
+ except torch.cuda.OutOfMemoryError as e:
280
+ last_exc = e
281
+ logger.warning(f"SAM2 OOM at {th}x{tw}; retrying smaller. {e}")
282
+ torch.cuda.empty_cache()
283
+ continue
284
+ except Exception as e:
285
+ last_exc = e
286
+ logger.debug(traceback.format_exc())
287
+ logger.warning(f"SAM2 predict failed at {th}x{tw}; retrying smaller. {e}")
288
+ torch.cuda.empty_cache()
289
+ continue
290
+
291
+ # All attempts failed → safe fallback (full mask)
292
+ logger.warning(f"SAM2 calls failed; returning fallback. {last_exc}")
293
+ return {
294
+ "masks": np.ones((1, H, W), dtype=np.float32),
295
+ "scores": np.array([0.5], dtype=np.float32),
296
+ }
297
+
298
+
299
+ # -------------------------- Loader --------------------------
300
+
301
  class SAM2Loader:
302
  """Dedicated loader for SAM2 models"""
303
+
304
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_cache"):
305
+ self.device = _select_device(device)
306
  self.cache_dir = cache_dir
307
  os.makedirs(self.cache_dir, exist_ok=True)
308
+
309
  # Configure HF hub for spaces
310
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
311
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
312
+
313
+ self.model = None # underlying predictor (SAM2ImagePredictor)
314
+ self.adapter = None # wrapped predictor exposed to callers
315
  self.model_id = None
316
  self.load_time = 0.0
317
+
318
  def load(self, model_size: str = "auto") -> Optional[Any]:
319
  """
320
  Load SAM2 model with specified size
321
  Args:
322
  model_size: "tiny", "small", "base", "large", or "auto"
323
  Returns:
324
+ Wrapped predictor (adapter) or None
325
  """
326
  if model_size == "auto":
327
  model_size = self._determine_optimal_size()
328
+
329
  model_map = {
330
+ "tiny": "facebook/sam2.1-hiera-tiny",
331
  "small": "facebook/sam2.1-hiera-small",
332
+ "base": "facebook/sam2.1-hiera-base-plus",
333
  "large": "facebook/sam2.1-hiera-large",
334
  }
335
+
336
  self.model_id = model_map.get(model_size, model_map["tiny"])
337
+ logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
338
+
339
+ # Try the official loader
340
+ strategies = [("official", self._load_official), ("fallback", self._load_fallback)]
341
+
342
+ for name, fn in strategies:
 
 
 
 
343
  try:
344
+ t0 = time.time()
345
+ pred = fn()
346
+ if pred is None:
347
+ continue
348
+ self.model = pred
349
+ self.adapter = _SAM2Adapter(self.model, self.device)
350
+ self.load_time = time.time() - t0
351
+ logger.info(f"SAM2 loaded via {name} in {self.load_time:.2f}s")
352
+ return self.adapter
353
  except Exception as e:
354
+ logger.error(f"SAM2 {name} strategy failed: {e}")
355
  logger.debug(traceback.format_exc())
356
+
 
357
  logger.error("All SAM2 loading strategies failed")
358
  return None
359
+
360
  def _determine_optimal_size(self) -> str:
361
  """Determine optimal model size based on available memory"""
362
  try:
363
  if torch.cuda.is_available():
364
  props = torch.cuda.get_device_properties(0)
365
  vram_gb = props.total_memory / (1024**3)
366
+ if vram_gb < 4: return "tiny"
367
+ if vram_gb < 8: return "small"
368
+ if vram_gb < 12: return "base"
369
+ return "large"
370
+ except Exception:
 
 
 
 
 
371
  pass
372
+ return "tiny"
373
+
374
  def _load_official(self) -> Optional[Any]:
375
+ """Load using official SAM2 API"""
376
  from sam2.sam2_image_predictor import SAM2ImagePredictor
377
+
378
  predictor = SAM2ImagePredictor.from_pretrained(
379
  self.model_id,
380
  cache_dir=self.cache_dir,
381
  local_files_only=False,
382
  trust_remote_code=True,
383
  )
384
+
385
+ # Move internal model to device if present
386
  if hasattr(predictor, "model"):
387
  predictor.model = predictor.model.to(self.device)
388
  predictor.model.eval()
 
 
389
  if hasattr(predictor, "device"):
390
  predictor.device = self.device
391
+
 
 
392
  return predictor
393
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  def _load_fallback(self) -> Optional[Any]:
395
+ """Create a tiny fallback predictor"""
396
+
397
  class FallbackSAM2:
398
  def __init__(self, device):
399
  self.device = device
400
+ self._img = None
 
401
  def set_image(self, image):
402
+ self._img = np.asarray(image)
403
+ def predict(self, **kwargs):
404
+ if self._img is not None:
405
+ h, w = self._img.shape[:2]
 
 
406
  else:
407
  h, w = 512, 512
 
408
  return {
409
  "masks": np.ones((1, h, w), dtype=np.float32),
410
+ "scores": np.array([0.5], dtype=np.float32),
 
411
  }
412
+
413
  logger.warning("Using fallback SAM2 (no real segmentation)")
414
  return FallbackSAM2(self.device)
415
+
416
  def cleanup(self):
417
  """Clean up resources"""
418
+ self.adapter = None
419
+ if self.model is not None:
420
+ try:
421
+ del self.model
422
+ except Exception:
423
+ pass
424
  self.model = None
425
  if torch.cuda.is_available():
426
  torch.cuda.empty_cache()
427
+
428
  def get_info(self) -> Dict[str, Any]:
429
  """Get loader information"""
430
  return {
431
+ "loaded": self.adapter is not None,
432
  "model_id": self.model_id,
433
  "device": self.device,
434
  "load_time": self.load_time,
435
+ "model_type": type(self.model).__name__ if self.model else None,
436
+ }