Update models/loaders/sam2_loader.py
Browse files- models/loaders/sam2_loader.py +68 -17
models/loaders/sam2_loader.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
|
| 4 |
"""
|
| 5 |
SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
|
| 6 |
-
-
|
| 7 |
- Never assigns predictor.device (read-only) — moves .model to device instead
|
| 8 |
- Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
|
| 9 |
- Downscale ladder on set_image(); upsample masks back to original H,W
|
|
@@ -244,6 +244,12 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/sam2_ca
|
|
| 244 |
self.load_time = 0.0
|
| 245 |
|
| 246 |
def _determine_optimal_size(self) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
try:
|
| 248 |
if torch.cuda.is_available():
|
| 249 |
props = torch.cuda.get_device_properties(0)
|
|
@@ -260,11 +266,12 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
|
|
| 260 |
if model_size == "auto":
|
| 261 |
model_size = self._determine_optimal_size()
|
| 262 |
|
|
|
|
| 263 |
model_map = {
|
| 264 |
-
"tiny": "facebook/sam2
|
| 265 |
-
"small": "facebook/sam2
|
| 266 |
-
"base": "facebook/sam2
|
| 267 |
-
"large": "facebook/sam2
|
| 268 |
}
|
| 269 |
self.model_id = model_map.get(model_size, model_map["tiny"])
|
| 270 |
logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
|
|
@@ -288,17 +295,61 @@ def load(self, model_size: str = "auto") -> Optional[_SAM2Adapter]:
|
|
| 288 |
return None
|
| 289 |
|
| 290 |
def _load_official(self):
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
def _load_fallback(self):
|
| 304 |
class FallbackSAM2:
|
|
@@ -360,4 +411,4 @@ def get_info(self) -> Dict[str, Any]:
|
|
| 360 |
m = out["masks"]
|
| 361 |
print("Masks:", m.shape, m.dtype, m.min(), m.max())
|
| 362 |
cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
|
| 363 |
-
print("Wrote sam2_mask0.png")
|
|
|
|
| 3 |
|
| 4 |
"""
|
| 5 |
SAM2 Loader + Guarded Predictor Adapter (VRAM-friendly, shape-safe, thread-safe, PyTorch 2.x)
|
| 6 |
+
- Uses traditional build_sam2 method with HF hub downloads for SAM 2.1 weights
|
| 7 |
- Never assigns predictor.device (read-only) — moves .model to device instead
|
| 8 |
- Accepts RGB/BGR, float/uint8; strips alpha; optional BGR→RGB via env
|
| 9 |
- Downscale ladder on set_image(); upsample masks back to original H,W
|
|
|
|
| 244 |
self.load_time = 0.0
|
| 245 |
|
| 246 |
def _determine_optimal_size(self) -> str:
|
| 247 |
+
# Check environment variable first
|
| 248 |
+
env_size = os.environ.get("USE_SAM2", "").lower()
|
| 249 |
+
if env_size in ["tiny", "small", "base", "large"]:
|
| 250 |
+
logger.info(f"Using SAM2 size from environment: {env_size}")
|
| 251 |
+
return env_size
|
| 252 |
+
|
| 253 |
try:
|
| 254 |
if torch.cuda.is_available():
|
| 255 |
props = torch.cuda.get_device_properties(0)
|
|
|
|
| 266 |
if model_size == "auto":
|
| 267 |
model_size = self._determine_optimal_size()
|
| 268 |
|
| 269 |
+
# Use original SAM2 model names (without .1) for compatibility
|
| 270 |
model_map = {
|
| 271 |
+
"tiny": "facebook/sam2-hiera-tiny",
|
| 272 |
+
"small": "facebook/sam2-hiera-small",
|
| 273 |
+
"base": "facebook/sam2-hiera-base-plus",
|
| 274 |
+
"large": "facebook/sam2-hiera-large",
|
| 275 |
}
|
| 276 |
self.model_id = model_map.get(model_size, model_map["tiny"])
|
| 277 |
logger.info(f"Loading SAM2 model: {self.model_id} (device={self.device})")
|
|
|
|
| 295 |
return None
|
| 296 |
|
| 297 |
def _load_official(self):
|
| 298 |
+
try:
|
| 299 |
+
from huggingface_hub import hf_hub_download
|
| 300 |
+
from sam2.build_sam import build_sam2
|
| 301 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 302 |
+
except ImportError as e:
|
| 303 |
+
logger.error(f"Failed to import SAM2 components: {e}")
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
+
# Map model IDs to config files and checkpoint names
|
| 307 |
+
config_map = {
|
| 308 |
+
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
|
| 309 |
+
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
|
| 310 |
+
"facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"),
|
| 311 |
+
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
config_file, checkpoint_file = config_map.get(self.model_id, (None, None))
|
| 315 |
+
if not config_file:
|
| 316 |
+
raise ValueError(f"Unknown model: {self.model_id}")
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
# Download the checkpoint from HuggingFace
|
| 320 |
+
logger.info(f"Downloading checkpoint: {checkpoint_file}")
|
| 321 |
+
checkpoint_path = hf_hub_download(
|
| 322 |
+
repo_id=self.model_id,
|
| 323 |
+
filename=checkpoint_file,
|
| 324 |
+
cache_dir=self.cache_dir,
|
| 325 |
+
local_files_only=False
|
| 326 |
+
)
|
| 327 |
+
logger.info(f"Checkpoint downloaded to: {checkpoint_path}")
|
| 328 |
+
|
| 329 |
+
# Also download the config file if needed
|
| 330 |
+
config_path = hf_hub_download(
|
| 331 |
+
repo_id=self.model_id,
|
| 332 |
+
filename=config_file,
|
| 333 |
+
cache_dir=self.cache_dir,
|
| 334 |
+
local_files_only=False
|
| 335 |
+
)
|
| 336 |
+
logger.info(f"Config downloaded to: {config_path}")
|
| 337 |
+
|
| 338 |
+
# Build the model using the traditional method
|
| 339 |
+
sam2_model = build_sam2(config_path, checkpoint_path, device=self.device)
|
| 340 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
| 341 |
+
|
| 342 |
+
# Ensure model is on the correct device and in eval mode
|
| 343 |
+
if hasattr(predictor, "model"):
|
| 344 |
+
predictor.model = predictor.model.to(self.device)
|
| 345 |
+
predictor.model.eval()
|
| 346 |
+
|
| 347 |
+
return predictor
|
| 348 |
+
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.error(f"Error loading SAM2 model: {e}")
|
| 351 |
+
logger.debug(traceback.format_exc())
|
| 352 |
+
return None
|
| 353 |
|
| 354 |
def _load_fallback(self):
|
| 355 |
class FallbackSAM2:
|
|
|
|
| 411 |
m = out["masks"]
|
| 412 |
print("Masks:", m.shape, m.dtype, m.min(), m.max())
|
| 413 |
cv2.imwrite("sam2_mask0.png", (np.clip(m[0], 0, 1) * 255).astype(np.uint8))
|
| 414 |
+
print("Wrote sam2_mask0.png")
|