|
from typing import List |
|
import torch |
|
|
|
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper |
|
|
|
|
|
class BidirectionalInferencePipeline(torch.nn.Module): |
|
def __init__( |
|
self, |
|
args, |
|
device, |
|
generator=None, |
|
text_encoder=None, |
|
vae=None |
|
): |
|
super().__init__() |
|
|
|
self.generator = WanDiffusionWrapper( |
|
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator |
|
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder |
|
self.vae = WanVAEWrapper() if vae is None else vae |
|
|
|
|
|
self.scheduler = self.generator.get_scheduler() |
|
self.denoising_step_list = torch.tensor( |
|
args.denoising_step_list, dtype=torch.long, device=device) |
|
if self.denoising_step_list[-1] == 0: |
|
self.denoising_step_list = self.denoising_step_list[:-1] |
|
if args.warp_denoising_step: |
|
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) |
|
self.denoising_step_list = timesteps[1000 - self.denoising_step_list] |
|
|
|
def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor: |
|
""" |
|
Perform inference on the given noise and text prompts. |
|
Inputs: |
|
noise (torch.Tensor): The input noise tensor of shape |
|
(batch_size, num_frames, num_channels, height, width). |
|
text_prompts (List[str]): The list of text prompts. |
|
Outputs: |
|
video (torch.Tensor): The generated video tensor of shape |
|
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. |
|
""" |
|
conditional_dict = self.text_encoder( |
|
text_prompts=text_prompts |
|
) |
|
|
|
|
|
noisy_image_or_video = noise |
|
|
|
|
|
for index, current_timestep in enumerate(self.denoising_step_list[:-1]): |
|
_, pred_image_or_video = self.generator( |
|
noisy_image_or_video=noisy_image_or_video, |
|
conditional_dict=conditional_dict, |
|
timestep=torch.ones( |
|
noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep |
|
) |
|
|
|
next_timestep = self.denoising_step_list[index + 1] * torch.ones( |
|
noise.shape[:2], dtype=torch.long, device=noise.device) |
|
|
|
noisy_image_or_video = self.scheduler.add_noise( |
|
pred_image_or_video.flatten(0, 1), |
|
torch.randn_like(pred_image_or_video.flatten(0, 1)), |
|
next_timestep.flatten(0, 1) |
|
).unflatten(0, noise.shape[:2]) |
|
|
|
video = self.vae.decode_to_pixel(pred_image_or_video) |
|
video = (video * 0.5 + 0.5).clamp(0, 1) |
|
return video |
|
|