Spaces:
Runtime error
Runtime error
File size: 2,176 Bytes
1d29cdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from io import BytesIO
import numpy as np
import gradio as gr
from PIL import Image
import requests
import torch
from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor
from utils import annotate_masks
from utils.sam import predict
# Load the model and feature extractor
model_name = "facebook/detr-resnet-50"
model = AutoModelForImageSegmentation.from_pretrained(model_name)
extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Function to handle segmentation
def segment_image(image):
method = "sam"
if method == "sam":
point=[300,300]
image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array
if image_rgb.size == 0:
raise ValueError("The image is empty!")
if len(image_rgb.shape) == 2: # Grayscale image fix
image_rgb = np.stack([image_rgb]*3, axis=-1)
elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB
image_rgb = image_rgb[:, :, :3]
print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}")
# Ensure correct format for SAM (RGB and np.uint8)
if image_rgb.dtype != np.uint8:
image_rgb = (image_rgb * 255).astype(np.uint8)
masks, scores, logits = predict(image_rgb, [point])
return annotate_masks(image_rgb, masks)
else:
# Prepare the image and perform segmentation
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
# Convert the segmentation mask to an image
mask_image = Image.fromarray(segmentation_mask.astype('uint8'))
return mask_image
# Create Gradio interface
demo = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
live=True,
title="Image Segmentation App",
description="Upload an image and get the segmented output using a pre-trained model."
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|