Spaces:
Running
Running
| import os | |
| import gc | |
| import timm | |
| import gradio as gr | |
| import torch | |
| import tensorflow as tf | |
| model_names = [ | |
| "mobilenetv4_conv_small.e2400_r224_in1k", | |
| "mobilenetv4_conv_medium.e500_r224_in1k", | |
| "mobilenetv4_conv_blur_medium.e500_r224_in1k", | |
| "mobilenetv4_conv_medium.e500_r256_in1k", | |
| "mobilenetv4_conv_large.e500_r256_in1k", | |
| "mobilenetv4_conv_large.e600_r384_in1k", | |
| ] | |
| with open('imagenet_classes.txt', 'r') as file: | |
| lines = file.readlines() | |
| index_to_label = {index: line.strip() for index, line in enumerate(lines)} | |
| model, transforms = None, None | |
| tfl_model, input_details, output_details = None, None, None | |
| last_model = None | |
| def load_models(timm_model): | |
| convert_dir = "tflite_models" | |
| tf_model_path = os.path.join(convert_dir, f"{timm_model}_float16.tflite") | |
| model = timm.create_model(timm_model, pretrained=True) | |
| model = model.eval() | |
| data_config = timm.data.resolve_data_config(model=model) | |
| transforms = timm.data.create_transform(**data_config, is_training=False) | |
| tfl_model = tf.lite.Interpreter(model_path=tf_model_path) | |
| tfl_model.allocate_tensors() | |
| input_details = tfl_model.get_input_details() | |
| output_details = tfl_model.get_output_details() | |
| return model, transforms, tfl_model, input_details, output_details | |
| def classify(img, model_name): | |
| global model, transforms, tfl_model, input_details, output_details, last_model | |
| if last_model is None or model_name != last_model: | |
| if model is not None: | |
| model = None | |
| gc.collect() | |
| if tfl_model is not None: | |
| tfl_model = None | |
| gc.collect() | |
| model, transforms, tfl_model, input_details, output_details = load_models(model_name) | |
| last_model = model_name | |
| processed_img = transforms(img).unsqueeze(0) | |
| pt_output = model(processed_img) | |
| pt_top5_probs, pt_top5_indices = torch.topk(pt_output.softmax(dim=1), k=5) | |
| pt_index_list = pt_top5_indices[0].tolist() | |
| pt_probs_list = pt_top5_probs[0].tolist() | |
| pt_result_labels = { | |
| index_to_label[index]: prob | |
| for index, prob in zip(pt_index_list, pt_probs_list) | |
| } | |
| ############################################################ | |
| img_tf = processed_img.permute(0, 2, 3, 1) # BCHW to numpy BHWC | |
| input = input_details[0] | |
| tfl_model.set_tensor(input["index"], img_tf) | |
| tfl_model.invoke() | |
| tfl_output = tfl_model.get_tensor(output_details[0]["index"]) | |
| tfl_output_tensor = tf.convert_to_tensor(tfl_output) | |
| tfl_softmax_output = tf.nn.softmax(tfl_output_tensor, axis=1) | |
| tfl_top5_probs, tfl_top5_indices = tf.math.top_k(tfl_softmax_output, k=5) | |
| tfl_probs_list = tfl_top5_probs[0].numpy().tolist() | |
| tfl_index_list = tfl_top5_indices[0].numpy().tolist() | |
| tfl_result_labels = { | |
| index_to_label[index]: prob | |
| for index, prob in zip(tfl_index_list, tfl_probs_list) | |
| } | |
| return pt_result_labels, tfl_result_labels | |
| iface = gr.Interface( | |
| fn=classify, | |
| inputs=[gr.Image(type="pil"), gr.Dropdown(choices=model_names, value=model_names[0], label="Model Variant.")], | |
| outputs=[gr.Label(label="Pytorch Output"), gr.Label(label="TFLite Output")], | |
| title="MobileNetV4 Pytorch vs TFLite Imagenet1K Classification", | |
| examples=[ | |
| ["example_images/n01818515_macaw.JPEG", model_names[0]], | |
| ["example_images/n01828970_bee_eater.jpg", model_names[0]], | |
| ["example_images/n01833805_hummingbird.JPEG", model_names[0]] | |
| ] | |
| ) | |
| iface.launch() | |