import torch from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor import cv2 class SamWrapper: """ Wrapper for Segment Anything Model (SAM). Handles both automatic mask generation and guided segmentation via bounding boxes. """ def __init__(self, model_type="vit_b", checkpoint_path=None, device=None): """ Initialize the SAM model. :param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h') :param checkpoint_path: Path to the .pth checkpoint file :param device: 'cuda' or 'cpu'; if None, auto-detects """ device = "cpu" self.device = device self.model = sam_model_registry[model_type](checkpoint=checkpoint_path) self.model.to(self.device) self.automatic_generator = SamAutomaticMaskGenerator( model=self.model, points_per_side=12, pred_iou_thresh=0.92, stability_score_thresh=0.95, min_mask_region_area=1500, box_nms_thresh=0.3 ) self.predictor = SamPredictor(self.model) def generate_masks(self, image, boxes=None): """ Generate segmentation masks for the given image. :param image: Input image as NumPy array (BGR) :param boxes: Optional list of bounding boxes [x1, y1, x2, y2] :return: List of binary masks (NumPy arrays) """ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if boxes is None: masks = self.automatic_generator.generate(image_rgb) return [mask['segmentation'] for mask in masks] # Set the image once self.predictor.set_image(image_rgb) # Convert boxes to tensor and transform transformed_boxes = self.predictor.transform.apply_boxes_torch( torch.tensor(boxes, dtype=torch.float32, device=self.device), image.shape[:2] ) # Predict masks for all boxes at once masks, _, _ = self.predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False ) return [m[0].cpu().numpy() for m in masks] def predict_with_box(self, image, box): """ Predict a single segmentation mask for the given box. :param image: Input image (BGR) :param box: One bounding box [x1, y1, x2, y2] :return: Binary mask (NumPy array) """ masks = self.generate_masks(image, boxes=[box]) return masks[0] if masks else None