teamalphabmsit's picture
Upload folder using huggingface_hub
44504f7 verified
from models import (
SigmoidNNAutoencoder,
TanhNNAutoencoder,
TanhPNAutoencoder,
ReLUNNAutoencoder,
ReLUPNAutoencoder,
TanhSwishNNAutoencoder,
ReLUSigmoidNRAutoencoder,
ReLUSigmoidRRAutoencoder,
)
from tqdm import tqdm
def get_network(name):
match name:
case "nn_sigmoid":
return SigmoidNNAutoencoder()
case "nn_tanh":
return TanhNNAutoencoder()
case "pn_tanh":
return TanhPNAutoencoder()
case "nn_relu":
return ReLUNNAutoencoder()
case "pn_relu":
return ReLUPNAutoencoder()
case "nn_tanh_swish":
return TanhSwishNNAutoencoder()
case "nr_relu_sigmoid":
return ReLUSigmoidNRAutoencoder()
case "rr_relu_sigmoid":
return ReLUSigmoidRRAutoencoder()
case _:
raise NotImplementedError(
f"Autoencoder of name '{name}' currently is not supported"
)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def epoch(loader, model, device, criterion, opt=None):
losses = AverageMeter()
if opt is None:
model.eval()
else:
model.train()
for inputs, _ in tqdm(loader, leave=False):
inputs = inputs.view(-1, 28 * 28).to(device)
outputs = model(inputs)
loss = criterion(outputs, inputs)
if opt:
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
model.clamp()
losses.update(loss.item(), inputs.size(0))
return losses.avg