Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
from PIL import Image | |
import streamlit as st | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
# most of this code has been obtained from Daature's prediction script | |
# https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py | |
st.set_option('deprecation.showfileUploaderEncoding', False) | |
def load_model(): | |
return tf.saved_model.load('./saved_model') | |
def load_label_map(label_map_path): | |
""" | |
Reads label map in the format of .pbtxt and parse into dictionary | |
Args: | |
label_map_path: the file path to the label_map | |
Returns: | |
dictionary with the format of {label_index: {'id': label_index, 'name': label_name}} | |
""" | |
label_map = {} | |
with open(label_map_path, "r") as label_file: | |
for line in label_file: | |
if "id" in line: | |
label_index = int(line.split(":")[-1]) | |
label_name = next(label_file).split(":")[-1].strip().strip('"') | |
label_map[label_index] = {"id": label_index, "name": label_name} | |
return label_map | |
def predict_class(image, model): | |
image = tf.cast(image, tf.float32) | |
image = tf.image.resize(image, [150, 150]) | |
image = np.expand_dims(image, axis = 0) | |
return model.predict(image) | |
def plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape): | |
for idx, each_bbox in enumerate(bboxes): | |
color = color_map[classes[idx]] | |
## Draw bounding box | |
cv2.rectangle( | |
image_origi, | |
(int(each_bbox[1] * origi_shape[1]), | |
int(each_bbox[0] * origi_shape[0]),), | |
(int(each_bbox[3] * origi_shape[1]), | |
int(each_bbox[2] * origi_shape[0]),), | |
color, | |
2, | |
) | |
## Draw label background | |
cv2.rectangle( | |
image_origi, | |
(int(each_bbox[1] * origi_shape[1]), | |
int(each_bbox[2] * origi_shape[0]),), | |
(int(each_bbox[3] * origi_shape[1]), | |
int(each_bbox[2] * origi_shape[0] + 15),), | |
color, | |
-1, | |
) | |
## Insert label class & score | |
cv2.putText( | |
image_origi, | |
"Class: {}, Score: {}".format( | |
str(category_index[classes[idx]]["name"]), | |
str(round(scores[idx], 2)), | |
), | |
(int(each_bbox[1] * origi_shape[1]), | |
int(each_bbox[2] * origi_shape[0] + 10),), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.3, | |
(0, 0, 0), | |
1, | |
cv2.LINE_AA, | |
) | |
return image_origi | |
# Webpage code starts here | |
st.title('Banana ripeness detection 🍌') | |
st.text('made by Luzie') #TODO change with your name | |
st.markdown('## Find out if a banana is too ripe!') | |
with st.spinner('Model is being loaded...'): | |
model = load_model() | |
# ask user to upload an image | |
file = st.file_uploader("Upload an image of a banana", type=["jpg", "png"]) | |
if file is None: | |
st.text('Waiting for upload...') | |
else: | |
st.text('Running inference...') | |
# open image | |
test_image = Image.open(file).convert("RGB") | |
origi_shape = np.asarray(test_image).shape | |
# resize image to default shape | |
default_shape = 320 | |
image_resized = np.array(test_image.resize((default_shape, default_shape))) | |
## Load color map | |
category_index = load_label_map("./label_map.pbtxt") | |
# color of each label. check label_map.pbtxt to check the index for each class | |
# TODO Add more colors if there are more classes | |
color_map = { | |
2: [255, 0, 0], # overripe -> red | |
1: [0, 255, 0] # ripe -> green | |
} | |
## The model input needs to be a tensor | |
input_tensor = tf.convert_to_tensor(image_resized) | |
## The model expects a batch of images, so add an axis with `tf.newaxis`. | |
input_tensor = input_tensor[tf.newaxis, ...] | |
## Feed image into model and obtain output | |
detections_output = model(input_tensor) | |
num_detections = int(detections_output.pop("num_detections")) | |
detections = {key: value[0, :num_detections].numpy() for key, value in detections_output.items()} | |
detections["num_detections"] = num_detections | |
## Filter out predictions below threshold | |
# if threshold is higher, there will be fewer predictions | |
# TODO change this number to see how the predictions change | |
confidence_threshold = 0.5 | |
indexes = np.where(detections["detection_scores"] > confidence_threshold) | |
## Extract predicted bounding boxes | |
bboxes = detections["detection_boxes"][indexes] | |
# there are no predicted boxes | |
if len(bboxes) == 0: | |
st.error('No boxes predicted') | |
# there are predicted boxes | |
else: | |
st.success('Boxes predicted') | |
classes = detections["detection_classes"][indexes].astype(np.int64) | |
scores = detections["detection_scores"][indexes] | |
# plot boxes and labels on image | |
image_origi = np.array(Image.fromarray(image_resized).resize((origi_shape[1], origi_shape[0]))) | |
image_origi = plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape) | |
# show image in web page | |
st.image(Image.fromarray(image_origi), caption="Image with predictions", width=400) | |
st.markdown("### Predicted boxes") | |
for idx in range(len((bboxes))): | |
st.markdown(f"* Class: {str(category_index[classes[idx]]['name'])}, confidence score: {str(round(scores[idx], 2))}") | |