nomanmanzoor commited on
Commit
729d3e8
Β·
verified Β·
1 Parent(s): a158ee2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -41
app.py CHANGED
@@ -1,49 +1,66 @@
1
- # app.py
2
  import streamlit as st
3
- from PIL import Image
4
  import torch
5
- import torchvision.transforms as T
6
- import requests
7
- from io import BytesIO
 
8
 
9
- # Load a pre-trained model (Faster R-CNN)
10
- model = torch.hub.load('pytorch/vision:v0.10.0', 'fasterrcnn_resnet50_fpn', pretrained=True)
11
- model.eval()
12
 
13
- # COCO dataset labels
14
- LABELS_URL = 'https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt'
15
- labels = requests.get(LABELS_URL).text.split('\n')
16
-
17
- # Streamlit UI
18
- st.set_page_config(page_title="🎯 Object Detector", layout="centered")
19
  st.title("🎯 AI Object Detection App")
20
- st.markdown("Detect objects in your images using a pre-trained AI model.")
 
 
 
 
 
 
 
 
 
21
 
22
- uploaded_file = st.file_uploader("πŸ“€ Upload an Image", type=["jpg", "jpeg", "png"])
 
23
 
24
- if uploaded_file:
 
 
 
25
  image = Image.open(uploaded_file).convert("RGB")
26
- st.image(image, caption="Uploaded Image", use_column_width=True)
27
-
28
- # Transform image
29
- transform = T.Compose([T.ToTensor()])
30
- input_tensor = transform(image).unsqueeze(0)
31
-
32
- with st.spinner('Detecting objects...'):
33
- outputs = model(input_tensor)[0]
34
-
35
- threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
36
-
37
- draw = image.copy()
38
- draw_edit = ImageDraw.Draw(draw)
39
-
40
- for idx, score in enumerate(outputs['scores']):
41
- if score > threshold:
42
- box = outputs['boxes'][idx].detach().numpy()
43
- label_id = int(outputs['labels'][idx])
44
- label = labels[label_id] if label_id < len(labels) else str(label_id)
45
- draw_edit.rectangle(box, outline="red", width=3)
46
- draw_edit.text((box[0], box[1]), f"{label}: {score:.2f}", fill="white")
47
-
48
- st.image(draw, caption="Detected Objects", use_column_width=True)
49
- st.success("Detection complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
4
+ from torchvision.transforms import functional as F
5
+ from PIL import Image, ImageDraw
6
+ import io
7
 
8
+ # Set page config
9
+ st.set_page_config(page_title="Object Detection App", layout="centered")
 
10
 
11
+ # Title and description
 
 
 
 
 
12
  st.title("🎯 AI Object Detection App")
13
+ st.markdown("Upload an image, and let AI detect objects with bounding boxes using a pretrained Faster R-CNN model.")
14
+
15
+ # Load model
16
+ @st.cache_resource
17
+ def load_model():
18
+ model = fasterrcnn_resnet50_fpn(pretrained=True)
19
+ model.eval()
20
+ return model
21
+
22
+ model = load_model()
23
 
24
+ # Upload image
25
+ uploaded_file = st.file_uploader("πŸ“· Upload Image", type=["jpg", "jpeg", "png"])
26
 
27
+ # Confidence threshold slider
28
+ conf_thresh = st.slider("🎚 Confidence Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.05)
29
+
30
+ if uploaded_file is not None:
31
  image = Image.open(uploaded_file).convert("RGB")
32
+ st.image(image, caption="Original Image", use_column_width=True)
33
+
34
+ # Convert image to tensor
35
+ image_tensor = F.to_tensor(image).unsqueeze(0)
36
+
37
+ # Run detection
38
+ with st.spinner("Detecting objects..."):
39
+ outputs = model(image_tensor)
40
+ boxes = outputs[0]["boxes"]
41
+ labels = outputs[0]["labels"]
42
+ scores = outputs[0]["scores"]
43
+
44
+ # Filter boxes by confidence threshold
45
+ selected_indices = [i for i, score in enumerate(scores) if score >= conf_thresh]
46
+ draw = ImageDraw.Draw(image)
47
+
48
+ for i in selected_indices:
49
+ box = boxes[i].tolist()
50
+ label = labels[i].item()
51
+ score = scores[i].item()
52
+ draw.rectangle(box, outline="red", width=3)
53
+ draw.text((box[0], box[1]), f"{label}:{score:.2f}", fill="white")
54
+
55
+ st.image(image, caption="🧠 Detected Image", use_column_width=True)
56
+
57
+ # Download button
58
+ buf = io.BytesIO()
59
+ image.save(buf, format="PNG")
60
+ byte_im = buf.getvalue()
61
+ st.download_button(
62
+ label="πŸ“₯ Download Detected Image",
63
+ data=byte_im,
64
+ file_name="detected.png",
65
+ mime="image/png"
66
+ )