Spaces:
Runtime error
Runtime error
import argparse | |
from pathlib import Path | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from torchvision import datasets | |
from utils import get_network, epoch | |
torch.manual_seed(0) | |
def train_nn_network(args): | |
p = Path(__file__) | |
weights_path = f"{p.parent}/weights" | |
Path(weights_path).mkdir(parents=True, exist_ok=True) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = get_network(args.net) | |
model.to(device) | |
mnist_train = datasets.MNIST( | |
".", train=True, download=True, transform=transforms.ToTensor() | |
) | |
mnist_test = datasets.MNIST( | |
".", train=False, download=True, transform=transforms.ToTensor() | |
) | |
train_loader = DataLoader( | |
mnist_train, batch_size=args.b, shuffle=True, num_workers=4, pin_memory=True | |
) | |
test_loader = DataLoader( | |
mnist_test, batch_size=args.b, shuffle=False, num_workers=4, pin_memory=True | |
) | |
opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) | |
criterion = nn.MSELoss() | |
best_loss = None | |
for i in range(1, args.epochs + 1): | |
train_loss = epoch(train_loader, model, device, criterion, opt) | |
test_loss = epoch(test_loader, model, device, criterion) | |
if best_loss is None or best_loss > test_loss: | |
best_loss = test_loss | |
torch.save(model.state_dict(), f"{weights_path}/{args.net}.pth") | |
print(f"Epoch: {i} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}") | |
if __name__ == "__main__": | |
train_nn_network() | |