nomanmanzoor commited on
Commit
7b90a74
Β·
verified Β·
1 Parent(s): 68fe094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -42
app.py CHANGED
@@ -2,17 +2,32 @@ import streamlit as st
2
  import torch
3
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
4
  from torchvision.transforms import functional as F
5
- from PIL import Image, ImageDraw
6
  import io
 
 
7
 
8
- # Set page config
9
- st.set_page_config(page_title="Object Detection App", layout="centered")
10
 
11
- # Title and description
12
- st.title("🎯 AI Object Detection App")
13
- st.markdown("Upload an image, and let AI detect objects with bounding boxes using a pretrained Faster R-CNN model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Load model
16
  @st.cache_resource
17
  def load_model():
18
  model = fasterrcnn_resnet50_fpn(pretrained=True)
@@ -21,46 +36,63 @@ def load_model():
21
 
22
  model = load_model()
23
 
24
- # Upload image
25
- uploaded_file = st.file_uploader("πŸ“· Upload Image", type=["jpg", "jpeg", "png"])
 
 
 
26
 
27
- # Confidence threshold slider
28
- conf_thresh = st.slider("🎚 Confidence Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.05)
 
29
 
30
- if uploaded_file is not None:
31
- image = Image.open(uploaded_file).convert("RGB")
32
- st.image(image, caption="Original Image", use_column_width=True)
 
 
 
 
 
33
 
34
- # Convert image to tensor
35
- image_tensor = F.to_tensor(image).unsqueeze(0)
 
 
 
 
 
36
 
37
- # Run detection
38
- with st.spinner("Detecting objects..."):
39
- outputs = model(image_tensor)
40
- boxes = outputs[0]["boxes"]
41
- labels = outputs[0]["labels"]
42
- scores = outputs[0]["scores"]
 
 
43
 
44
- # Filter boxes by confidence threshold
45
- selected_indices = [i for i, score in enumerate(scores) if score >= conf_thresh]
46
- draw = ImageDraw.Draw(image)
 
 
 
 
 
 
 
47
 
48
- for i in selected_indices:
49
- box = boxes[i].tolist()
50
- label = labels[i].item()
51
- score = scores[i].item()
52
- draw.rectangle(box, outline="red", width=3)
53
- draw.text((box[0], box[1]), f"{label}:{score:.2f}", fill="white")
54
 
55
- st.image(image, caption="🧠 Detected Image", use_column_width=True)
 
 
 
 
56
 
57
- # Download button
58
- buf = io.BytesIO()
59
- image.save(buf, format="PNG")
60
- byte_im = buf.getvalue()
61
- st.download_button(
62
- label="πŸ“₯ Download Detected Image",
63
- data=byte_im,
64
- file_name="detected.png",
65
- mime="image/png"
66
- )
 
2
  import torch
3
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
4
  from torchvision.transforms import functional as F
5
+ from PIL import Image, ImageDraw, ImageFont
6
  import io
7
+ import cv2
8
+ import numpy as np
9
 
10
+ # Set up UI
11
+ st.set_page_config(page_title="🎯 AI Object Detection", layout="centered")
12
 
13
+ # COCO class labels (80 classes)
14
+ COCO_INSTANCE_CATEGORY_NAMES = [
15
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
16
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
17
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
18
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
19
+ 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
20
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
21
+ 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
22
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
23
+ 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
24
+ 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
25
+ 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
26
+ 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
27
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
28
+ ]
29
 
30
+ # Load detection model
31
  @st.cache_resource
32
  def load_model():
33
  model = fasterrcnn_resnet50_fpn(pretrained=True)
 
36
 
37
  model = load_model()
38
 
39
+ # Sidebar
40
+ st.sidebar.title("βš™οΈ Settings")
41
+ conf_thresh = st.sidebar.slider("Confidence Threshold", 0.2, 1.0, 0.5, 0.05)
42
+ theme = st.sidebar.radio("Theme", ["Light", "Dark"])
43
+ source = st.sidebar.radio("Input Source", ["Upload Image", "Use Webcam"])
44
 
45
+ # Main Title
46
+ st.title("🎯 Object Detection with AI")
47
+ st.markdown("Upload an image or use webcam to detect objects in real time using **Faster R-CNN** with bounding boxes and labels.")
48
 
49
+ def detect_objects(image, threshold=0.5):
50
+ tensor = F.to_tensor(image).unsqueeze(0)
51
+ outputs = model(tensor)
52
+ boxes = outputs[0]['boxes']
53
+ labels = outputs[0]['labels']
54
+ scores = outputs[0]['scores']
55
+
56
+ draw = ImageDraw.Draw(image)
57
 
58
+ for box, label, score in zip(boxes, labels, scores):
59
+ if score >= threshold:
60
+ box = box.tolist()
61
+ name = COCO_INSTANCE_CATEGORY_NAMES[label.item()]
62
+ draw.rectangle(box, outline="red", width=3)
63
+ draw.text((box[0], box[1] - 10), f"{name} {score:.2f}", fill="white" if theme == "Dark" else "black")
64
+ return image
65
 
66
+ if source == "Upload Image":
67
+ uploaded_file = st.file_uploader("πŸ“€ Upload Image", type=["jpg", "jpeg", "png"])
68
+ if uploaded_file:
69
+ image = Image.open(uploaded_file).convert("RGB")
70
+ st.image(image, caption="Original Image", use_column_width=True)
71
+ st.subheader("🧠 Detected Objects:")
72
+ detected = detect_objects(image.copy(), conf_thresh)
73
+ st.image(detected, caption="Detection Result", use_column_width=True)
74
 
75
+ # Download button
76
+ buf = io.BytesIO()
77
+ detected.save(buf, format="PNG")
78
+ st.download_button("πŸ“₯ Download Image", data=buf.getvalue(), file_name="detected.png", mime="image/png")
79
+
80
+ else:
81
+ st.subheader("πŸ“Έ Real-Time Webcam Detection")
82
+ run = st.checkbox("Turn On Webcam")
83
+
84
+ FRAME_WINDOW = st.image([])
85
 
86
+ cap = cv2.VideoCapture(0)
 
 
 
 
 
87
 
88
+ while run:
89
+ ret, frame = cap.read()
90
+ if not ret:
91
+ st.write("Failed to get webcam feed.")
92
+ break
93
 
94
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
+ pil_img = Image.fromarray(img)
96
+ result = detect_objects(pil_img.copy(), conf_thresh)
97
+ FRAME_WINDOW.image(result)
98
+ cap.release()