|
|
|
import gradio as gr |
|
import os |
|
import torch |
|
|
|
from timeit import default_timer as timer |
|
from model import create_effnetb2_model |
|
from typing import Tuple , Dict |
|
|
|
|
|
class_names = ['apple_pie', |
|
'baby_back_ribs', |
|
'baklava', |
|
'beef_carpaccio', |
|
'beef_tartare', |
|
'beet_salad', |
|
'beignets', |
|
'bibimbap', |
|
'bread_pudding', |
|
'breakfast_burrito', |
|
'bruschetta', |
|
'caesar_salad', |
|
'cannoli', |
|
'caprese_salad', |
|
'carrot_cake', |
|
'ceviche', |
|
'cheese_plate', |
|
'cheesecake', |
|
'chicken_curry', |
|
'chicken_quesadilla', |
|
'chicken_wings', |
|
'chocolate_cake', |
|
'chocolate_mousse', |
|
'churros', |
|
'clam_chowder', |
|
'club_sandwich', |
|
'crab_cakes', |
|
'creme_brulee', |
|
'croque_madame', |
|
'cup_cakes', |
|
'deviled_eggs', |
|
'donuts', |
|
'dumplings', |
|
'edamame', |
|
'eggs_benedict', |
|
'escargots', |
|
'falafel', |
|
'filet_mignon', |
|
'fish_and_chips', |
|
'foie_gras', |
|
'french_fries', |
|
'french_onion_soup', |
|
'french_toast', |
|
'fried_calamari', |
|
'fried_rice', |
|
'frozen_yogurt', |
|
'garlic_bread', |
|
'gnocchi', |
|
'greek_salad', |
|
'grilled_cheese_sandwich', |
|
'grilled_salmon', |
|
'guacamole', |
|
'gyoza', |
|
'hamburger', |
|
'hot_and_sour_soup', |
|
'hot_dog', |
|
'huevos_rancheros', |
|
'hummus', |
|
'ice_cream', |
|
'lasagna', |
|
'lobster_bisque', |
|
'lobster_roll_sandwich', |
|
'macaroni_and_cheese', |
|
'macarons', |
|
'miso_soup', |
|
'mussels', |
|
'nachos', |
|
'omelette', |
|
'onion_rings', |
|
'oysters', |
|
'pad_thai', |
|
'paella', |
|
'pancakes', |
|
'panna_cotta', |
|
'peking_duck', |
|
'pho', |
|
'pizza', |
|
'pork_chop', |
|
'poutine', |
|
'prime_rib', |
|
'pulled_pork_sandwich', |
|
'ramen', |
|
'ravioli', |
|
'red_velvet_cake', |
|
'risotto', |
|
'samosa', |
|
'sashimi', |
|
'scallops', |
|
'seaweed_salad', |
|
'shrimp_and_grits', |
|
'spaghetti_bolognese', |
|
'spaghetti_carbonara', |
|
'spring_rolls', |
|
'steak', |
|
'strawberry_shortcake', |
|
'sushi', |
|
'tacos', |
|
'takoyaki', |
|
'tiramisu', |
|
'tuna_tartare', |
|
'waffles'] |
|
|
|
from torchvision.models._api import WeightsEnum |
|
from torch.hub import load_state_dict_from_url |
|
def get_state_dict(self, *args, **kwargs): |
|
kwargs.pop("check_hash") |
|
return load_state_dict_from_url(self.url, *args, **kwargs) |
|
WeightsEnum.get_state_dict = get_state_dict |
|
|
|
|
|
effnetb2_model , effnet_b2_transforms = create_effnetb2_model(num_classes = 101, seed = 42) |
|
|
|
|
|
effnetb2_model.load_state_dict( |
|
torch.load( |
|
f='11_pretrained_effnet_feature_extractor_food101.pth', |
|
map_location = torch.device('cpu') |
|
) |
|
) |
|
|
|
def predict(img) -> Tuple[Dict,float] : |
|
|
|
start_time = timer() |
|
|
|
image = effnet_b2_transforms(img).unsqueeze(0) |
|
|
|
|
|
effnetb2_model.eval() |
|
with torch.inference_mode(): |
|
|
|
pred_probs = torch.softmax(effnetb2_model(image) , dim=1) |
|
pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range (len(class_names))} |
|
|
|
end_time = timer() |
|
pred_time = round(end_time - start_time , 4) |
|
|
|
return pred_label_and_probs , pred_time |
|
|
|
|
|
|
|
title = 'FoodVision Big' |
|
description = 'An EfficientNetB2 Feature extractor computer vision model to classifiy 101 Food images ' |
|
article = 'created at PyTorch Model Deployment' |
|
|
|
|
|
|
|
example_list = [['examples/'+ example] for example in os.listdir('examples')] |
|
example_list |
|
|
|
|
|
demo = gr.Interface(fn=predict , |
|
inputs=gr.Image(type='pil'), |
|
outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'), |
|
gr.Number(label= 'Prediction time (s)')], |
|
examples = example_list, |
|
title = title, |
|
description = description, |
|
article= article) |
|
|
|
|
|
demo.launch(debug= False) |
|
|