import torch import torch.nn as nn from networks import make_grid as mkgrid from networks import make_grid_3d as mkgrid_3d import argparse import os import time from cp_dataset import CPDataset, CPDataLoader from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D from tqdm import tqdm from utils import * def remove_overlap(seg_out, warped_cm): assert len(warped_cm.shape) == 4 warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm return warped_cm def get_opt(): parser = argparse.ArgumentParser() parser.add_argument("--name", default="test") parser.add_argument("--gpu_ids", default="") parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=8) parser.add_argument('--fp16', action='store_true', help='use amp') parser.add_argument("--dataroot", default="./data/") parser.add_argument("--datamode", default="train") parser.add_argument("--data_list", default="train_pairs.txt") parser.add_argument("--fine_width", type=int, default=192) parser.add_argument("--fine_height", type=int, default=256) parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos') parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint') parser.add_argument("--display_count", type=int, default=100) parser.add_argument("--save_count", type=int, default=1000) parser.add_argument("--load_step", type=int, default=0) parser.add_argument("--keep_step", type=int, default=300000) parser.add_argument("--shuffle", action='store_true', help='shuffle input data') parser.add_argument("--semantic_nc", type=int, default=13) parser.add_argument("--output_nc", type=int, default=13) parser.add_argument('--cond_G_ngf', type=int, default=96) parser.add_argument('--cond_G_num_layers', type=int, default=5) parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1") parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu") parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field") parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D") parser.add_argument('--num_D', type=int, default=2, help='Generator ngf') parser.add_argument("--lasttvonly", action='store_true') parser.add_argument("--interflowloss", action='store_true', help="Intermediate flow loss") parser.add_argument('--G_lr', type=float, default=0.0002, help='Generator initial learning rate for adam') parser.add_argument('--D_lr', type=float, default=0.0002, help='Discriminator initial learning rate for adam') parser.add_argument('--CElamda', type=float, default=10, help='initial learning rate for adam') parser.add_argument('--GANlambda', type=float, default=1) parser.add_argument('--tvlambda_tvob', type=float, default=2) parser.add_argument('--tvlambda_taco', type=float, default=2) parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear']) parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D") parser.add_argument('--occlusion', action='store_true', help="Occlusion handling") parser.add_argument('--resume', action='store_true', help='resume from the latest checkpoint') opt = parser.parse_args() return opt def train(opt, train_loader, tocg, D, optimizer_G, optimizer_D): tocg.cuda() tocg.train() D.cuda() D.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() if opt.fp16: criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.HalfTensor) else: criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor if opt.gpu_ids else torch.Tensor) for step in tqdm(range(opt.load_step, opt.keep_step)): iter_start_time = time.time() inputs = train_loader.next_batch() # input1 c_paired = inputs['cloth']['paired'].cuda() cm_paired = inputs['cloth_mask']['paired'].cuda() cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(float)).cuda() # input2 parse_agnostic = inputs['parse_agnostic'].cuda() densepose = inputs['densepose'].cuda() openpose = inputs['pose'].cuda() # GT label_onehot = inputs['parse_onehot'].cuda() # CE label = inputs['parse'].cuda() # GAN loss parse_cloth_mask = inputs['pcm'].cuda() # L1 im_c = inputs['parse_cloth'].cuda() # VGG # visualization im = inputs['image'] # tucked-out shirts style lower_clothes_mask = inputs['lower_clothes_mask'].cuda() clothes_no_loss_mask = inputs['clothes_no_loss_mask'].cuda() # inputs input1 = torch.cat([c_paired, cm_paired], 1) input2 = torch.cat([parse_agnostic, densepose], 1) # forward flow_list_taco, fake_segmap, warped_cloth_paired_taco, warped_clothmask_paired_taco, flow_list_tvob, warped_cloth_paired_tvob, warped_clothmask_paired_tvob = tocg(input1, input2) # warped cloth mask one hot warped_clothmask_paired_taco_onehot = torch.FloatTensor((warped_clothmask_paired_taco.detach().cpu().numpy() > 0.5).astype(float)).cuda() # fake segmap cloth channel * warped clothmask cloth_mask = torch.ones_like(fake_segmap.detach()) cloth_mask[:, 3:4, :, :] = warped_clothmask_paired_taco fake_segmap = fake_segmap * cloth_mask if opt.occlusion: warped_clothmask_paired_taco = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired_taco) warped_cloth_paired_taco = warped_cloth_paired_taco * warped_clothmask_paired_taco + torch.ones_like(warped_cloth_paired_taco) * (1 - warped_clothmask_paired_taco) warped_clothmask_paired_tvob = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired_tvob) warped_cloth_paired_tvob = warped_cloth_paired_tvob * warped_clothmask_paired_tvob + torch.ones_like(warped_cloth_paired_tvob) * (1 - warped_clothmask_paired_tvob) # loss warping loss_l1_cloth = criterionL1(warped_clothmask_paired_taco, parse_cloth_mask) loss_vgg = criterionVGG(warped_cloth_paired_taco, im_c) # Eq.8 & Eq.9 of SD-VITON inv_lower_clothes_mask = lower_clothes_mask * clothes_no_loss_mask inv_lower_clothes_mask = 1. - inv_lower_clothes_mask loss_l1_cloth += criterionL1(warped_clothmask_paired_tvob * inv_lower_clothes_mask, parse_cloth_mask * inv_lower_clothes_mask) loss_vgg += criterionVGG(warped_cloth_paired_tvob * inv_lower_clothes_mask, im_c * inv_lower_clothes_mask) # Eq.12 of SD-VITON roi_mask = torch.nn.functional.interpolate(parse_cloth_mask, scale_factor=0.5, mode='nearest') non_roi_mask = 1. - roi_mask flow_taco = flow_list_taco[-1] z_gt_non_roi = -1 z_gt_roi = 1 z_src_coordinate = -1 z_dist_loss_non_roi = (torch.abs(z_src_coordinate + flow_taco[:, 0:1, :, :, 2] + z_gt_non_roi) * non_roi_mask).mean() z_dist_loss_roi = (torch.abs(z_src_coordinate + flow_taco[:, 0:1, :, :, 2] + z_gt_roi) * roi_mask).mean() loss_tv_tvob = 0 loss_tv_taco = 0 if not opt.lasttvonly: for flow in flow_list_taco: y_tv = torch.abs(flow[:, :, 1:, :, :] - flow[:, :, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, :, 1:, :] - flow[:, :, :, :-1, :]).mean() loss_tv_taco = loss_tv_taco + y_tv + x_tv for flow in flow_list_tvob: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv_tvob = loss_tv_tvob + y_tv + x_tv else: for flow in flow_list_taco[-1:]: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv_taco = loss_tv_taco + y_tv + x_tv for flow in flow_list_tvob[-1:]: y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean() x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean() loss_tv_tvob = loss_tv_tvob + y_tv + x_tv N, _, iH, iW = c_paired.size() # Intermediate flow loss if opt.interflowloss: layers_max_idx = len(flow_list_tvob) - 1 for i in range(len(flow_list_tvob) - 1): flow = flow_list_tvob[i] N, fH, fW, _ = flow.size() grid = mkgrid(N, iH, iW) grid_3d = mkgrid_3d(N, iH, iW) flow = F.interpolate(flow.permute(0, 3, 1, 2), size=c_paired.shape[2:], mode=opt.upsample).permute(0, 2, 3, 1) flow_norm = torch.cat([flow[:, :, :, 0:1] / ((fW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((fH - 1.0) / 2.0)], 3) warped_c = F.grid_sample(c_paired, flow_norm + grid, padding_mode='border') warped_cm = F.grid_sample(cm_paired, flow_norm + grid, padding_mode='border') warped_cm = remove_overlap(F.softmax(fake_segmap, dim=1), warped_cm) # Eq.8 & Eq.9 of SD-VITON loss_l1_cloth += criterionL1(warped_cm * inv_lower_clothes_mask, parse_cloth_mask * inv_lower_clothes_mask) / (2 ** (layers_max_idx - i)) loss_vgg += criterionVGG(warped_c * inv_lower_clothes_mask, im_c * inv_lower_clothes_mask) / (2 ** (layers_max_idx - i)) # loss segmentation # generator CE_loss = cross_entropy2d(fake_segmap, label_onehot.transpose(0, 1)[0].long()) fake_segmap_softmax = torch.softmax(fake_segmap, 1) pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1)) loss_G_GAN = criterionGAN(pred_segmap, True) # discriminator fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()), dim=1)) real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label), dim=1)) loss_D_fake = criterionGAN(fake_segmap_pred, False) loss_D_real = criterionGAN(real_segmap_pred, True) # loss sum loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda_tvob * loss_tv_tvob + opt.tvlambda_taco * loss_tv_taco) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) + z_dist_loss_non_roi + z_dist_loss_roi loss_D = loss_D_fake + loss_D_real # step optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() # display if (step) % 100 == 0: a_0 = c_paired[0].cuda() b_0 = im[0].cuda() c_0 = warped_cloth_paired_tvob[0] d_0 = warped_cloth_paired_taco[0] e_0 = lower_clothes_mask e_0 = torch.cat((e_0[0], e_0[0], e_0[0]), dim=0) f_0 = densepose[0].cuda() g_0 = clothes_no_loss_mask g_0 = torch.cat((g_0[0], g_0[0], g_0[0]), dim=0) h_0 = lower_clothes_mask * clothes_no_loss_mask h_0 = torch.cat((h_0[0], h_0[0], h_0[0]), dim=0) i_0 = inv_lower_clothes_mask i_0 = torch.cat((i_0[0], i_0[0], i_0[0]), dim=0) combine = torch.cat((a_0, b_0, c_0, d_0, e_0, f_0, g_0, h_0, i_0), dim=2) cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2 rgb = (cv_img * 255).astype(np.uint8) bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) cv2.imwrite('sample_fs_3/' + str(step) + '.jpg', bgr) if (step + 1) % opt.display_count == 0: t = time.time() - iter_start_time print("step: %8d, time: %.3f\nloss G: %.4f, L1_cloth loss: %.4f, VGG loss: %.4f, TV_tvob loss: %.4f, TV_taco loss: %.4f, CE: %.4f, G GAN: %.4f\nloss D: %.4f, D real: %.4f, D fake: %.4f, z_non_roi: %.4f, z_roi: %.4f" % (step + 1, t, loss_G.item(), loss_l1_cloth.item(), loss_vgg.item(), loss_tv_tvob.item(), loss_tv_taco.item(), CE_loss.item(), loss_G_GAN.item(), loss_D.item(), loss_D_real.item(), loss_D_fake.item(), z_dist_loss_non_roi, z_dist_loss_roi), flush=True) # Save checkpoint if (step + 1) % opt.save_count == 0: checkpoint = { 'step': step + 1, 'tocg_state_dict': tocg.state_dict(), 'D_state_dict': D.state_dict(), 'optimizer_G_state_dict': optimizer_G.state_dict(), 'optimizer_D_state_dict': optimizer_D.state_dict(), } checkpoint_path = os.path.join(opt.checkpoint_dir, opt.name, 'latest_checkpoint.pth') torch.save(checkpoint, checkpoint_path) print(f"Saved checkpoint at step {step + 1} to {checkpoint_path}") def main(): opt = get_opt() print(f"Start to train {opt.name}!") os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids os.makedirs('sample_fs_3', exist_ok=True) os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True) # Create train dataset & loader train_dataset = CPDataset(opt) train_loader = CPDataLoader(opt, train_dataset) # Model input1_nc = 4 # cloth + cloth-mask input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers) D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2=opt.Ddownx2, Ddropout=opt.Ddropout, n_layers_D=(opt.cond_G_num_layers - 2), spectral=opt.spectral, num_D=opt.num_D) # Move models to GPU tocg.cuda() D.cuda() # Define optimizers optimizer_G = torch.optim.Adam(tocg.parameters(), lr=opt.G_lr, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.D_lr, betas=(0.5, 0.999)) # Load checkpoint if resuming if opt.resume: checkpoint_path = os.path.join(opt.checkpoint_dir, opt.name, 'latest_checkpoint.pth') if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) tocg.load_state_dict(checkpoint['tocg_state_dict']) D.load_state_dict(checkpoint['D_state_dict']) optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) opt.load_step = checkpoint['step'] print(f"Resuming from step {opt.load_step} using checkpoint {checkpoint_path}") else: print(f"No checkpoint found at {checkpoint_path}, starting from scratch") # Train train(opt, train_loader, tocg, D, optimizer_G, optimizer_D) # Save final models save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_final.pth')) save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_final.pth')) print(f"Finished training {opt.name}!") if __name__ == "__main__": main()