Stanislav
feat: pre-ready, switch to hf
aa1c1e5
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2
class SamWrapper:
"""
Wrapper for Segment Anything Model (SAM).
Handles both automatic mask generation and guided segmentation via bounding boxes.
"""
def __init__(self, model_type="vit_b", checkpoint_path=None, device=None):
"""
Initialize the SAM model.
:param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h')
:param checkpoint_path: Path to the .pth checkpoint file
:param device: 'cuda' or 'cpu'; if None, auto-detects
"""
device = "cpu"
self.device = device
self.model = sam_model_registry[model_type](checkpoint=checkpoint_path)
self.model.to(self.device)
self.automatic_generator = SamAutomaticMaskGenerator(
model=self.model,
points_per_side=12,
pred_iou_thresh=0.92,
stability_score_thresh=0.95,
min_mask_region_area=1500,
box_nms_thresh=0.3
)
self.predictor = SamPredictor(self.model)
def generate_masks(self, image, boxes=None):
"""
Generate segmentation masks for the given image.
:param image: Input image as NumPy array (BGR)
:param boxes: Optional list of bounding boxes [x1, y1, x2, y2]
:return: List of binary masks (NumPy arrays)
"""
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if boxes is None:
masks = self.automatic_generator.generate(image_rgb)
return [mask['segmentation'] for mask in masks]
# Set the image once
self.predictor.set_image(image_rgb)
# Convert boxes to tensor and transform
transformed_boxes = self.predictor.transform.apply_boxes_torch(
torch.tensor(boxes, dtype=torch.float32, device=self.device),
image.shape[:2]
)
# Predict masks for all boxes at once
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False
)
return [m[0].cpu().numpy() for m in masks]
def predict_with_box(self, image, box):
"""
Predict a single segmentation mask for the given box.
:param image: Input image (BGR)
:param box: One bounding box [x1, y1, x2, y2]
:return: Binary mask (NumPy array)
"""
masks = self.generate_masks(image, boxes=[box])
return masks[0] if masks else None