cmoineau commited on
Commit
159b4ba
·
verified ·
1 Parent(s): 08339d0

Create lenet_mnist_torch.py

Browse files
Files changed (1) hide show
  1. lenet_mnist_torch.py +102 -0
lenet_mnist_torch.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file creates a simple lenet network using the MNIST dataset.
3
+ """
4
+
5
+ import random
6
+
7
+ import torch
8
+ from torchvision import datasets, transforms
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ # Download the MNIST Dataset
13
+
14
+ def get_mnist_dataset():
15
+ transform = transforms.ToTensor()
16
+ train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
17
+ test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
18
+ return train_set, test_set
19
+
20
+ # Create the lenet model
21
+
22
+ class Classifier(torch.nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.network = nn.Sequential(
26
+ nn.Conv2d(1, 32, 5), # 28 -> 24
27
+ nn.ReLU(),
28
+ nn.MaxPool2d(2, 2), # 24 -> 12
29
+ nn.Conv2d(32, 32, 5), # 12 -> 8
30
+ nn.ReLU(),
31
+ nn.MaxPool2d(2, 2), # 8 -> 4
32
+ nn.Flatten(),
33
+ nn.Linear(32*4*4, 100),
34
+ nn.ReLU(),
35
+ nn.Linear(100, 100),
36
+ nn.ReLU(),
37
+ nn.Linear(100, 10)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.network(x)
42
+
43
+ # Compute accuracy function
44
+
45
+ def compute_accuracy(model, data_set, nb_samples):
46
+ nb_valid = 0
47
+ for it in range(nb_samples):
48
+ # get a sample
49
+ sample_idx = torch.randint(len(data_set), size=(1,)).item()
50
+ img, label = data_set[sample_idx]
51
+ # compute the output
52
+ x = torch.reshape(img, (1,1,28,28))
53
+ y_h = model.forward(x)
54
+ pred_label = torch.argmax(y_h).item()
55
+ if label == pred_label :
56
+ nb_valid = nb_valid + 1
57
+ return nb_valid / nb_samples
58
+
59
+ # Train the model
60
+
61
+ def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier):
62
+ accuracy_history = []
63
+ for it in range(NB_ITERATION):
64
+ sample_idx = random.randint(0, len(train_set)-1)
65
+ img, label = train_set[sample_idx]
66
+ x = torch.flatten(img)
67
+ x = torch.reshape(x, (1,1,28,28))
68
+ y = torch.zeros(1,10)
69
+ y[0][label] = 1
70
+ y_h = classifier.forward(x)
71
+ #print(y_h.shape, 'test')
72
+ l = F.mse_loss(y, y_h)
73
+ l.backward()
74
+ for p in classifier.parameters():
75
+ with torch.no_grad():
76
+ p -= 0.01 * p.grad
77
+ p.grad.zero_()
78
+
79
+ if it % CHECK_PERIOD == 0:
80
+ accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD)
81
+ accuracy_history.append(accuracy)
82
+ print(f'it {it}: accuracy = {accuracy:.8f} ')
83
+
84
+
85
+ def create_lenet():
86
+ # Get Dataset
87
+ train_set, test_set = get_mnist_dataset()
88
+
89
+ # Create model
90
+ classifier = Classifier()
91
+
92
+ # Train model
93
+ NB_ITERATION = 50000
94
+ CHECK_PERIOD = 3000
95
+ print("NB_ITERATIONS = ", NB_ITERATION)
96
+ print("CHECK_PERIOD = ", CHECK_PERIOD)
97
+ print("\nTraining LeNet...")
98
+ train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier)
99
+
100
+ # Export as ONNX
101
+ x = torch.Tensor(1,1,28,28)
102
+ torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ])