from io import BytesIO import numpy as np import gradio as gr from PIL import Image import requests import torch from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor from utils import annotate_masks from utils.sam import predict # Load the model and feature extractor model_name = "facebook/detr-resnet-50" model = AutoModelForImageSegmentation.from_pretrained(model_name) extractor = AutoFeatureExtractor.from_pretrained(model_name) # Function to handle segmentation def segment_image(image): method = "sam" if method == "sam": point=[300,300] image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array if image_rgb.size == 0: raise ValueError("The image is empty!") if len(image_rgb.shape) == 2: # Grayscale image fix image_rgb = np.stack([image_rgb]*3, axis=-1) elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB image_rgb = image_rgb[:, :, :3] print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}") # Ensure correct format for SAM (RGB and np.uint8) if image_rgb.dtype != np.uint8: image_rgb = (image_rgb * 255).astype(np.uint8) masks, scores, logits = predict(image_rgb, [point]) return annotate_masks(image_rgb, masks) else: # Prepare the image and perform segmentation inputs = extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy() # Convert the segmentation mask to an image mask_image = Image.fromarray(segmentation_mask.astype('uint8')) return mask_image # Create Gradio interface demo = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), live=True, title="Image Segmentation App", description="Upload an image and get the segmented output using a pre-trained model." ) # Launch the Gradio app if __name__ == "__main__": demo.launch()