Spaces:
Runtime error
Runtime error
from io import BytesIO | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import requests | |
import torch | |
from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor | |
from utils import annotate_masks | |
from utils.sam import predict | |
# Load the model and feature extractor | |
model_name = "facebook/detr-resnet-50" | |
model = AutoModelForImageSegmentation.from_pretrained(model_name) | |
extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
# Function to handle segmentation | |
def segment_image(image): | |
method = "sam" | |
if method == "sam": | |
point=[300,300] | |
image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array | |
if image_rgb.size == 0: | |
raise ValueError("The image is empty!") | |
if len(image_rgb.shape) == 2: # Grayscale image fix | |
image_rgb = np.stack([image_rgb]*3, axis=-1) | |
elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB | |
image_rgb = image_rgb[:, :, :3] | |
print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}") | |
# Ensure correct format for SAM (RGB and np.uint8) | |
if image_rgb.dtype != np.uint8: | |
image_rgb = (image_rgb * 255).astype(np.uint8) | |
masks, scores, logits = predict(image_rgb, [point]) | |
return annotate_masks(image_rgb, masks) | |
else: | |
# Prepare the image and perform segmentation | |
inputs = extractor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy() | |
# Convert the segmentation mask to an image | |
mask_image = Image.fromarray(segmentation_mask.astype('uint8')) | |
return mask_image | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=segment_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="pil"), | |
live=True, | |
title="Image Segmentation App", | |
description="Upload an image and get the segmented output using a pre-trained model." | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |