Spaces:
Runtime error
Runtime error
from models import ( | |
SigmoidNNAutoencoder, | |
TanhNNAutoencoder, | |
TanhPNAutoencoder, | |
ReLUNNAutoencoder, | |
ReLUPNAutoencoder, | |
TanhSwishNNAutoencoder, | |
ReLUSigmoidNRAutoencoder, | |
ReLUSigmoidRRAutoencoder, | |
) | |
from tqdm import tqdm | |
def get_network(name): | |
match name: | |
case "nn_sigmoid": | |
return SigmoidNNAutoencoder() | |
case "nn_tanh": | |
return TanhNNAutoencoder() | |
case "pn_tanh": | |
return TanhPNAutoencoder() | |
case "nn_relu": | |
return ReLUNNAutoencoder() | |
case "pn_relu": | |
return ReLUPNAutoencoder() | |
case "nn_tanh_swish": | |
return TanhSwishNNAutoencoder() | |
case "nr_relu_sigmoid": | |
return ReLUSigmoidNRAutoencoder() | |
case "rr_relu_sigmoid": | |
return ReLUSigmoidRRAutoencoder() | |
case _: | |
raise NotImplementedError( | |
f"Autoencoder of name '{name}' currently is not supported" | |
) | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def epoch(loader, model, device, criterion, opt=None): | |
losses = AverageMeter() | |
if opt is None: | |
model.eval() | |
else: | |
model.train() | |
for inputs, _ in tqdm(loader, leave=False): | |
inputs = inputs.view(-1, 28 * 28).to(device) | |
outputs = model(inputs) | |
loss = criterion(outputs, inputs) | |
if opt: | |
opt.zero_grad(set_to_none=True) | |
loss.backward() | |
opt.step() | |
model.clamp() | |
losses.update(loss.item(), inputs.size(0)) | |
return losses.avg | |