{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "82baf493-aca3-40ae-8d2f-33adafecb6a9", "metadata": {}, "outputs": [], "source": [ "#|default_exp app" ] }, { "cell_type": "markdown", "id": "5fec5815-2555-4b0d-bd1c-a77a7fbdeda7", "metadata": {}, "source": [ "# Digit parser\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "2c3da714-bd9c-4b8f-ae28-980f8dea239c", "metadata": {}, "outputs": [], "source": [ "#|export\n", "import torch\n", "import numpy as np\n", "import gradio as gr\n", "from PIL import Image\n", "from pathlib import Path\n", "import sys\n", "np.set_printoptions(threshold=sys.maxsize)" ] }, { "cell_type": "code", "execution_count": 2, "id": "5664caad-faca-489c-a8ab-74514aa7d706", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 2ms\u001b[0m\u001b[0m\n" ] } ], "source": [ "!uv pip install torchmetrics" ] }, { "cell_type": "code", "execution_count": null, "id": "bff78822-ebd1-4f5f-a765-cb0df804a29b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7862\n", "\n", "To create a public link, set `share=True` in `launch()`.\n", "Keyboard interruption in main thread... closing server.\n" ] }, { "data": { "text/plain": [] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#|export\n", "from lenet import LeNet5\n", "# Allowlist the custom class\n", "MODEL_PATH = Path(\"models/lenet5-cpu.pt\")\n", "model = torch.load(MODEL_PATH, weights_only=False)\n", "model.eval()\n", "\n", "def predict(img):\n", " # Create a new image with a white background\n", " background = Image.new(\"L\", (28, 28), 255)\n", "\n", " # Resize the input image\n", " img_pil = img[\"composite\"].resize((28, 28))\n", "\n", " # Paste the resized image onto the white background\n", " background.paste(img_pil, (0, 0), img_pil)\n", " \n", " # Convert to numpy\n", " img_array = np.array(background)\n", " \n", " # Invert colors (MNIST has white digits on black)\n", " img_array = 255 - img_array\n", "\n", " # Create a displayable version of the inverted image (what the model actually sees)\n", " inverted_debug = img_array.astype(np.uint8)\n", "\n", " img_tensor = torch.tensor(img_array, dtype=torch.float32) \n", " img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions\n", "\n", " # Debug: Print the shape and values of the input tensor\n", " print(f\"Input tensor shape: {img_tensor.shape}\")\n", " print(f\"Input tensor values: {img_tensor}\")\n", "\n", " with torch.no_grad():\n", " output = model(img_tensor)\n", " probabilities = torch.nn.functional.softmax(output, dim=1)[0]\n", "\n", " print(f\"Output shape: {output.shape}\")\n", " print(f\"Probabilities shape: {probabilities.shape}\")\n", " print(f\"Probabilities: {probabilities}\")\n", "\n", " # Create dictionary of label: probability for Gradio Label output\n", " return {str(i): float(prob) for i, prob in enumerate(probabilities)}, inverted_debug\n", "\n", "image = gr.Sketchpad(type=\"pil\", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=[\"#000000\"], color_mode=\"fixed\", default_size=20), layers=False, transforms=[])\n", "label = gr.Label()\n", "processed_image = gr.Image(label=\"What the Model Sees (28x28)\")\n", "intf = gr.Interface(title=\"Draw a digit\", description=\"And let me identify it for you...\", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None)\n", "intf.launch(inline=False, debug=True)" ] }, { "cell_type": "markdown", "id": "cf53a6ec-86bf-44cb-baaa-011f21f5869e", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": 1, "id": "c35ecd80-c0a1-421a-9dd2-04cca2d4c461", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2mUsing Python 3.12.7 environment at: /home/afonso/git/private/pytorch-tutorial/.venv\u001b[0m\n", "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 34ms\u001b[0m\u001b[0m\n" ] } ], "source": [ "!uv pip install nbdev\n", "from nbdev.export import nb_export" ] }, { "cell_type": "code", "execution_count": 2, "id": "de31d563-3696-45ba-9100-06c93072508c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Exported\n" ] } ], "source": [ "nb_export('app.ipynb', './')\n", "print(\"Exported\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1a443132-c4ec-4990-89c3-9a6320d14640", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "3058daee-595f-4829-ae93-a38bebdc4030", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch-tutorial", "language": "python", "name": "pytorch-tutorial" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }