Spaces:
Build error
Build error
from typing import Optional, Tuple, Dict | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from ldm.modules.diffusionmodules.util import make_beta_schedule | |
from model.cond_fn import Guidance | |
from utils.image import ( | |
wavelet_reconstruction, adaptive_instance_normalization | |
) | |
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py | |
def space_timesteps(num_timesteps, section_counts): | |
""" | |
Create a list of timesteps to use from an original diffusion process, | |
given the number of timesteps we want to take from equally-sized portions | |
of the original process. | |
For example, if there's 300 timesteps and the section counts are [10,15,20] | |
then the first 100 timesteps are strided to be 10 timesteps, the second 100 | |
are strided to be 15 timesteps, and the final 100 are strided to be 20. | |
If the stride is a string starting with "ddim", then the fixed striding | |
from the DDIM paper is used, and only one section is allowed. | |
:param num_timesteps: the number of diffusion steps in the original | |
process to divide up. | |
:param section_counts: either a list of numbers, or a string containing | |
comma-separated numbers, indicating the step count | |
per section. As a special case, use "ddimN" where N | |
is a number of steps to use the striding from the | |
DDIM paper. | |
:return: a set of diffusion steps from the original process to use. | |
""" | |
if isinstance(section_counts, str): | |
if section_counts.startswith("ddim"): | |
desired_count = int(section_counts[len("ddim") :]) | |
for i in range(1, num_timesteps): | |
if len(range(0, num_timesteps, i)) == desired_count: | |
return set(range(0, num_timesteps, i)) | |
raise ValueError( | |
f"cannot create exactly {num_timesteps} steps with an integer stride" | |
) | |
section_counts = [int(x) for x in section_counts.split(",")] | |
size_per = num_timesteps // len(section_counts) | |
extra = num_timesteps % len(section_counts) | |
start_idx = 0 | |
all_steps = [] | |
for i, section_count in enumerate(section_counts): | |
size = size_per + (1 if i < extra else 0) | |
if size < section_count: | |
raise ValueError( | |
f"cannot divide section of {size} steps into {section_count}" | |
) | |
if section_count <= 1: | |
frac_stride = 1 | |
else: | |
frac_stride = (size - 1) / (section_count - 1) | |
cur_idx = 0.0 | |
taken_steps = [] | |
for _ in range(section_count): | |
taken_steps.append(start_idx + round(cur_idx)) | |
cur_idx += frac_stride | |
all_steps += taken_steps | |
start_idx += size | |
return set(all_steps) | |
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py | |
def _extract_into_tensor(arr, timesteps, broadcast_shape): | |
""" | |
Extract values from a 1-D numpy array for a batch of indices. | |
:param arr: the 1-D numpy array. | |
:param timesteps: a tensor of indices into the array to extract. | |
:param broadcast_shape: a larger shape of K dimensions with the batch | |
dimension equal to the length of timesteps. | |
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. | |
""" | |
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() | |
while len(res.shape) < len(broadcast_shape): | |
res = res[..., None] | |
return res.expand(broadcast_shape) | |
class SpacedSampler: | |
""" | |
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed | |
for sampling ControlLDM. | |
https://arxiv.org/pdf/2102.09672.pdf | |
""" | |
def __init__( | |
self, | |
model: "ControlLDM", | |
schedule: str="linear", | |
var_type: str="fixed_small" | |
) -> "SpacedSampler": | |
self.model = model | |
self.original_num_steps = model.num_timesteps | |
self.schedule = schedule | |
self.var_type = var_type | |
def make_schedule(self, num_steps: int) -> None: | |
""" | |
Initialize sampling parameters according to `num_steps`. | |
Args: | |
num_steps (int): Sampling steps. | |
Returns: | |
None | |
""" | |
# NOTE: this schedule, which generates betas linearly in log space, is a little different | |
# from guided diffusion. | |
original_betas = make_beta_schedule( | |
self.schedule, self.original_num_steps, linear_start=self.model.linear_start, | |
linear_end=self.model.linear_end | |
) | |
original_alphas = 1.0 - original_betas | |
original_alphas_cumprod = np.cumprod(original_alphas, axis=0) | |
# calcualte betas for spaced sampling | |
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py | |
used_timesteps = space_timesteps(self.original_num_steps, str(num_steps)) | |
print(f"timesteps used in spaced sampler: \n\t{sorted(list(used_timesteps))}") | |
betas = [] | |
last_alpha_cumprod = 1.0 | |
for i, alpha_cumprod in enumerate(original_alphas_cumprod): | |
if i in used_timesteps: | |
# marginal distribution is the same as q(x_{S_t}|x_0) | |
betas.append(1 - alpha_cumprod / last_alpha_cumprod) | |
last_alpha_cumprod = alpha_cumprod | |
assert len(betas) == num_steps | |
betas = np.array(betas, dtype=np.float64) | |
self.betas = betas | |
self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...] | |
alphas = 1.0 - betas | |
self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) | |
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) | |
assert self.alphas_cumprod_prev.shape == (num_steps, ) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) | |
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) | |
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) | |
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.posterior_variance = ( | |
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
) | |
# log calculation clipped because the posterior variance is 0 at the | |
# beginning of the diffusion chain. | |
self.posterior_log_variance_clipped = np.log( | |
np.append(self.posterior_variance[1], self.posterior_variance[1:]) | |
) | |
self.posterior_mean_coef1 = ( | |
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
) | |
self.posterior_mean_coef2 = ( | |
(1.0 - self.alphas_cumprod_prev) | |
* np.sqrt(alphas) | |
/ (1.0 - self.alphas_cumprod) | |
) | |
def q_sample( | |
self, | |
x_start: torch.Tensor, | |
t: torch.Tensor, | |
noise: Optional[torch.Tensor]=None | |
) -> torch.Tensor: | |
""" | |
Implement the marginal distribution q(x_t|x_0). | |
Args: | |
x_start (torch.Tensor): Images (NCHW) sampled from data distribution. | |
t (torch.Tensor): Timestep (N) for diffusion process. `t` serves as an index | |
to get parameters for each timestep. | |
noise (torch.Tensor, optional): Specify the noise (NCHW) added to `x_start`. | |
Returns: | |
x_t (torch.Tensor): The noisy images. | |
""" | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
assert noise.shape == x_start.shape | |
return ( | |
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
* noise | |
) | |
def q_posterior_mean_variance( | |
self, | |
x_start: torch.Tensor, | |
x_t: torch.Tensor, | |
t: torch.Tensor | |
) -> Tuple[torch.Tensor]: | |
""" | |
Implement the posterior distribution q(x_{t-1}|x_t, x_0). | |
Args: | |
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. | |
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. | |
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get | |
parameters for each timestep. | |
Returns: | |
posterior_mean (torch.Tensor): Mean of the posterior distribution. | |
posterior_variance (torch.Tensor): Variance of the posterior distribution. | |
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. | |
""" | |
assert x_start.shape == x_t.shape | |
posterior_mean = ( | |
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
) | |
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) | |
posterior_log_variance_clipped = _extract_into_tensor( | |
self.posterior_log_variance_clipped, t, x_t.shape | |
) | |
assert ( | |
posterior_mean.shape[0] | |
== posterior_variance.shape[0] | |
== posterior_log_variance_clipped.shape[0] | |
== x_start.shape[0] | |
) | |
return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
def _predict_xstart_from_eps( | |
self, | |
x_t: torch.Tensor, | |
t: torch.Tensor, | |
eps: torch.Tensor | |
) -> torch.Tensor: | |
assert x_t.shape == eps.shape | |
return ( | |
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps | |
) | |
def predict_noise( | |
self, | |
x: torch.Tensor, | |
t: torch.Tensor, | |
cond: Dict[str, torch.Tensor], | |
cfg_scale: float, | |
uncond: Optional[Dict[str, torch.Tensor]] | |
) -> torch.Tensor: | |
if uncond is None or cfg_scale == 1.: | |
model_output = self.model.apply_model(x, t, cond) | |
else: | |
# apply classifier-free guidance | |
model_cond = self.model.apply_model(x, t, cond) | |
model_uncond = self.model.apply_model(x, t, uncond) | |
model_output = model_uncond + cfg_scale * (model_cond - model_uncond) | |
if self.model.parameterization == "v": | |
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) | |
else: | |
e_t = model_output | |
return e_t | |
def apply_cond_fn( | |
self, | |
x: torch.Tensor, | |
cond: Dict[str, torch.Tensor], | |
t: torch.Tensor, | |
index: torch.Tensor, | |
cond_fn: Guidance, | |
cfg_scale: float, | |
uncond: Optional[Dict[str, torch.Tensor]] | |
) -> torch.Tensor: | |
device = x.device | |
t_now = int(t[0].item()) + 1 | |
# ----------------- predict noise and x0 ----------------- # | |
e_t = self.predict_noise( | |
x, t, cond, cfg_scale, uncond | |
) | |
pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) | |
model_mean, _, _ = self.q_posterior_mean_variance( | |
x_start=pred_x0, x_t=x, t=index | |
) | |
# apply classifier guidance for multiple times | |
for _ in range(cond_fn.repeat): | |
# ----------------- compute gradient for x0 in latent space ----------------- # | |
target, pred = None, None | |
if cond_fn.space == "latent": | |
target = self.model.get_first_stage_encoding( | |
self.model.encode_first_stage(cond_fn.target.to(device)) | |
) | |
pred = pred_x0 | |
elif cond_fn.space == "rgb": | |
# We need to backward gradient to x0 in latent space, so it's required | |
# to trace the computation graph while decoding the latent. | |
with torch.enable_grad(): | |
pred_x0.requires_grad_(True) | |
target = cond_fn.target.to(device) | |
pred = self.model.decode_first_stage_with_grad(pred_x0) | |
else: | |
raise NotImplementedError(cond_fn.space) | |
delta_pred = cond_fn(target, pred, t_now) | |
# ----------------- apply classifier guidance ----------------- # | |
if delta_pred is not None: | |
if cond_fn.space == "rgb": | |
# compute gradient for pred_x0 | |
pred.backward(delta_pred) | |
delta_pred_x0 = pred_x0.grad | |
# update prex_x0 | |
pred_x0 += delta_pred_x0 | |
# our classifier guidance is equivalent to multiply delta_pred_x0 | |
# by a constant and then add it to model_mean, We set the constant | |
# to 0.5 | |
model_mean += 0.5 * delta_pred_x0 | |
pred_x0.grad.zero_() | |
else: | |
delta_pred_x0 = delta_pred | |
pred_x0 += delta_pred_x0 | |
model_mean += 0.5 * delta_pred_x0 | |
else: | |
# means stop guidance | |
break | |
return model_mean.detach().clone(), pred_x0.detach().clone() | |
def p_sample( | |
self, | |
x: torch.Tensor, | |
cond: Dict[str, torch.Tensor], | |
t: torch.Tensor, | |
index: torch.Tensor, | |
cfg_scale: float, | |
uncond: Optional[Dict[str, torch.Tensor]], | |
cond_fn: Optional[Guidance] | |
) -> torch.Tensor: | |
# variance of posterior distribution q(x_{t-1}|x_t, x_0) | |
model_variance = { | |
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]), | |
"fixed_small": self.posterior_variance | |
}[self.var_type] | |
model_variance = _extract_into_tensor(model_variance, index, x.shape) | |
# mean of posterior distribution q(x_{t-1}|x_t, x_0) | |
if cond_fn is not None: | |
# apply classifier guidance | |
model_mean, pred_x0 = self.apply_cond_fn( | |
x, cond, t, index, cond_fn, | |
cfg_scale, uncond | |
) | |
else: | |
e_t = self.predict_noise( | |
x, t, cond, cfg_scale, uncond | |
) | |
pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) | |
model_mean, _, _ = self.q_posterior_mean_variance( | |
x_start=pred_x0, x_t=x, t=index | |
) | |
# sample x_t from q(x_{t-1}|x_t, x_0) | |
noise = torch.randn_like(x) | |
nonzero_mask = ( | |
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
) | |
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise | |
return x_prev | |
def sample_with_mixdiff( | |
self, | |
tile_size: int, | |
tile_stride: int, | |
steps: int, | |
shape: Tuple[int], | |
cond_img: torch.Tensor, | |
positive_prompt: str, | |
negative_prompt: str, | |
x_T: Optional[torch.Tensor]=None, | |
cfg_scale: float=1., | |
cond_fn: Optional[Guidance]=None, | |
color_fix_type: str="none" | |
) -> torch.Tensor: | |
def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: | |
hi_list = list(range(0, h - tile_size + 1, tile_stride)) | |
if (h - tile_size) % tile_stride != 0: | |
hi_list.append(h - tile_size) | |
wi_list = list(range(0, w - tile_size + 1, tile_stride)) | |
if (w - tile_size) % tile_stride != 0: | |
wi_list.append(w - tile_size) | |
coords = [] | |
for hi in hi_list: | |
for wi in wi_list: | |
coords.append((hi, hi + tile_size, wi, wi + tile_size)) | |
return coords | |
# make sampling parameters (e.g. sigmas) | |
self.make_schedule(num_steps=steps) | |
device = next(self.model.parameters()).device | |
b, _, h, w = shape | |
if x_T is None: | |
img = torch.randn(shape, dtype=torch.float32, device=device) | |
else: | |
img = x_T | |
# create buffers for accumulating predicted noise of different diffusion process | |
noise_buffer = torch.zeros_like(img) | |
count = torch.zeros(shape, dtype=torch.long, device=device) | |
# timesteps iterator | |
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...] | |
total_steps = len(self.timesteps) | |
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps) | |
# sampling loop | |
for i, step in enumerate(iterator): | |
ts = torch.full((b,), step, device=device, dtype=torch.long) | |
index = torch.full_like(ts, fill_value=total_steps - i - 1) | |
# predict noise for each tile | |
tiles_iterator = tqdm(_sliding_windows(h, w, tile_size // 8, tile_stride // 8)) | |
for hi, hi_end, wi, wi_end in tiles_iterator: | |
tiles_iterator.set_description(f"Process tile with location ({hi} {hi_end}) ({wi} {wi_end})") | |
# noisy latent of this diffusion process (tile) at this step | |
tile_img = img[:, :, hi:hi_end, wi:wi_end] | |
# prepare condition for this tile | |
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] | |
tile_cond = { | |
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)], | |
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)] | |
} | |
tile_uncond = { | |
"c_latent": [self.model.apply_condition_encoder(tile_cond_img)], | |
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)] | |
} | |
# TODO: tile_cond_fn | |
# predict noise for this tile | |
tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond) | |
# accumulate mean and variance | |
noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise | |
count[:, :, hi:hi_end, wi:wi_end] += 1 | |
if (count == 0).any().item(): | |
print(f"find count == 0!") | |
# average on noise | |
noise_buffer.div_(count) | |
# sample previous latent | |
pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer) | |
mean, _, _ = self.q_posterior_mean_variance( | |
x_start=pred_x0, x_t=img, t=index | |
) | |
variance = { | |
"fixed_large": np.append(self.posterior_variance[1], self.betas[1:]), | |
"fixed_small": self.posterior_variance | |
}[self.var_type] | |
variance = _extract_into_tensor(variance, index, noise_buffer.shape) | |
nonzero_mask = ( | |
(index != 0).float().view(-1, *([1] * (len(noise_buffer.shape) - 1))) | |
) | |
img = mean + nonzero_mask * torch.sqrt(variance) * torch.randn_like(mean) | |
noise_buffer.zero_() | |
count.zero_() | |
# decode samples of each diffusion process | |
img_buffer = torch.zeros_like(cond_img) | |
count = torch.zeros_like(cond_img, dtype=torch.long) | |
for hi, hi_end, wi, wi_end in _sliding_windows(h, w, tile_size // 8, tile_stride // 8): | |
tile_img = img[:, :, hi:hi_end, wi:wi_end] | |
tile_img_pixel = (self.model.decode_first_stage(tile_img) + 1) / 2 | |
tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] | |
# apply color correction (borrowed from StableSR) | |
if color_fix_type == "adain": | |
tile_img_pixel = adaptive_instance_normalization(tile_img_pixel, tile_cond_img) | |
elif color_fix_type == "wavelet": | |
tile_img_pixel = wavelet_reconstruction(tile_img_pixel, tile_cond_img) | |
else: | |
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}" | |
img_buffer[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += tile_img_pixel | |
count[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += 1 | |
img_buffer.div_(count) | |
return img_buffer | |
def sample( | |
self, | |
steps: int, | |
shape: Tuple[int], | |
cond_img: torch.Tensor, | |
positive_prompt: str, | |
negative_prompt: str, | |
x_T: Optional[torch.Tensor]=None, | |
cfg_scale: float=1., | |
cond_fn: Optional[Guidance]=None, | |
color_fix_type: str="none" | |
) -> torch.Tensor: | |
self.make_schedule(num_steps=steps) | |
device = next(self.model.parameters()).device | |
b = shape[0] | |
if x_T is None: | |
img = torch.randn(shape, device=device) | |
else: | |
img = x_T | |
time_range = np.flip(self.timesteps) # [1000, 950, 900, ...] | |
total_steps = len(self.timesteps) | |
iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps) | |
cond = { | |
"c_latent": [self.model.apply_condition_encoder(cond_img)], | |
"c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)] | |
} | |
uncond = { | |
"c_latent": [self.model.apply_condition_encoder(cond_img)], | |
"c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)] | |
} | |
for i, step in enumerate(iterator): | |
ts = torch.full((b,), step, device=device, dtype=torch.long) | |
index = torch.full_like(ts, fill_value=total_steps - i - 1) | |
img = self.p_sample( | |
img, cond, ts, index=index, | |
cfg_scale=cfg_scale, uncond=uncond, | |
cond_fn=cond_fn | |
) | |
img_pixel = (self.model.decode_first_stage(img) + 1) / 2 | |
# apply color correction (borrowed from StableSR) | |
if color_fix_type == "adain": | |
img_pixel = adaptive_instance_normalization(img_pixel, cond_img) | |
elif color_fix_type == "wavelet": | |
img_pixel = wavelet_reconstruction(img_pixel, cond_img) | |
else: | |
assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}" | |
return img_pixel | |