File size: 2,176 Bytes
1d29cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from io import BytesIO
import numpy as np
import gradio as gr
from PIL import Image
import requests
import torch
from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor

from utils import annotate_masks
from utils.sam import predict

# Load the model and feature extractor
model_name = "facebook/detr-resnet-50"
model = AutoModelForImageSegmentation.from_pretrained(model_name)
extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Function to handle segmentation
def segment_image(image):
    method = "sam"
    if method == "sam":
        point=[300,300]        
        image_rgb = np.array(image)  # Converts PIL image directly to RGB NumPy array
        if image_rgb.size == 0:
            raise ValueError("The image is empty!")
        if len(image_rgb.shape) == 2:  # Grayscale image fix
            image_rgb = np.stack([image_rgb]*3, axis=-1)
        elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4:  # RGBA to RGB
            image_rgb = image_rgb[:, :, :3]

        print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}")

        # Ensure correct format for SAM (RGB and np.uint8)
        if image_rgb.dtype != np.uint8:
            image_rgb = (image_rgb * 255).astype(np.uint8)        

        masks, scores, logits = predict(image_rgb, [point])
        return annotate_masks(image_rgb, masks)
    else:        
        # Prepare the image and perform segmentation
        inputs = extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
        segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
    
        # Convert the segmentation mask to an image
        mask_image = Image.fromarray(segmentation_mask.astype('uint8'))
        return mask_image

# Create Gradio interface
demo = gr.Interface(
    fn=segment_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    live=True,
    title="Image Segmentation App",
    description="Upload an image and get the segmented output using a pre-trained model."
)

# Launch the Gradio app

if __name__ == "__main__":
    demo.launch()