Spaces:
Build error
Build error
File size: 1,985 Bytes
3ed3379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from typing import overload
import torch
from torch.nn import functional as F
class Guidance:
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type):
self.scale = scale
self.type = type
self.t_start = t_start
self.t_stop = t_stop
self.target = None
self.space = space
self.repeat = repeat
self.loss_type = loss_type
def load_target(self, target):
self.target = target
def __call__(self, target_x0, pred_x0, t):
if self.t_stop < t and t < self.t_start:
# print("sampling with classifier guidance")
# avoid propagating gradient out of this scope
pred_x0 = pred_x0.detach().clone()
target_x0 = target_x0.detach().clone()
return self.scale * self._forward(target_x0, pred_x0)
else:
return None
@overload
def _forward(self, target_x0, pred_x0): ...
class MSEGuidance(Guidance):
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None:
super().__init__(
scale, type, t_start, t_stop, space, repeat, loss_type
)
@torch.enable_grad()
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor):
# inputs: [-1, 1], nchw, rgb
pred_x0.requires_grad_(True)
if self.loss_type == "mse":
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
elif self.loss_type == "downsample_mse":
# FIXME: scale_factor should be 1/4, not 4
lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic")
lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic")
loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum()
else:
raise ValueError(self.loss_type)
print(f"loss = {loss.item()}")
return -torch.autograd.grad(loss, pred_x0)[0]
|