Afonso B. Sousa
Added a better title.
d95ef01 unverified
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
# %% auto 0
__all__ = ['MODEL_PATH', 'model', 'image', 'label', 'processed_image', 'intf', 'predict']
# %% app.ipynb 2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from pathlib import Path
import sys
np.set_printoptions(threshold=sys.maxsize)
# %% app.ipynb 4
from lenet import LeNet5
# Allowlist the custom class
MODEL_PATH = Path("models/lenet5-cpu.pt")
model = torch.load(MODEL_PATH, weights_only=False)
model.eval()
def predict(img):
# Create a new image with a white background
background = Image.new("L", (28, 28), 255)
# Resize the input image
img_pil = img["composite"].resize((28, 28))
# Paste the resized image onto the white background
background.paste(img_pil, (0, 0), img_pil)
# Convert to numpy
img_array = np.array(background)
# Invert colors (MNIST has white digits on black)
img_array = 255 - img_array
# Create a displayable version of the inverted image (what the model actually sees)
inverted_debug = img_array.astype(np.uint8)
img_tensor = torch.tensor(img_array, dtype=torch.float32)
img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions
# Debug: Print the shape and values of the input tensor
print(f"Input tensor shape: {img_tensor.shape}")
print(f"Input tensor values: {img_tensor}")
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
print(f"Output shape: {output.shape}")
print(f"Probabilities shape: {probabilities.shape}")
print(f"Probabilities: {probabilities}")
# Create dictionary of label: probability for Gradio Label output
return {str(i): float(prob) for i, prob in enumerate(probabilities)}, inverted_debug
image = gr.Sketchpad(type="pil", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=20), layers=False, transforms=[])
label = gr.Label()
processed_image = gr.Image(label="What the Model Sees (28x28)")
intf = gr.Interface(title="Draw a digit", description="And let me identify it for you...", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None)
intf.launch(inline=False, debug=True)