initial segmentation app
1d29cdb
raw
history blame
652 Bytes
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
# Load the SAM model (adjust the path to the model checkpoint)
sam_checkpoint = r"H:\dev\pantareh\data\models\sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
def predict(image, points):
# Initialize the predictor
predictor = SamPredictor(sam)
predictor.set_image(image)
# Generate masks (you can adjust parameters based on your requirements)
masks, scores, logits = predictor.predict(point_coords=np.array(points), point_labels=np.ones(len(points)), box=None)
return masks, scores, logits