|
import gc |
|
import logging |
|
|
|
from utils.dataset import ShardingLMDBDataset, cycle |
|
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job |
|
from utils.misc import ( |
|
set_seed, |
|
merge_dict_list |
|
) |
|
import torch.distributed as dist |
|
from omegaconf import OmegaConf |
|
from model import GAN |
|
import torch |
|
import wandb |
|
import time |
|
import os |
|
|
|
|
|
class Trainer: |
|
def __init__(self, config): |
|
self.config = config |
|
self.step = 0 |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
launch_distributed_job() |
|
global_rank = dist.get_rank() |
|
self.world_size = dist.get_world_size() |
|
|
|
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32 |
|
self.device = torch.cuda.current_device() |
|
self.is_main_process = global_rank == 0 |
|
self.causal = config.causal |
|
self.disable_wandb = config.disable_wandb |
|
|
|
|
|
self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0) |
|
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps |
|
if self.in_discriminator_warmup and self.is_main_process: |
|
print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps") |
|
self.loss_scale = getattr(config, "loss_scale", 1.0) |
|
|
|
|
|
if config.seed == 0: |
|
random_seed = torch.randint(0, 10000000, (1,), device=self.device) |
|
dist.broadcast(random_seed, src=0) |
|
config.seed = random_seed.item() |
|
|
|
set_seed(config.seed + global_rank) |
|
|
|
if self.is_main_process and not self.disable_wandb: |
|
wandb.login(host=config.wandb_host, key=config.wandb_key) |
|
wandb.init( |
|
config=OmegaConf.to_container(config, resolve=True), |
|
name=config.config_name, |
|
mode="online", |
|
entity=config.wandb_entity, |
|
project=config.wandb_project, |
|
dir=config.wandb_save_dir |
|
) |
|
|
|
self.output_path = config.logdir |
|
|
|
|
|
self.model = GAN(config, device=self.device) |
|
|
|
self.model.generator = fsdp_wrap( |
|
self.model.generator, |
|
sharding_strategy=config.sharding_strategy, |
|
mixed_precision=config.mixed_precision, |
|
wrap_strategy=config.generator_fsdp_wrap_strategy |
|
) |
|
|
|
self.model.fake_score = fsdp_wrap( |
|
self.model.fake_score, |
|
sharding_strategy=config.sharding_strategy, |
|
mixed_precision=config.mixed_precision, |
|
wrap_strategy=config.fake_score_fsdp_wrap_strategy |
|
) |
|
|
|
self.model.text_encoder = fsdp_wrap( |
|
self.model.text_encoder, |
|
sharding_strategy=config.sharding_strategy, |
|
mixed_precision=config.mixed_precision, |
|
wrap_strategy=config.text_encoder_fsdp_wrap_strategy, |
|
cpu_offload=getattr(config, "text_encoder_cpu_offload", False) |
|
) |
|
|
|
if not config.no_visualize or config.load_raw_video: |
|
self.model.vae = self.model.vae.to( |
|
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32) |
|
|
|
self.generator_optimizer = torch.optim.AdamW( |
|
[param for param in self.model.generator.parameters() |
|
if param.requires_grad], |
|
lr=config.gen_lr, |
|
betas=(config.beta1, config.beta2) |
|
) |
|
|
|
|
|
|
|
|
|
fake_score_params = [] |
|
discriminator_params = [] |
|
|
|
for name, param in self.model.fake_score.named_parameters(): |
|
if param.requires_grad: |
|
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name: |
|
discriminator_params.append(param) |
|
else: |
|
fake_score_params.append(param) |
|
|
|
|
|
|
|
self.critic_param_groups = [ |
|
{'params': fake_score_params, 'lr': config.critic_lr}, |
|
{'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier} |
|
] |
|
if self.in_discriminator_warmup: |
|
self.critic_optimizer = torch.optim.AdamW( |
|
self.critic_param_groups, |
|
betas=(0.9, config.beta2_critic) |
|
) |
|
else: |
|
self.critic_optimizer = torch.optim.AdamW( |
|
self.critic_param_groups, |
|
betas=(config.beta1_critic, config.beta2_critic) |
|
) |
|
|
|
|
|
self.data_path = config.data_path |
|
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8)) |
|
sampler = torch.utils.data.distributed.DistributedSampler( |
|
dataset, shuffle=True, drop_last=True) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=config.batch_size, |
|
sampler=sampler, |
|
num_workers=8) |
|
|
|
if dist.get_rank() == 0: |
|
print("DATASET SIZE %d" % len(dataset)) |
|
|
|
self.dataloader = cycle(dataloader) |
|
|
|
|
|
|
|
rename_param = ( |
|
lambda name: name.replace("_fsdp_wrapped_module.", "") |
|
.replace("_checkpoint_wrapped_module.", "") |
|
.replace("_orig_mod.", "") |
|
) |
|
self.name_to_trainable_params = {} |
|
for n, p in self.model.generator.named_parameters(): |
|
if not p.requires_grad: |
|
continue |
|
|
|
renamed_n = rename_param(n) |
|
self.name_to_trainable_params[renamed_n] = p |
|
ema_weight = config.ema_weight |
|
self.generator_ema = None |
|
if (ema_weight is not None) and (ema_weight > 0.0): |
|
print(f"Setting up EMA with weight {ema_weight}") |
|
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) |
|
|
|
|
|
|
|
if getattr(config, "generator_ckpt", False): |
|
print(f"Loading pretrained generator from {config.generator_ckpt}") |
|
state_dict = torch.load(config.generator_ckpt, map_location="cpu") |
|
if "generator" in state_dict: |
|
state_dict = state_dict["generator"] |
|
elif "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
self.model.generator.load_state_dict( |
|
state_dict, strict=True |
|
) |
|
if hasattr(config, "load"): |
|
resume_ckpt_path_critic = os.path.join(config.load, "critic") |
|
resume_ckpt_path_generator = os.path.join(config.load, "generator") |
|
else: |
|
resume_ckpt_path_critic = "none" |
|
resume_ckpt_path_generator = "none" |
|
|
|
_, _ = self.checkpointer_critic.try_best_load( |
|
resume_ckpt_path=resume_ckpt_path_critic, |
|
) |
|
self.step, _ = self.checkpointer_generator.try_best_load( |
|
resume_ckpt_path=resume_ckpt_path_generator, |
|
force_start_w_ema=config.force_start_w_ema, |
|
force_reset_zero_step=config.force_reset_zero_step, |
|
force_reinit_ema=config.force_reinit_ema, |
|
skip_optimizer_scheduler=config.skip_optimizer_scheduler, |
|
) |
|
|
|
|
|
|
|
|
|
if self.step < config.ema_start_step: |
|
self.generator_ema = None |
|
|
|
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0) |
|
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0) |
|
self.previous_time = None |
|
|
|
def save(self): |
|
print("Start gathering distributed model states...") |
|
generator_state_dict = fsdp_state_dict( |
|
self.model.generator) |
|
critic_state_dict = fsdp_state_dict( |
|
self.model.fake_score) |
|
|
|
if self.config.ema_start_step < self.step: |
|
state_dict = { |
|
"generator": generator_state_dict, |
|
"critic": critic_state_dict, |
|
"generator_ema": self.generator_ema.state_dict(), |
|
} |
|
else: |
|
state_dict = { |
|
"generator": generator_state_dict, |
|
"critic": critic_state_dict, |
|
} |
|
|
|
if self.is_main_process: |
|
os.makedirs(os.path.join(self.output_path, |
|
f"checkpoint_model_{self.step:06d}"), exist_ok=True) |
|
torch.save(state_dict, os.path.join(self.output_path, |
|
f"checkpoint_model_{self.step:06d}", "model.pt")) |
|
print("Model saved to", os.path.join(self.output_path, |
|
f"checkpoint_model_{self.step:06d}", "model.pt")) |
|
|
|
def fwdbwd_one_step(self, batch, train_generator): |
|
self.model.eval() |
|
|
|
if self.step % 20 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
text_prompts = batch["prompts"] |
|
if "ode_latent" in batch: |
|
clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype) |
|
else: |
|
frames = batch["frames"].to(device=self.device, dtype=self.dtype) |
|
with torch.no_grad(): |
|
clean_latent = self.model.vae.encode_to_latent( |
|
frames).to(device=self.device, dtype=self.dtype) |
|
|
|
image_latent = clean_latent[:, 0:1, ] |
|
|
|
batch_size = len(text_prompts) |
|
image_or_video_shape = list(self.config.image_or_video_shape) |
|
image_or_video_shape[0] = batch_size |
|
|
|
|
|
with torch.no_grad(): |
|
conditional_dict = self.model.text_encoder( |
|
text_prompts=text_prompts) |
|
|
|
if not getattr(self, "unconditional_dict", None): |
|
unconditional_dict = self.model.text_encoder( |
|
text_prompts=[self.config.negative_prompt] * batch_size) |
|
unconditional_dict = {k: v.detach() |
|
for k, v in unconditional_dict.items()} |
|
self.unconditional_dict = unconditional_dict |
|
else: |
|
unconditional_dict = self.unconditional_dict |
|
|
|
mini_bs, full_bs = ( |
|
batch["mini_bs"], |
|
batch["full_bs"], |
|
) |
|
|
|
|
|
if train_generator: |
|
gan_G_loss = self.model.generator_loss( |
|
image_or_video_shape=image_or_video_shape, |
|
conditional_dict=conditional_dict, |
|
unconditional_dict=unconditional_dict, |
|
clean_latent=clean_latent, |
|
initial_latent=image_latent if self.config.i2v else None |
|
) |
|
|
|
loss_ratio = mini_bs * self.world_size / full_bs |
|
total_loss = gan_G_loss * loss_ratio * self.loss_scale |
|
|
|
total_loss.backward() |
|
generator_grad_norm = self.model.generator.clip_grad_norm_( |
|
self.max_grad_norm_generator) |
|
|
|
generator_log_dict = {"generator_grad_norm": generator_grad_norm, |
|
"gan_G_loss": gan_G_loss} |
|
|
|
return generator_log_dict |
|
else: |
|
generator_log_dict = {} |
|
|
|
|
|
(gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss( |
|
image_or_video_shape=image_or_video_shape, |
|
conditional_dict=conditional_dict, |
|
unconditional_dict=unconditional_dict, |
|
clean_latent=clean_latent, |
|
real_image_or_video=clean_latent, |
|
initial_latent=image_latent if self.config.i2v else None |
|
) |
|
|
|
loss_ratio = mini_bs * dist.get_world_size() / full_bs |
|
total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale |
|
|
|
total_loss.backward() |
|
critic_grad_norm = self.model.fake_score.clip_grad_norm_( |
|
self.max_grad_norm_critic) |
|
|
|
critic_log_dict.update({"critic_grad_norm": critic_grad_norm, |
|
"gan_D_loss": gan_D_loss, |
|
"r1_loss": r1_loss, |
|
"r2_loss": r2_loss}) |
|
|
|
return critic_log_dict |
|
|
|
def generate_video(self, pipeline, prompts, image=None): |
|
batch_size = len(prompts) |
|
sampled_noise = torch.randn( |
|
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype |
|
) |
|
video, _ = pipeline.inference( |
|
noise=sampled_noise, |
|
text_prompts=prompts, |
|
return_latents=True |
|
) |
|
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0 |
|
return current_video |
|
|
|
def train(self): |
|
start_step = self.step |
|
|
|
while True: |
|
if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0: |
|
print("Resetting critic optimizer") |
|
del self.critic_optimizer |
|
torch.cuda.empty_cache() |
|
|
|
self.critic_optimizer = torch.optim.AdamW( |
|
self.critic_param_groups, |
|
betas=(self.config.beta1_critic, self.config.beta2_critic) |
|
) |
|
|
|
self.checkpointer_critic.optimizer = self.critic_optimizer |
|
|
|
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps |
|
|
|
|
|
TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0 |
|
|
|
|
|
if TRAIN_GENERATOR: |
|
self.model.fake_score.requires_grad_(False) |
|
self.model.generator.requires_grad_(True) |
|
self.generator_optimizer.zero_grad(set_to_none=True) |
|
extras_list = [] |
|
for ii, mini_batch in enumerate(self.dataloader.next()): |
|
extra = self.fwdbwd_one_step(mini_batch, True) |
|
extras_list.append(extra) |
|
generator_log_dict = merge_dict_list(extras_list) |
|
self.generator_optimizer.step() |
|
if self.generator_ema is not None: |
|
self.generator_ema.update(self.model.generator) |
|
else: |
|
generator_log_dict = {} |
|
|
|
|
|
if self.in_discriminator_warmup: |
|
|
|
self.model.generator.requires_grad_(False) |
|
self.model.fake_score.requires_grad_(False) |
|
|
|
|
|
for name, param in self.model.fake_score.named_parameters(): |
|
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name: |
|
param.requires_grad_(True) |
|
else: |
|
|
|
self.model.generator.requires_grad_(False) |
|
self.model.fake_score.requires_grad_(True) |
|
|
|
self.critic_optimizer.zero_grad(set_to_none=True) |
|
extras_list = [] |
|
batch = next(self.dataloader) |
|
extra = self.fwdbwd_one_step(batch, False) |
|
extras_list.append(extra) |
|
critic_log_dict = merge_dict_list(extras_list) |
|
self.critic_optimizer.step() |
|
|
|
|
|
self.step += 1 |
|
|
|
|
|
if self.is_main_process and self.step == self.discriminator_warmup_steps: |
|
print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps") |
|
|
|
|
|
if (self.step >= self.config.ema_start_step) and \ |
|
(self.generator_ema is None) and (self.config.ema_weight > 0): |
|
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight) |
|
|
|
|
|
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0: |
|
torch.cuda.empty_cache() |
|
self.save() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
wandb_loss_dict = { |
|
"generator_grad_norm": generator_log_dict["generator_grad_norm"], |
|
"critic_grad_norm": critic_log_dict["critic_grad_norm"], |
|
"real_logit": critic_log_dict["noisy_real_logit"], |
|
"fake_logit": critic_log_dict["noisy_fake_logit"], |
|
"r1_loss": critic_log_dict["r1_loss"], |
|
"r2_loss": critic_log_dict["r2_loss"], |
|
} |
|
if TRAIN_GENERATOR: |
|
wandb_loss_dict.update({ |
|
"generator_grad_norm": generator_log_dict["generator_grad_norm"], |
|
}) |
|
self.all_gather_dict(wandb_loss_dict) |
|
wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"] |
|
wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"]) |
|
|
|
if self.is_main_process: |
|
if self.in_discriminator_warmup: |
|
warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params" |
|
print(warmup_status) |
|
if not self.disable_wandb: |
|
wandb_loss_dict.update({"warmup_status": 1.0}) |
|
|
|
if not self.disable_wandb: |
|
wandb.log(wandb_loss_dict, step=self.step) |
|
|
|
if self.step % self.config.gc_interval == 0: |
|
if dist.get_rank() == 0: |
|
logging.info("DistGarbageCollector: Running GC.") |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if self.is_main_process: |
|
current_time = time.time() |
|
if self.previous_time is None: |
|
self.previous_time = current_time |
|
else: |
|
if not self.disable_wandb: |
|
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step) |
|
self.previous_time = current_time |
|
|
|
def all_gather_dict(self, target_dict): |
|
for key, value in target_dict.items(): |
|
gathered_value = torch.zeros( |
|
[self.world_size, *value.shape], |
|
dtype=value.dtype, device=self.device) |
|
dist.all_gather_into_tensor(gathered_value, value) |
|
avg_value = gathered_value.mean().item() |
|
target_dict[key] = avg_value |
|
|