Spaces:
Build error
Build error
from typing import overload | |
import torch | |
from torch.nn import functional as F | |
class Guidance: | |
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type): | |
self.scale = scale | |
self.type = type | |
self.t_start = t_start | |
self.t_stop = t_stop | |
self.target = None | |
self.space = space | |
self.repeat = repeat | |
self.loss_type = loss_type | |
def load_target(self, target): | |
self.target = target | |
def __call__(self, target_x0, pred_x0, t): | |
if self.t_stop < t and t < self.t_start: | |
# print("sampling with classifier guidance") | |
# avoid propagating gradient out of this scope | |
pred_x0 = pred_x0.detach().clone() | |
target_x0 = target_x0.detach().clone() | |
return self.scale * self._forward(target_x0, pred_x0) | |
else: | |
return None | |
def _forward(self, target_x0, pred_x0): ... | |
class MSEGuidance(Guidance): | |
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None: | |
super().__init__( | |
scale, type, t_start, t_stop, space, repeat, loss_type | |
) | |
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor): | |
# inputs: [-1, 1], nchw, rgb | |
pred_x0.requires_grad_(True) | |
if self.loss_type == "mse": | |
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() | |
elif self.loss_type == "downsample_mse": | |
# FIXME: scale_factor should be 1/4, not 4 | |
lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic") | |
lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic") | |
loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum() | |
else: | |
raise ValueError(self.loss_type) | |
print(f"loss = {loss.item()}") | |
return -torch.autograd.grad(loss, pred_x0)[0] | |