case: "def epoch(loader, model, device, criterion, opt=None): losses = AverageMeter() if opt is None: model.eval() else: model.train() for inputs, _ in tqdm(loader, leave=False): inputs = inputs.view(-1, 28 * 28).to(device) outputs = model(inputs) loss = criterion(outputs, inputs) if opt: opt.zero_grad(set_to_none=True) loss.backward() opt.step() model.clamp() losses.update(loss.item(), inputs.size(0)) return losses.avg"