File size: 652 Bytes
1d29cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor

# Load the SAM model (adjust the path to the model checkpoint)
sam_checkpoint = r"H:\dev\pantareh\data\models\sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)


def predict(image, points):
    # Initialize the predictor
    predictor = SamPredictor(sam)

    predictor.set_image(image)

    # Generate masks (you can adjust parameters based on your requirements)
    masks, scores, logits = predictor.predict(point_coords=np.array(points), point_labels=np.ones(len(points)), box=None)
    return masks, scores, logits