teamalphabmsit's picture
Upload folder using huggingface_hub
44504f7 verified
import argparse
from pathlib import Path
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from utils import get_network, epoch
torch.manual_seed(0)
def train_nn_network(args):
p = Path(__file__)
weights_path = f"{p.parent}/weights"
Path(weights_path).mkdir(parents=True, exist_ok=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = get_network(args.net)
model.to(device)
mnist_train = datasets.MNIST(
".", train=True, download=True, transform=transforms.ToTensor()
)
mnist_test = datasets.MNIST(
".", train=False, download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(
mnist_train, batch_size=args.b, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
mnist_test, batch_size=args.b, shuffle=False, num_workers=4, pin_memory=True
)
opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
criterion = nn.MSELoss()
best_loss = None
for i in range(1, args.epochs + 1):
train_loss = epoch(train_loader, model, device, criterion, opt)
test_loss = epoch(test_loader, model, device, criterion)
if best_loss is None or best_loss > test_loss:
best_loss = test_loss
torch.save(model.state_dict(), f"{weights_path}/{args.net}.pth")
print(f"Epoch: {i} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")
if __name__ == "__main__":
train_nn_network()