Spaces:
Runtime error
Runtime error
[email protected]
commited on
Commit
·
1d29cdb
1
Parent(s):
9941f75
initial segmentation app
Browse files- app.py +64 -0
- requirements.txt +7 -0
- utils/__init__.py +13 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/sam.cpython-39.pyc +0 -0
- utils/sam.py +19 -0
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor
|
8 |
+
|
9 |
+
from utils import annotate_masks
|
10 |
+
from utils.sam import predict
|
11 |
+
|
12 |
+
# Load the model and feature extractor
|
13 |
+
model_name = "facebook/detr-resnet-50"
|
14 |
+
model = AutoModelForImageSegmentation.from_pretrained(model_name)
|
15 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
16 |
+
|
17 |
+
# Function to handle segmentation
|
18 |
+
def segment_image(image):
|
19 |
+
method = "sam"
|
20 |
+
if method == "sam":
|
21 |
+
point=[300,300]
|
22 |
+
image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array
|
23 |
+
if image_rgb.size == 0:
|
24 |
+
raise ValueError("The image is empty!")
|
25 |
+
if len(image_rgb.shape) == 2: # Grayscale image fix
|
26 |
+
image_rgb = np.stack([image_rgb]*3, axis=-1)
|
27 |
+
elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB
|
28 |
+
image_rgb = image_rgb[:, :, :3]
|
29 |
+
|
30 |
+
print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}")
|
31 |
+
|
32 |
+
# Ensure correct format for SAM (RGB and np.uint8)
|
33 |
+
if image_rgb.dtype != np.uint8:
|
34 |
+
image_rgb = (image_rgb * 255).astype(np.uint8)
|
35 |
+
|
36 |
+
masks, scores, logits = predict(image_rgb, [point])
|
37 |
+
return annotate_masks(image_rgb, masks)
|
38 |
+
else:
|
39 |
+
# Prepare the image and perform segmentation
|
40 |
+
inputs = extractor(images=image, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = model(**inputs)
|
43 |
+
segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
|
44 |
+
|
45 |
+
# Convert the segmentation mask to an image
|
46 |
+
mask_image = Image.fromarray(segmentation_mask.astype('uint8'))
|
47 |
+
return mask_image
|
48 |
+
|
49 |
+
# Create Gradio interface
|
50 |
+
demo = gr.Interface(
|
51 |
+
fn=segment_image,
|
52 |
+
inputs=gr.Image(type="pil"),
|
53 |
+
outputs=gr.Image(type="pil"),
|
54 |
+
live=True,
|
55 |
+
title="Image Segmentation App",
|
56 |
+
description="Upload an image and get the segmented output using a pre-trained model."
|
57 |
+
)
|
58 |
+
|
59 |
+
# Launch the Gradio app
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
demo.launch()
|
63 |
+
|
64 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
2 |
+
gradio
|
3 |
+
supervision
|
4 |
+
torch --index-url https://download.pytorch.org/whl/cu124
|
5 |
+
torchvision --index-url https://download.pytorch.org/whl/cu124
|
6 |
+
torchaudio --index-url https://download.pytorch.org/whl/cu124
|
7 |
+
transformers
|
utils/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def annotate_masks(image, masks):
|
5 |
+
canvas = np.zeros_like(image)
|
6 |
+
for i, mask in enumerate(masks):
|
7 |
+
# Generate a unique color for each mask (you can also choose specific colors)
|
8 |
+
color = np.random.randint(0, 256, size=3) # Random color (R, G, B)
|
9 |
+
|
10 |
+
# Apply the color to the masked regions
|
11 |
+
canvas[mask == 1] = color
|
12 |
+
overlay_image = cv2.addWeighted(image, 0.7, canvas, 0.3, 0)
|
13 |
+
return overlay_image
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (539 Bytes). View file
|
|
utils/__pycache__/sam.cpython-39.pyc
ADDED
Binary file (678 Bytes). View file
|
|
utils/sam.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from segment_anything import sam_model_registry, SamPredictor
|
4 |
+
|
5 |
+
# Load the SAM model (adjust the path to the model checkpoint)
|
6 |
+
sam_checkpoint = r"H:\dev\pantareh\data\models\sam_vit_h_4b8939.pth"
|
7 |
+
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
|
8 |
+
|
9 |
+
|
10 |
+
def predict(image, points):
|
11 |
+
# Initialize the predictor
|
12 |
+
predictor = SamPredictor(sam)
|
13 |
+
|
14 |
+
predictor.set_image(image)
|
15 |
+
|
16 |
+
# Generate masks (you can adjust parameters based on your requirements)
|
17 |
+
masks, scores, logits = predictor.predict(point_coords=np.array(points), point_labels=np.ones(len(points)), box=None)
|
18 |
+
return masks, scores, logits
|
19 |
+
|