Afonso B. Sousa commited on
Commit
d95ef01
·
unverified ·
1 Parent(s): a518dad

Added a better title.

Browse files
Files changed (2) hide show
  1. app.ipynb +218 -0
  2. app.py +1 -1
app.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "82baf493-aca3-40ae-8d2f-33adafecb6a9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "#|default_exp app"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "5fec5815-2555-4b0d-bd1c-a77a7fbdeda7",
16
+ "metadata": {},
17
+ "source": [
18
+ "# Digit parser\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "2c3da714-bd9c-4b8f-ae28-980f8dea239c",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "#|export\n",
29
+ "import torch\n",
30
+ "import numpy as np\n",
31
+ "import gradio as gr\n",
32
+ "from PIL import Image\n",
33
+ "from pathlib import Path\n",
34
+ "import sys\n",
35
+ "np.set_printoptions(threshold=sys.maxsize)"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 2,
41
+ "id": "5664caad-faca-489c-a8ab-74514aa7d706",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stdout",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 2ms\u001b[0m\u001b[0m\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "!uv pip install torchmetrics"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "bff78822-ebd1-4f5f-a765-cb0df804a29b",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "* Running on local URL: http://127.0.0.1:7862\n",
67
+ "\n",
68
+ "To create a public link, set `share=True` in `launch()`.\n",
69
+ "Keyboard interruption in main thread... closing server.\n"
70
+ ]
71
+ },
72
+ {
73
+ "data": {
74
+ "text/plain": []
75
+ },
76
+ "execution_count": 3,
77
+ "metadata": {},
78
+ "output_type": "execute_result"
79
+ }
80
+ ],
81
+ "source": [
82
+ "#|export\n",
83
+ "from lenet import LeNet5\n",
84
+ "# Allowlist the custom class\n",
85
+ "MODEL_PATH = Path(\"models/lenet5-cpu.pt\")\n",
86
+ "model = torch.load(MODEL_PATH, weights_only=False)\n",
87
+ "model.eval()\n",
88
+ "\n",
89
+ "def predict(img):\n",
90
+ " # Create a new image with a white background\n",
91
+ " background = Image.new(\"L\", (28, 28), 255)\n",
92
+ "\n",
93
+ " # Resize the input image\n",
94
+ " img_pil = img[\"composite\"].resize((28, 28))\n",
95
+ "\n",
96
+ " # Paste the resized image onto the white background\n",
97
+ " background.paste(img_pil, (0, 0), img_pil)\n",
98
+ " \n",
99
+ " # Convert to numpy\n",
100
+ " img_array = np.array(background)\n",
101
+ " \n",
102
+ " # Invert colors (MNIST has white digits on black)\n",
103
+ " img_array = 255 - img_array\n",
104
+ "\n",
105
+ " # Create a displayable version of the inverted image (what the model actually sees)\n",
106
+ " inverted_debug = img_array.astype(np.uint8)\n",
107
+ "\n",
108
+ " img_tensor = torch.tensor(img_array, dtype=torch.float32) \n",
109
+ " img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions\n",
110
+ "\n",
111
+ " # Debug: Print the shape and values of the input tensor\n",
112
+ " print(f\"Input tensor shape: {img_tensor.shape}\")\n",
113
+ " print(f\"Input tensor values: {img_tensor}\")\n",
114
+ "\n",
115
+ " with torch.no_grad():\n",
116
+ " output = model(img_tensor)\n",
117
+ " probabilities = torch.nn.functional.softmax(output, dim=1)[0]\n",
118
+ "\n",
119
+ " print(f\"Output shape: {output.shape}\")\n",
120
+ " print(f\"Probabilities shape: {probabilities.shape}\")\n",
121
+ " print(f\"Probabilities: {probabilities}\")\n",
122
+ "\n",
123
+ " # Create dictionary of label: probability for Gradio Label output\n",
124
+ " return {str(i): float(prob) for i, prob in enumerate(probabilities)}, inverted_debug\n",
125
+ "\n",
126
+ "image = gr.Sketchpad(type=\"pil\", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=[\"#000000\"], color_mode=\"fixed\", default_size=20), layers=False, transforms=[])\n",
127
+ "label = gr.Label()\n",
128
+ "processed_image = gr.Image(label=\"What the Model Sees (28x28)\")\n",
129
+ "intf = gr.Interface(title=\"Draw a digit\", description=\"And let me identify it for you...\", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None)\n",
130
+ "intf.launch(inline=False, debug=True)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "id": "cf53a6ec-86bf-44cb-baaa-011f21f5869e",
136
+ "metadata": {},
137
+ "source": [
138
+ "## Export"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 1,
144
+ "id": "c35ecd80-c0a1-421a-9dd2-04cca2d4c461",
145
+ "metadata": {},
146
+ "outputs": [
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "\u001b[2mUsing Python 3.12.7 environment at: /home/afonso/git/private/pytorch-tutorial/.venv\u001b[0m\n",
152
+ "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 34ms\u001b[0m\u001b[0m\n"
153
+ ]
154
+ }
155
+ ],
156
+ "source": [
157
+ "!uv pip install nbdev\n",
158
+ "from nbdev.export import nb_export"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 2,
164
+ "id": "de31d563-3696-45ba-9100-06c93072508c",
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "name": "stdout",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "Exported\n"
172
+ ]
173
+ }
174
+ ],
175
+ "source": [
176
+ "nb_export('app.ipynb', './')\n",
177
+ "print(\"Exported\")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "1a443132-c4ec-4990-89c3-9a6320d14640",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": []
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "3058daee-595f-4829-ae93-a38bebdc4030",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": []
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "pytorch-tutorial",
200
+ "language": "python",
201
+ "name": "pytorch-tutorial"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.12.7"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }
app.py CHANGED
@@ -59,5 +59,5 @@ def predict(img):
59
  image = gr.Sketchpad(type="pil", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=20), layers=False, transforms=[])
60
  label = gr.Label()
61
  processed_image = gr.Image(label="What the Model Sees (28x28)")
62
- intf = gr.Interface(title="Title", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None)
63
  intf.launch(inline=False, debug=True)
 
59
  image = gr.Sketchpad(type="pil", sources=(), canvas_size=(280,280), brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=20), layers=False, transforms=[])
60
  label = gr.Label()
61
  processed_image = gr.Image(label="What the Model Sees (28x28)")
62
+ intf = gr.Interface(title="Draw a digit", description="And let me identify it for you...", fn=predict, inputs=image, outputs=[label, processed_image], clear_btn=None)
63
  intf.launch(inline=False, debug=True)