Create lenet_mnist_torch.py
Browse files- lenet_mnist_torch.py +102 -0
lenet_mnist_torch.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file creates a simple lenet network using the MNIST dataset.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import datasets, transforms
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
# Download the MNIST Dataset
|
| 13 |
+
|
| 14 |
+
def get_mnist_dataset():
|
| 15 |
+
transform = transforms.ToTensor()
|
| 16 |
+
train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
|
| 17 |
+
test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
|
| 18 |
+
return train_set, test_set
|
| 19 |
+
|
| 20 |
+
# Create the lenet model
|
| 21 |
+
|
| 22 |
+
class Classifier(torch.nn.Module):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.network = nn.Sequential(
|
| 26 |
+
nn.Conv2d(1, 32, 5), # 28 -> 24
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.MaxPool2d(2, 2), # 24 -> 12
|
| 29 |
+
nn.Conv2d(32, 32, 5), # 12 -> 8
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.MaxPool2d(2, 2), # 8 -> 4
|
| 32 |
+
nn.Flatten(),
|
| 33 |
+
nn.Linear(32*4*4, 100),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
nn.Linear(100, 100),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
nn.Linear(100, 10)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.network(x)
|
| 42 |
+
|
| 43 |
+
# Compute accuracy function
|
| 44 |
+
|
| 45 |
+
def compute_accuracy(model, data_set, nb_samples):
|
| 46 |
+
nb_valid = 0
|
| 47 |
+
for it in range(nb_samples):
|
| 48 |
+
# get a sample
|
| 49 |
+
sample_idx = torch.randint(len(data_set), size=(1,)).item()
|
| 50 |
+
img, label = data_set[sample_idx]
|
| 51 |
+
# compute the output
|
| 52 |
+
x = torch.reshape(img, (1,1,28,28))
|
| 53 |
+
y_h = model.forward(x)
|
| 54 |
+
pred_label = torch.argmax(y_h).item()
|
| 55 |
+
if label == pred_label :
|
| 56 |
+
nb_valid = nb_valid + 1
|
| 57 |
+
return nb_valid / nb_samples
|
| 58 |
+
|
| 59 |
+
# Train the model
|
| 60 |
+
|
| 61 |
+
def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier):
|
| 62 |
+
accuracy_history = []
|
| 63 |
+
for it in range(NB_ITERATION):
|
| 64 |
+
sample_idx = random.randint(0, len(train_set)-1)
|
| 65 |
+
img, label = train_set[sample_idx]
|
| 66 |
+
x = torch.flatten(img)
|
| 67 |
+
x = torch.reshape(x, (1,1,28,28))
|
| 68 |
+
y = torch.zeros(1,10)
|
| 69 |
+
y[0][label] = 1
|
| 70 |
+
y_h = classifier.forward(x)
|
| 71 |
+
#print(y_h.shape, 'test')
|
| 72 |
+
l = F.mse_loss(y, y_h)
|
| 73 |
+
l.backward()
|
| 74 |
+
for p in classifier.parameters():
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
p -= 0.01 * p.grad
|
| 77 |
+
p.grad.zero_()
|
| 78 |
+
|
| 79 |
+
if it % CHECK_PERIOD == 0:
|
| 80 |
+
accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD)
|
| 81 |
+
accuracy_history.append(accuracy)
|
| 82 |
+
print(f'it {it}: accuracy = {accuracy:.8f} ')
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def create_lenet():
|
| 86 |
+
# Get Dataset
|
| 87 |
+
train_set, test_set = get_mnist_dataset()
|
| 88 |
+
|
| 89 |
+
# Create model
|
| 90 |
+
classifier = Classifier()
|
| 91 |
+
|
| 92 |
+
# Train model
|
| 93 |
+
NB_ITERATION = 50000
|
| 94 |
+
CHECK_PERIOD = 3000
|
| 95 |
+
print("NB_ITERATIONS = ", NB_ITERATION)
|
| 96 |
+
print("CHECK_PERIOD = ", CHECK_PERIOD)
|
| 97 |
+
print("\nTraining LeNet...")
|
| 98 |
+
train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier)
|
| 99 |
+
|
| 100 |
+
# Export as ONNX
|
| 101 |
+
x = torch.Tensor(1,1,28,28)
|
| 102 |
+
torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ])
|