nomanmanzoor commited on
Commit
5d6db6d
Β·
verified Β·
1 Parent(s): 9300680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -72
app.py CHANGED
@@ -1,16 +1,15 @@
1
  import streamlit as st
 
2
  import torch
 
3
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
4
- from torchvision.transforms import functional as F
5
- from PIL import Image, ImageDraw, ImageFont
6
- import io
7
- import cv2
8
- import numpy as np
9
 
10
- # Set up UI
11
- st.set_page_config(page_title="🎯 AI Object Detection", layout="centered")
 
12
 
13
- # COCO class labels (80 classes)
14
  COCO_INSTANCE_CATEGORY_NAMES = [
15
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
16
  'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
@@ -20,79 +19,74 @@ COCO_INSTANCE_CATEGORY_NAMES = [
20
  'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
21
  'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
22
  'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
23
- 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
24
- 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
25
- 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
26
- 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
27
- 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
28
  ]
29
 
30
- # Load detection model
31
- @st.cache_resource
32
- def load_model():
33
- model = fasterrcnn_resnet50_fpn(pretrained=True)
34
- model.eval()
35
- return model
 
 
 
 
 
36
 
37
- model = load_model()
 
38
 
39
- # Sidebar
40
- st.sidebar.title("βš™οΈ Settings")
41
- conf_thresh = st.sidebar.slider("Confidence Threshold", 0.2, 1.0, 0.5, 0.05)
42
- theme = st.sidebar.radio("Theme", ["Light", "Dark"])
43
- source = st.sidebar.radio("Input Source", ["Upload Image", "Use Webcam"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Main Title
46
- st.title("🎯 Object Detection with AI")
47
- st.markdown("Upload an image or use webcam to detect objects in real time using **Faster R-CNN** with bounding boxes and labels.")
48
 
49
- def detect_objects(image, threshold=0.5):
50
- tensor = F.to_tensor(image).unsqueeze(0)
51
- outputs = model(tensor)
52
- boxes = outputs[0]['boxes']
53
- labels = outputs[0]['labels']
54
- scores = outputs[0]['scores']
55
-
56
- draw = ImageDraw.Draw(image)
57
 
58
- for box, label, score in zip(boxes, labels, scores):
59
- if score >= threshold:
60
- box = box.tolist()
61
- name = COCO_INSTANCE_CATEGORY_NAMES[label.item()]
62
- draw.rectangle(box, outline="red", width=3)
63
- draw.text((box[0], box[1] - 10), f"{name} {score:.2f}", fill="white" if theme == "Dark" else "black")
64
- return image
65
 
66
- if source == "Upload Image":
67
- uploaded_file = st.file_uploader("πŸ“€ Upload Image", type=["jpg", "jpeg", "png"])
68
- if uploaded_file:
69
- image = Image.open(uploaded_file).convert("RGB")
70
- st.image(image, caption="Original Image", use_column_width=True)
71
- st.subheader("🧠 Detected Objects:")
72
- detected = detect_objects(image.copy(), conf_thresh)
73
- st.image(detected, caption="Detection Result", use_column_width=True)
74
 
75
- # Download button
76
- buf = io.BytesIO()
77
- detected.save(buf, format="PNG")
78
- st.download_button("πŸ“₯ Download Image", data=buf.getvalue(), file_name="detected.png", mime="image/png")
79
 
80
- else:
81
- st.subheader("πŸ“Έ Real-Time Webcam Detection")
82
- run = st.checkbox("Turn On Webcam")
83
 
84
- FRAME_WINDOW = st.image([])
 
 
 
 
 
 
 
 
 
 
85
 
86
- cap = cv2.VideoCapture(0)
87
-
88
- while run:
89
- ret, frame = cap.read()
90
- if not ret:
91
- st.write("Failed to get webcam feed.")
92
- break
93
-
94
- img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
- pil_img = Image.fromarray(img)
96
- result = detect_objects(pil_img.copy(), conf_thresh)
97
- FRAME_WINDOW.image(result)
98
- cap.release()
 
1
  import streamlit as st
2
+ from PIL import Image
3
  import torch
4
+ from torchvision import transforms
5
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
6
+ import torchvision
 
 
 
 
7
 
8
+ # Load model
9
+ model = fasterrcnn_resnet50_fpn(pretrained=True)
10
+ model.eval()
11
 
12
+ # Define class labels
13
  COCO_INSTANCE_CATEGORY_NAMES = [
14
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15
  'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
 
19
  'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
20
  'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
21
  'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
22
+ 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
23
+ 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
24
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
25
+ 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
26
+ 'hair drier', 'toothbrush'
27
  ]
28
 
29
+ def get_prediction(img, threshold):
30
+ transform = transforms.Compose([transforms.ToTensor()])
31
+ img = transform(img)
32
+ pred = model([img])
33
+ pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
34
+ pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
35
+ pred_score = list(pred[0]['scores'].detach().numpy())
36
+ pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
37
+ boxes = pred_boxes[:pred_t+1]
38
+ classes = pred_classes[:pred_t+1]
39
+ return boxes, classes
40
 
41
+ # UI design
42
+ st.set_page_config(page_title="AI Object Detector", layout="wide")
43
 
44
+ st.markdown("""
45
+ <style>
46
+ .main {
47
+ background-color: #f5f7fa;
48
+ padding: 20px;
49
+ border-radius: 10px;
50
+ }
51
+ h1 {
52
+ color: #2c3e50;
53
+ }
54
+ .stButton>button {
55
+ background-color: #008CBA;
56
+ color: white;
57
+ font-weight: bold;
58
+ border-radius: 8px;
59
+ padding: 10px 24px;
60
+ }
61
+ </style>
62
+ """, unsafe_allow_html=True)
63
 
64
+ st.title("πŸ” AI Object Detection App")
65
+ st.markdown("Upload an image and let the AI detect what's in it!")
 
66
 
67
+ img_file = st.file_uploader("πŸ“Έ Upload an Image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
68
 
69
+ confidence = st.slider("🎯 Confidence Threshold", 0.0, 1.0, 0.5)
 
 
 
 
 
 
70
 
71
+ if img_file is not None:
72
+ image = Image.open(img_file).convert("RGB")
73
+ st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
74
 
75
+ boxes, classes = get_prediction(image, confidence)
 
 
 
76
 
77
+ # Draw results
78
+ import matplotlib.pyplot as plt
79
+ import matplotlib.patches as patches
80
 
81
+ fig, ax = plt.subplots(1, figsize=(12, 8))
82
+ ax.imshow(image)
83
+ for i in range(len(boxes)):
84
+ box = boxes[i]
85
+ label = classes[i]
86
+ rect = patches.Rectangle(box[0], box[1][0]-box[0][0], box[1][1]-box[0][1],
87
+ linewidth=2, edgecolor='blue', facecolor='none')
88
+ ax.add_patch(rect)
89
+ ax.text(box[0][0], box[0][1]-10, label, fontsize=12,
90
+ color='black', bbox=dict(facecolor='lightblue', edgecolor='blue', boxstyle='round,pad=0.5'))
91
+ st.pyplot(fig)
92