import streamlit as st from PIL import Image, ImageDraw from io import BytesIO import torch import torchvision.transforms as transforms from torchvision.models.detection import fasterrcnn_resnet50_fpn def draw_bounding_boxes(image, boxes, labels): draw = ImageDraw.Draw(image) for box, label in zip(boxes, labels): draw.rectangle([box[0], box[1], box[2], box[3]], outline="red", width=3) draw.text((box[0], box[1]), str(label), fill="red") return image def run_inference(image): try: # Transform the image transform = transforms.Compose([transforms.ToTensor()]) input_tensor = transform(image).unsqueeze(0) # Load a pre-trained Faster R-CNN model model = fasterrcnn_resnet50_fpn(pretrained=True) model.eval() # Perform inference with torch.no_grad(): predictions = model(input_tensor) # Extract bounding boxes, labels, and scores boxes = predictions[0]['boxes'].cpu().numpy() labels = predictions[0]['labels'].cpu().numpy() scores = predictions[0]['scores'].cpu().numpy() # Apply confidence threshold threshold = 0.25 selected_indices = scores > threshold boxes = boxes[selected_indices] labels = labels[selected_indices] # Draw bounding boxes on the image annotated_image = draw_bounding_boxes(image.copy(), boxes.astype(int), labels.astype(int)) # Save the result to a BytesIO object result_bytesio = BytesIO() annotated_image.save(result_bytesio, format='JPEG') result_bytes = result_bytesio.getvalue() # Return the bounding boxes, labels, and result image bytes return boxes.tolist(), labels.tolist(), result_bytes except Exception as e: st.error(f"Error processing the image: {e}") return [], [], None def main(): st.title("Faster R-CNN Object Detection with Streamlit") uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: # Read the uploaded image image = Image.open(uploaded_file) # Run inference and get bounding boxes, labels, and result image bytes bounding_boxes, labels, result_image_bytes = run_inference(image) # Display bounding boxes coordinates st.text(f"Bounding Boxes: {bounding_boxes}") # Display detected labels if labels: st.text(f"Detected Labels: {labels}") # Display the result image if available if result_image_bytes is not None: st.image(result_image_bytes, caption="Object Detection Result", use_column_width=True, format="JPEG") if __name__ == "__main__": main()