DiffBIR-OpenXLab / model /cond_fn.py
Dynamatrix's picture
Upload 61 files
3ed3379 verified
from typing import overload
import torch
from torch.nn import functional as F
class Guidance:
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type):
self.scale = scale
self.type = type
self.t_start = t_start
self.t_stop = t_stop
self.target = None
self.space = space
self.repeat = repeat
self.loss_type = loss_type
def load_target(self, target):
self.target = target
def __call__(self, target_x0, pred_x0, t):
if self.t_stop < t and t < self.t_start:
# print("sampling with classifier guidance")
# avoid propagating gradient out of this scope
pred_x0 = pred_x0.detach().clone()
target_x0 = target_x0.detach().clone()
return self.scale * self._forward(target_x0, pred_x0)
else:
return None
@overload
def _forward(self, target_x0, pred_x0): ...
class MSEGuidance(Guidance):
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None:
super().__init__(
scale, type, t_start, t_stop, space, repeat, loss_type
)
@torch.enable_grad()
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor):
# inputs: [-1, 1], nchw, rgb
pred_x0.requires_grad_(True)
if self.loss_type == "mse":
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
elif self.loss_type == "downsample_mse":
# FIXME: scale_factor should be 1/4, not 4
lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic")
lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic")
loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum()
else:
raise ValueError(self.loss_type)
print(f"loss = {loss.item()}")
return -torch.autograd.grad(loss, pred_x0)[0]