initial segmentation app
1d29cdb
from io import BytesIO
import numpy as np
import gradio as gr
from PIL import Image
import requests
import torch
from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor
from utils import annotate_masks
from utils.sam import predict
# Load the model and feature extractor
model_name = "facebook/detr-resnet-50"
model = AutoModelForImageSegmentation.from_pretrained(model_name)
extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Function to handle segmentation
def segment_image(image):
method = "sam"
if method == "sam":
point=[300,300]
image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array
if image_rgb.size == 0:
raise ValueError("The image is empty!")
if len(image_rgb.shape) == 2: # Grayscale image fix
image_rgb = np.stack([image_rgb]*3, axis=-1)
elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB
image_rgb = image_rgb[:, :, :3]
print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}")
# Ensure correct format for SAM (RGB and np.uint8)
if image_rgb.dtype != np.uint8:
image_rgb = (image_rgb * 255).astype(np.uint8)
masks, scores, logits = predict(image_rgb, [point])
return annotate_masks(image_rgb, masks)
else:
# Prepare the image and perform segmentation
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
# Convert the segmentation mask to an image
mask_image = Image.fromarray(segmentation_mask.astype('uint8'))
return mask_image
# Create Gradio interface
demo = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
live=True,
title="Image Segmentation App",
description="Upload an image and get the segmented output using a pre-trained model."
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()