Spaces:
Runtime error
Runtime error
| # import the necessary packages | |
| from utilities import config | |
| from utilities import model | |
| from utilities import visualization | |
| from tensorflow import keras | |
| import gradio as gr | |
| # load the models from disk | |
| conv_stem = keras.models.load_model( | |
| config.IMAGENETTE_STEM_PATH, | |
| compile=False | |
| ) | |
| conv_trunk = keras.models.load_model( | |
| config.IMAGENETTE_TRUNK_PATH, | |
| compile=False | |
| ) | |
| conv_attn = keras.models.load_model( | |
| config.IMAGENETTE_ATTN_PATH, | |
| compile=False | |
| ) | |
| # create the patch conv net | |
| patch_conv_net = model.PatchConvNet( | |
| stem=conv_stem, | |
| trunk=conv_trunk, | |
| attention_pooling=conv_attn, | |
| ) | |
| # get the plot attention function | |
| plot_attention = visualization.PlotAttention(model=patch_conv_net) | |
| iface = gr.Interface( | |
| fn=plot_attention, | |
| inputs=[gr.inputs.Image(label="Input Image")], | |
| outputs="image").launch() |