[email protected] commited on
Commit
1d29cdb
·
1 Parent(s): 9941f75

initial segmentation app

Browse files
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import requests
6
+ import torch
7
+ from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor
8
+
9
+ from utils import annotate_masks
10
+ from utils.sam import predict
11
+
12
+ # Load the model and feature extractor
13
+ model_name = "facebook/detr-resnet-50"
14
+ model = AutoModelForImageSegmentation.from_pretrained(model_name)
15
+ extractor = AutoFeatureExtractor.from_pretrained(model_name)
16
+
17
+ # Function to handle segmentation
18
+ def segment_image(image):
19
+ method = "sam"
20
+ if method == "sam":
21
+ point=[300,300]
22
+ image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array
23
+ if image_rgb.size == 0:
24
+ raise ValueError("The image is empty!")
25
+ if len(image_rgb.shape) == 2: # Grayscale image fix
26
+ image_rgb = np.stack([image_rgb]*3, axis=-1)
27
+ elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB
28
+ image_rgb = image_rgb[:, :, :3]
29
+
30
+ print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}")
31
+
32
+ # Ensure correct format for SAM (RGB and np.uint8)
33
+ if image_rgb.dtype != np.uint8:
34
+ image_rgb = (image_rgb * 255).astype(np.uint8)
35
+
36
+ masks, scores, logits = predict(image_rgb, [point])
37
+ return annotate_masks(image_rgb, masks)
38
+ else:
39
+ # Prepare the image and perform segmentation
40
+ inputs = extractor(images=image, return_tensors="pt")
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
43
+ segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()
44
+
45
+ # Convert the segmentation mask to an image
46
+ mask_image = Image.fromarray(segmentation_mask.astype('uint8'))
47
+ return mask_image
48
+
49
+ # Create Gradio interface
50
+ demo = gr.Interface(
51
+ fn=segment_image,
52
+ inputs=gr.Image(type="pil"),
53
+ outputs=gr.Image(type="pil"),
54
+ live=True,
55
+ title="Image Segmentation App",
56
+ description="Upload an image and get the segmented output using a pre-trained model."
57
+ )
58
+
59
+ # Launch the Gradio app
60
+
61
+ if __name__ == "__main__":
62
+ demo.launch()
63
+
64
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/facebookresearch/segment-anything.git
2
+ gradio
3
+ supervision
4
+ torch --index-url https://download.pytorch.org/whl/cu124
5
+ torchvision --index-url https://download.pytorch.org/whl/cu124
6
+ torchaudio --index-url https://download.pytorch.org/whl/cu124
7
+ transformers
utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def annotate_masks(image, masks):
5
+ canvas = np.zeros_like(image)
6
+ for i, mask in enumerate(masks):
7
+ # Generate a unique color for each mask (you can also choose specific colors)
8
+ color = np.random.randint(0, 256, size=3) # Random color (R, G, B)
9
+
10
+ # Apply the color to the masked regions
11
+ canvas[mask == 1] = color
12
+ overlay_image = cv2.addWeighted(image, 0.7, canvas, 0.3, 0)
13
+ return overlay_image
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (539 Bytes). View file
 
utils/__pycache__/sam.cpython-39.pyc ADDED
Binary file (678 Bytes). View file
 
utils/sam.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from segment_anything import sam_model_registry, SamPredictor
4
+
5
+ # Load the SAM model (adjust the path to the model checkpoint)
6
+ sam_checkpoint = r"H:\dev\pantareh\data\models\sam_vit_h_4b8939.pth"
7
+ sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
8
+
9
+
10
+ def predict(image, points):
11
+ # Initialize the predictor
12
+ predictor = SamPredictor(sam)
13
+
14
+ predictor.set_image(image)
15
+
16
+ # Generate masks (you can adjust parameters based on your requirements)
17
+ masks, scores, logits = predictor.predict(point_coords=np.array(points), point_labels=np.ones(len(points)), box=None)
18
+ return masks, scores, logits
19
+