Spaces:
Sleeping
Sleeping
import torch | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
import cv2 | |
class SamWrapper: | |
""" | |
Wrapper for Segment Anything Model (SAM). | |
Handles both automatic mask generation and guided segmentation via bounding boxes. | |
""" | |
def __init__(self, model_type="vit_b", checkpoint_path=None, device=None): | |
""" | |
Initialize the SAM model. | |
:param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h') | |
:param checkpoint_path: Path to the .pth checkpoint file | |
:param device: 'cuda' or 'cpu'; if None, auto-detects | |
""" | |
device = "cpu" | |
self.device = device | |
self.model = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
self.model.to(self.device) | |
self.automatic_generator = SamAutomaticMaskGenerator( | |
model=self.model, | |
points_per_side=12, | |
pred_iou_thresh=0.92, | |
stability_score_thresh=0.95, | |
min_mask_region_area=1500, | |
box_nms_thresh=0.3 | |
) | |
self.predictor = SamPredictor(self.model) | |
def generate_masks(self, image, boxes=None): | |
""" | |
Generate segmentation masks for the given image. | |
:param image: Input image as NumPy array (BGR) | |
:param boxes: Optional list of bounding boxes [x1, y1, x2, y2] | |
:return: List of binary masks (NumPy arrays) | |
""" | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
if boxes is None: | |
masks = self.automatic_generator.generate(image_rgb) | |
return [mask['segmentation'] for mask in masks] | |
# Set the image once | |
self.predictor.set_image(image_rgb) | |
# Convert boxes to tensor and transform | |
transformed_boxes = self.predictor.transform.apply_boxes_torch( | |
torch.tensor(boxes, dtype=torch.float32, device=self.device), | |
image.shape[:2] | |
) | |
# Predict masks for all boxes at once | |
masks, _, _ = self.predictor.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=False | |
) | |
return [m[0].cpu().numpy() for m in masks] | |
def predict_with_box(self, image, box): | |
""" | |
Predict a single segmentation mask for the given box. | |
:param image: Input image (BGR) | |
:param box: One bounding box [x1, y1, x2, y2] | |
:return: Binary mask (NumPy array) | |
""" | |
masks = self.generate_masks(image, boxes=[box]) | |
return masks[0] if masks else None | |