Spaces:
Runtime error
Runtime error
| from diffusion.respace import SpacedDiffusion | |
| from .gaussian_diffusion import _extract_into_tensor | |
| import torch as th | |
| class InpaintingGaussianDiffusion(SpacedDiffusion): | |
| def q_sample(self, x_start, t, noise=None, model_kwargs=None): | |
| """ | |
| overrides q_sample to use the inpainting mask | |
| same usage as in GaussianDiffusion | |
| """ | |
| if noise is None: | |
| noise = th.randn_like(x_start) | |
| assert noise.shape == x_start.shape | |
| bs, feat, _, frames = noise.shape | |
| inpainting_mask = th.zeros_like(noise).to(noise.device) | |
| inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for training | |
| noise *= 1. - inpainting_mask | |
| return ( | |
| _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| * noise | |
| ) | |
| def p_sample( | |
| self, | |
| model, | |
| x, | |
| t, | |
| clip_denoised=True, | |
| denoised_fn=None, | |
| cond_fn=None, | |
| model_kwargs=None, | |
| const_noise=False, | |
| ): | |
| """ | |
| overrides p_sample to use the inpainting mask | |
| same usage as in GaussianDiffusion | |
| """ | |
| out = self.p_mean_variance( | |
| model, | |
| x, | |
| t, | |
| clip_denoised=clip_denoised, | |
| denoised_fn=denoised_fn, | |
| model_kwargs=model_kwargs, | |
| ) | |
| noise = th.randn_like(x) | |
| if const_noise: | |
| noise = noise[[0]].repeat(x.shape[0], 1, 1, 1) | |
| inpainting_mask = th.zeros_like(noise).to(noise.device) | |
| inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for inference | |
| noise *= 1. - inpainting_mask | |
| nonzero_mask = ( | |
| (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) | |
| ) # no noise when t == 0 | |
| if cond_fn is not None: | |
| out["mean"] = self.condition_mean( | |
| cond_fn, out, x, t, model_kwargs=model_kwargs | |
| ) | |
| sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise | |
| return {"sample": sample, "pred_xstart": out["pred_xstart"]} |