Update app.py
Browse files
app.py
CHANGED
@@ -2,17 +2,32 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
4 |
from torchvision.transforms import functional as F
|
5 |
-
from PIL import Image, ImageDraw
|
6 |
import io
|
|
|
|
|
7 |
|
8 |
-
# Set
|
9 |
-
st.set_page_config(page_title="Object Detection
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
# Load model
|
16 |
@st.cache_resource
|
17 |
def load_model():
|
18 |
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
@@ -21,46 +36,63 @@ def load_model():
|
|
21 |
|
22 |
model = load_model()
|
23 |
|
24 |
-
#
|
25 |
-
|
|
|
|
|
|
|
26 |
|
27 |
-
#
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
box = boxes[i].tolist()
|
50 |
-
label = labels[i].item()
|
51 |
-
score = scores[i].item()
|
52 |
-
draw.rectangle(box, outline="red", width=3)
|
53 |
-
draw.text((box[0], box[1]), f"{label}:{score:.2f}", fill="white")
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
label="π₯ Download Detected Image",
|
63 |
-
data=byte_im,
|
64 |
-
file_name="detected.png",
|
65 |
-
mime="image/png"
|
66 |
-
)
|
|
|
2 |
import torch
|
3 |
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
4 |
from torchvision.transforms import functional as F
|
5 |
+
from PIL import Image, ImageDraw, ImageFont
|
6 |
import io
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
|
10 |
+
# Set up UI
|
11 |
+
st.set_page_config(page_title="π― AI Object Detection", layout="centered")
|
12 |
|
13 |
+
# COCO class labels (80 classes)
|
14 |
+
COCO_INSTANCE_CATEGORY_NAMES = [
|
15 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
16 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
17 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
18 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
|
19 |
+
'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
20 |
+
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
21 |
+
'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
|
22 |
+
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
23 |
+
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
24 |
+
'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
|
25 |
+
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
26 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
|
27 |
+
'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
28 |
+
]
|
29 |
|
30 |
+
# Load detection model
|
31 |
@st.cache_resource
|
32 |
def load_model():
|
33 |
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
|
|
36 |
|
37 |
model = load_model()
|
38 |
|
39 |
+
# Sidebar
|
40 |
+
st.sidebar.title("βοΈ Settings")
|
41 |
+
conf_thresh = st.sidebar.slider("Confidence Threshold", 0.2, 1.0, 0.5, 0.05)
|
42 |
+
theme = st.sidebar.radio("Theme", ["Light", "Dark"])
|
43 |
+
source = st.sidebar.radio("Input Source", ["Upload Image", "Use Webcam"])
|
44 |
|
45 |
+
# Main Title
|
46 |
+
st.title("π― Object Detection with AI")
|
47 |
+
st.markdown("Upload an image or use webcam to detect objects in real time using **Faster R-CNN** with bounding boxes and labels.")
|
48 |
|
49 |
+
def detect_objects(image, threshold=0.5):
|
50 |
+
tensor = F.to_tensor(image).unsqueeze(0)
|
51 |
+
outputs = model(tensor)
|
52 |
+
boxes = outputs[0]['boxes']
|
53 |
+
labels = outputs[0]['labels']
|
54 |
+
scores = outputs[0]['scores']
|
55 |
+
|
56 |
+
draw = ImageDraw.Draw(image)
|
57 |
|
58 |
+
for box, label, score in zip(boxes, labels, scores):
|
59 |
+
if score >= threshold:
|
60 |
+
box = box.tolist()
|
61 |
+
name = COCO_INSTANCE_CATEGORY_NAMES[label.item()]
|
62 |
+
draw.rectangle(box, outline="red", width=3)
|
63 |
+
draw.text((box[0], box[1] - 10), f"{name} {score:.2f}", fill="white" if theme == "Dark" else "black")
|
64 |
+
return image
|
65 |
|
66 |
+
if source == "Upload Image":
|
67 |
+
uploaded_file = st.file_uploader("π€ Upload Image", type=["jpg", "jpeg", "png"])
|
68 |
+
if uploaded_file:
|
69 |
+
image = Image.open(uploaded_file).convert("RGB")
|
70 |
+
st.image(image, caption="Original Image", use_column_width=True)
|
71 |
+
st.subheader("π§ Detected Objects:")
|
72 |
+
detected = detect_objects(image.copy(), conf_thresh)
|
73 |
+
st.image(detected, caption="Detection Result", use_column_width=True)
|
74 |
|
75 |
+
# Download button
|
76 |
+
buf = io.BytesIO()
|
77 |
+
detected.save(buf, format="PNG")
|
78 |
+
st.download_button("π₯ Download Image", data=buf.getvalue(), file_name="detected.png", mime="image/png")
|
79 |
+
|
80 |
+
else:
|
81 |
+
st.subheader("πΈ Real-Time Webcam Detection")
|
82 |
+
run = st.checkbox("Turn On Webcam")
|
83 |
+
|
84 |
+
FRAME_WINDOW = st.image([])
|
85 |
|
86 |
+
cap = cv2.VideoCapture(0)
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
while run:
|
89 |
+
ret, frame = cap.read()
|
90 |
+
if not ret:
|
91 |
+
st.write("Failed to get webcam feed.")
|
92 |
+
break
|
93 |
|
94 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
95 |
+
pil_img = Image.fromarray(img)
|
96 |
+
result = detect_objects(pil_img.copy(), conf_thresh)
|
97 |
+
FRAME_WINDOW.image(result)
|
98 |
+
cap.release()
|
|
|
|
|
|
|
|
|
|