Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
import torch
|
|
|
3 |
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
4 |
-
|
5 |
-
from PIL import Image, ImageDraw, ImageFont
|
6 |
-
import io
|
7 |
-
import cv2
|
8 |
-
import numpy as np
|
9 |
|
10 |
-
#
|
11 |
-
|
|
|
12 |
|
13 |
-
#
|
14 |
COCO_INSTANCE_CATEGORY_NAMES = [
|
15 |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
16 |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
@@ -20,79 +19,74 @@ COCO_INSTANCE_CATEGORY_NAMES = [
|
|
20 |
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
21 |
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
|
22 |
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
23 |
-
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
24 |
-
'
|
25 |
-
'
|
26 |
-
'
|
27 |
-
'
|
28 |
]
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
st.
|
47 |
-
st.markdown("Upload an image or use webcam to detect objects in real time using **Faster R-CNN** with bounding boxes and labels.")
|
48 |
|
49 |
-
|
50 |
-
tensor = F.to_tensor(image).unsqueeze(0)
|
51 |
-
outputs = model(tensor)
|
52 |
-
boxes = outputs[0]['boxes']
|
53 |
-
labels = outputs[0]['labels']
|
54 |
-
scores = outputs[0]['scores']
|
55 |
-
|
56 |
-
draw = ImageDraw.Draw(image)
|
57 |
|
58 |
-
|
59 |
-
if score >= threshold:
|
60 |
-
box = box.tolist()
|
61 |
-
name = COCO_INSTANCE_CATEGORY_NAMES[label.item()]
|
62 |
-
draw.rectangle(box, outline="red", width=3)
|
63 |
-
draw.text((box[0], box[1] - 10), f"{name} {score:.2f}", fill="white" if theme == "Dark" else "black")
|
64 |
-
return image
|
65 |
|
66 |
-
if
|
67 |
-
|
68 |
-
|
69 |
-
image = Image.open(uploaded_file).convert("RGB")
|
70 |
-
st.image(image, caption="Original Image", use_column_width=True)
|
71 |
-
st.subheader("π§ Detected Objects:")
|
72 |
-
detected = detect_objects(image.copy(), conf_thresh)
|
73 |
-
st.image(detected, caption="Detection Result", use_column_width=True)
|
74 |
|
75 |
-
|
76 |
-
buf = io.BytesIO()
|
77 |
-
detected.save(buf, format="PNG")
|
78 |
-
st.download_button("π₯ Download Image", data=buf.getvalue(), file_name="detected.png", mime="image/png")
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
cap = cv2.VideoCapture(0)
|
87 |
-
|
88 |
-
while run:
|
89 |
-
ret, frame = cap.read()
|
90 |
-
if not ret:
|
91 |
-
st.write("Failed to get webcam feed.")
|
92 |
-
break
|
93 |
-
|
94 |
-
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
95 |
-
pil_img = Image.fromarray(img)
|
96 |
-
result = detect_objects(pil_img.copy(), conf_thresh)
|
97 |
-
FRAME_WINDOW.image(result)
|
98 |
-
cap.release()
|
|
|
1 |
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
import torch
|
4 |
+
from torchvision import transforms
|
5 |
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
6 |
+
import torchvision
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
# Load model
|
9 |
+
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
10 |
+
model.eval()
|
11 |
|
12 |
+
# Define class labels
|
13 |
COCO_INSTANCE_CATEGORY_NAMES = [
|
14 |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
15 |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
|
|
19 |
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
20 |
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
|
21 |
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
22 |
+
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
23 |
+
'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
|
24 |
+
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
|
25 |
+
'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
26 |
+
'hair drier', 'toothbrush'
|
27 |
]
|
28 |
|
29 |
+
def get_prediction(img, threshold):
|
30 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
31 |
+
img = transform(img)
|
32 |
+
pred = model([img])
|
33 |
+
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
|
34 |
+
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
|
35 |
+
pred_score = list(pred[0]['scores'].detach().numpy())
|
36 |
+
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
|
37 |
+
boxes = pred_boxes[:pred_t+1]
|
38 |
+
classes = pred_classes[:pred_t+1]
|
39 |
+
return boxes, classes
|
40 |
|
41 |
+
# UI design
|
42 |
+
st.set_page_config(page_title="AI Object Detector", layout="wide")
|
43 |
|
44 |
+
st.markdown("""
|
45 |
+
<style>
|
46 |
+
.main {
|
47 |
+
background-color: #f5f7fa;
|
48 |
+
padding: 20px;
|
49 |
+
border-radius: 10px;
|
50 |
+
}
|
51 |
+
h1 {
|
52 |
+
color: #2c3e50;
|
53 |
+
}
|
54 |
+
.stButton>button {
|
55 |
+
background-color: #008CBA;
|
56 |
+
color: white;
|
57 |
+
font-weight: bold;
|
58 |
+
border-radius: 8px;
|
59 |
+
padding: 10px 24px;
|
60 |
+
}
|
61 |
+
</style>
|
62 |
+
""", unsafe_allow_html=True)
|
63 |
|
64 |
+
st.title("π AI Object Detection App")
|
65 |
+
st.markdown("Upload an image and let the AI detect what's in it!")
|
|
|
66 |
|
67 |
+
img_file = st.file_uploader("πΈ Upload an Image", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
confidence = st.slider("π― Confidence Threshold", 0.0, 1.0, 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
if img_file is not None:
|
72 |
+
image = Image.open(img_file).convert("RGB")
|
73 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
boxes, classes = get_prediction(image, confidence)
|
|
|
|
|
|
|
76 |
|
77 |
+
# Draw results
|
78 |
+
import matplotlib.pyplot as plt
|
79 |
+
import matplotlib.patches as patches
|
80 |
|
81 |
+
fig, ax = plt.subplots(1, figsize=(12, 8))
|
82 |
+
ax.imshow(image)
|
83 |
+
for i in range(len(boxes)):
|
84 |
+
box = boxes[i]
|
85 |
+
label = classes[i]
|
86 |
+
rect = patches.Rectangle(box[0], box[1][0]-box[0][0], box[1][1]-box[0][1],
|
87 |
+
linewidth=2, edgecolor='blue', facecolor='none')
|
88 |
+
ax.add_patch(rect)
|
89 |
+
ax.text(box[0][0], box[0][1]-10, label, fontsize=12,
|
90 |
+
color='black', bbox=dict(facecolor='lightblue', edgecolor='blue', boxstyle='round,pad=0.5'))
|
91 |
+
st.pyplot(fig)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|