Gabe Rogan commited on
Commit
acd7a28
·
1 Parent(s): fc1c46b

Add object detection

Browse files
Files changed (1) hide show
  1. app.py +52 -4
app.py CHANGED
@@ -1,7 +1,55 @@
 
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
  import gradio as gr
3
+ import io
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ from random import choice
7
 
8
+ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
9
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
10
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
11
 
12
+ def get_figure(in_pil_img, in_results):
13
+ plt.figure(figsize=(16, 10))
14
+ plt.imshow(in_pil_img)
15
+ ax = plt.gca()
16
+
17
+ for prediction in in_results:
18
+ selected_color = choice(COLORS)
19
+
20
+ x, y = prediction['box']['xmin'], prediction['box']['ymin'],
21
+ w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
22
+
23
+ ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
24
+ ax.text(x, y - 3, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict={
25
+ "family" : "Arial",
26
+ "size" : 20,
27
+ "color" : selected_color,
28
+ "weight" : "bold",
29
+ })
30
+
31
+ plt.axis("off")
32
+
33
+ return plt.gcf()
34
+
35
+ def classify(in_pil_img):
36
+ detector = pipeline("object-detection", "facebook/detr-resnet-50")
37
+ results = detector(in_pil_img, { "threshold": 0.9 })
38
+
39
+ figure = get_figure(in_pil_img, results)
40
+
41
+ buf = io.BytesIO()
42
+ figure.savefig(buf, bbox_inches='tight')
43
+ buf.seek(0)
44
+ output_pil_img = Image.open(buf)
45
+
46
+ return output_pil_img
47
+
48
+ demo = gr.Interface(classify,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=gr.Image(type="pil"),
51
+ title="Object Detection",
52
+ examples=["https://iili.io/JgN38oQ.jpg", "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg"]
53
+ )
54
+
55
+ demo.launch()