Spaces:
Runtime error
Runtime error
Commit
·
a4443af
1
Parent(s):
5bdcaaf
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,6 +52,11 @@ DATASET_COLORMAPS = {
|
|
| 52 |
"ade20k": colormaps.ADE20K_COLORMAP,
|
| 53 |
"voc2012": colormaps.VOC2012_COLORMAP,
|
| 54 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
model = init_segmentor(cfg)
|
| 57 |
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
|
|
@@ -91,11 +96,11 @@ def create_segmenter(cfg, backbone_model):
|
|
| 91 |
|
| 92 |
|
| 93 |
def render_segmentation(segmentation_logits, dataset):
|
| 94 |
-
colormap = DATASET_COLORMAPS[dataset]
|
| 95 |
colormap_array = np.array(colormap, dtype=np.uint8)
|
| 96 |
segmentation_logits += 1
|
| 97 |
-
|
| 98 |
-
|
|
|
|
| 99 |
unique_labels = np.unique(segmentation_logits)
|
| 100 |
|
| 101 |
colormap_array = colormap_array[unique_labels]
|
|
@@ -107,7 +112,7 @@ def render_segmentation(segmentation_logits, dataset):
|
|
| 107 |
for idx, color in enumerate(colormap_array):
|
| 108 |
color_box = np.zeros((20, 20, 3), dtype=np.uint8)
|
| 109 |
color_box[:, :] = color
|
| 110 |
-
|
| 111 |
_, img_data = cv2.imencode(".jpg", color_box)
|
| 112 |
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
| 113 |
img_data_uri = f"data:image/jpg;base64,{img_base64}"
|
|
@@ -115,14 +120,15 @@ def render_segmentation(segmentation_logits, dataset):
|
|
| 115 |
|
| 116 |
html_output += "</div>"
|
| 117 |
|
| 118 |
-
return
|
| 119 |
|
| 120 |
|
| 121 |
def predict(image_file):
|
| 122 |
array = np.array(image_file)[:, :, ::-1] # BGR
|
| 123 |
segmentation_logits = inference_segmentor(model, array)[0]
|
|
|
|
| 124 |
segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
|
| 125 |
-
return
|
| 126 |
|
| 127 |
description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
|
| 128 |
|
|
@@ -130,10 +136,10 @@ demo = gr.Interface(
|
|
| 130 |
title="Semantic Segmentation - DinoV2",
|
| 131 |
fn=predict,
|
| 132 |
inputs=gr.inputs.Image(),
|
| 133 |
-
outputs=[gr.outputs.Image(type="
|
| 134 |
examples=["example_1.jpg", "example_2.jpg"],
|
| 135 |
cache_examples=False,
|
| 136 |
description=description,
|
| 137 |
)
|
| 138 |
|
| 139 |
-
demo.launch()
|
|
|
|
| 52 |
"ade20k": colormaps.ADE20K_COLORMAP,
|
| 53 |
"voc2012": colormaps.VOC2012_COLORMAP,
|
| 54 |
}
|
| 55 |
+
colormap = DATASET_COLORMAPS["ade20k"]
|
| 56 |
+
flattened = np.array(colormap).flatten()
|
| 57 |
+
zeros = np.zeros(768)
|
| 58 |
+
zeros[:flattened.shape[0]] = flattened
|
| 59 |
+
colorMap = list(zeros.astype('uint8'))
|
| 60 |
|
| 61 |
model = init_segmentor(cfg)
|
| 62 |
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
def render_segmentation(segmentation_logits, dataset):
|
|
|
|
| 99 |
colormap_array = np.array(colormap, dtype=np.uint8)
|
| 100 |
segmentation_logits += 1
|
| 101 |
+
segmented_image = Image.fromarray(segmentation_logits)
|
| 102 |
+
segmented_image.putpalette(colorMap)
|
| 103 |
+
|
| 104 |
unique_labels = np.unique(segmentation_logits)
|
| 105 |
|
| 106 |
colormap_array = colormap_array[unique_labels]
|
|
|
|
| 112 |
for idx, color in enumerate(colormap_array):
|
| 113 |
color_box = np.zeros((20, 20, 3), dtype=np.uint8)
|
| 114 |
color_box[:, :] = color
|
| 115 |
+
color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
|
| 116 |
_, img_data = cv2.imencode(".jpg", color_box)
|
| 117 |
img_base64 = base64.b64encode(img_data).decode("utf-8")
|
| 118 |
img_data_uri = f"data:image/jpg;base64,{img_base64}"
|
|
|
|
| 120 |
|
| 121 |
html_output += "</div>"
|
| 122 |
|
| 123 |
+
return segmented_image, html_output
|
| 124 |
|
| 125 |
|
| 126 |
def predict(image_file):
|
| 127 |
array = np.array(image_file)[:, :, ::-1] # BGR
|
| 128 |
segmentation_logits = inference_segmentor(model, array)[0]
|
| 129 |
+
segmentation_logits = segmentation_logits.astype(np.uint8)
|
| 130 |
segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
|
| 131 |
+
return segmented_image, html_output
|
| 132 |
|
| 133 |
description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
|
| 134 |
|
|
|
|
| 136 |
title="Semantic Segmentation - DinoV2",
|
| 137 |
fn=predict,
|
| 138 |
inputs=gr.inputs.Image(),
|
| 139 |
+
outputs=[gr.outputs.Image(type="pil"), gr.outputs.HTML()],
|
| 140 |
examples=["example_1.jpg", "example_2.jpg"],
|
| 141 |
cache_examples=False,
|
| 142 |
description=description,
|
| 143 |
)
|
| 144 |
|
| 145 |
+
demo.launch()
|