import numpy as np import torch from segment_anything import sam_model_registry, SamPredictor # Load the SAM model (adjust the path to the model checkpoint) sam_checkpoint = r"H:\dev\pantareh\data\models\sam_vit_h_4b8939.pth" sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) def predict(image, points): # Initialize the predictor predictor = SamPredictor(sam) predictor.set_image(image) # Generate masks (you can adjust parameters based on your requirements) masks, scores, logits = predictor.predict(point_coords=np.array(points), point_labels=np.ones(len(points)), box=None) return masks, scores, logits