object-detect / app.py
Gabe Rogan
Add object detection
acd7a28
raw
history blame
1.64 kB
from transformers import pipeline
import gradio as gr
import io
import matplotlib.pyplot as plt
from PIL import Image
from random import choice
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
def get_figure(in_pil_img, in_results):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y - 3, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict={
"family" : "Arial",
"size" : 20,
"color" : selected_color,
"weight" : "bold",
})
plt.axis("off")
return plt.gcf()
def classify(in_pil_img):
detector = pipeline("object-detection", "facebook/detr-resnet-50")
results = detector(in_pil_img, { "threshold": 0.9 })
figure = get_figure(in_pil_img, results)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
demo = gr.Interface(classify,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Object Detection",
examples=["https://iili.io/JgN38oQ.jpg", "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg"]
)
demo.launch()