Ubaida10 commited on
Commit
550ed07
·
verified ·
1 Parent(s): 25fbc4b

Upload train_generator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_generator.py +635 -0
train_generator.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ import argparse
6
+ import os
7
+ import time
8
+ from cp_dataset import CPDataset, CPDataLoader
9
+ from cp_dataset_test import CPDatasetTest
10
+ from networks import ConditionGenerator, VGGLoss, load_checkpoint, save_checkpoint, make_grid, make_grid_3d
11
+ from network_generator import SPADEGenerator, MultiscaleDiscriminator, GANLoss, Projected_GANs_Loss, set_requires_grad
12
+
13
+ from sync_batchnorm import DataParallelWithCallback
14
+ from utils import create_network
15
+ import sys
16
+ from tqdm import tqdm
17
+
18
+ import numpy as np
19
+ from torch.utils.data import Subset
20
+ from torchvision.transforms import transforms
21
+ import eval_models as models
22
+ import torchgeometry as tgm
23
+
24
+ from pg_modules.discriminator import ProjectedDiscriminator
25
+ import cv2
26
+
27
+ def remove_overlap(seg_out, warped_cm):
28
+ assert len(warped_cm.shape) == 4
29
+ warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm
30
+ return warped_cm
31
+
32
+ def get_opt():
33
+ parser = argparse.ArgumentParser()
34
+
35
+ parser.add_argument('--name', type=str, required=True)
36
+ parser.add_argument('--gpu_ids', type=str, default='0')
37
+ parser.add_argument('-j', '--workers', type=int, default=4)
38
+ parser.add_argument('-b', '--batch_size', type=int, default=8)
39
+ parser.add_argument('--fp16', action='store_true', help='use amp')
40
+
41
+ parser.add_argument("--dataroot", default="./data/")
42
+ parser.add_argument("--datamode", default="train")
43
+ parser.add_argument("--data_list", default="train_pairs.txt")
44
+ parser.add_argument("--fine_width", type=int, default=768)
45
+ parser.add_argument("--fine_height", type=int, default=1024)
46
+ parser.add_argument("--radius", type=int, default=20)
47
+ parser.add_argument("--grid_size", type=int, default=5)
48
+
49
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
50
+ parser.add_argument('--tocg_checkpoint', type=str, help='condition generator checkpoint')
51
+ parser.add_argument('--gen_checkpoint', type=str, default='', help='gen checkpoint')
52
+ parser.add_argument('--dis_checkpoint', type=str, default='', help='dis checkpoint')
53
+
54
+ parser.add_argument("--display_count", type=int, default=100)
55
+ parser.add_argument("--save_count", type=int, default=1000)
56
+ parser.add_argument("--load_step", type=int, default=0)
57
+ parser.add_argument("--keep_step", type=int, default=100000)
58
+ parser.add_argument("--decay_step", type=int, default=100000)
59
+ parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
60
+ parser.add_argument('--resume', action='store_true', help='resume training from the last checkpoint')
61
+
62
+ # test
63
+ parser.add_argument("--lpips_count", type=int, default=1000)
64
+ parser.add_argument("--test_datasetting", default="paired")
65
+ parser.add_argument("--test_dataroot", default="./data/")
66
+ parser.add_argument("--test_data_list", default="test_pairs.txt")
67
+
68
+ # Hyper-parameters
69
+ parser.add_argument('--G_lr', type=float, default=0.0001, help='initial learning rate for adam')
70
+ parser.add_argument('--D_lr', type=float, default=0.0004, help='initial learning rate for adam')
71
+
72
+ # SEAN-related hyper-parameters
73
+ parser.add_argument('--GMM_const', type=float, default=None, help='constraint for GMM module')
74
+ parser.add_argument('--semantic_nc', type=int, default=13, help='# of input label classes without unknown class')
75
+ parser.add_argument('--gen_semantic_nc', type=int, default=7, help='# of input label classes without unknown class')
76
+ parser.add_argument('--norm_G', type=str, default='spectralaliasinstance', help='instance normalization or batch normalization')
77
+ parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
78
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
79
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
80
+ parser.add_argument('--num_upsampling_layers', choices=['normal', 'more', 'most'], default='most',
81
+ help='If \'more\', add upsampling layer between the two middle resnet blocks. '
82
+ 'If \'most\', also add one more (upsampling + resnet) layer at the end of the generator.')
83
+ parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
84
+ parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
85
+
86
+ parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
87
+ parser.add_argument('--lambda_l1', type=float, default=1.0, help='weight for image-level l1 loss')
88
+ parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
89
+ parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
90
+
91
+ # D
92
+ parser.add_argument('--n_layers_D', type=int, default=3, help='# layers in each discriminator')
93
+ parser.add_argument('--netD_subarch', type=str, default='n_layer', help='architecture of each discriminator')
94
+ parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to be used in multiscale')
95
+
96
+ # G & D arch-related
97
+ parser.add_argument("--composition_mask", action='store_true', help='shuffle input data')
98
+
99
+ # Training
100
+ parser.add_argument('--occlusion', action='store_true')
101
+ # tocg
102
+ # network
103
+ parser.add_argument('--cond_G_ngf', type=int, default=96)
104
+ parser.add_argument("--cond_G_input_width", type=int, default=192)
105
+ parser.add_argument("--cond_G_input_height", type=int, default=256)
106
+ parser.add_argument('--cond_G_num_layers', type=int, default=5)
107
+ parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
108
+ parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
109
+
110
+ # New arguments for selective layer freezing and last layer control
111
+ parser.add_argument('--freeze_tocg_layers', type=int, default=0, help='number of layers to freeze in tocg from the start')
112
+ parser.add_argument('--freeze_gen_layers', type=int, default=0, help='number of layers to freeze in generator from the start')
113
+ parser.add_argument('--last_layer_mode', type=str, default='train', choices=['train', 'half', 'freeze'],
114
+ help='Mode for the last layer: train (full training), half (half parameters frozen), freeze (fully frozen)')
115
+
116
+ opt = parser.parse_args()
117
+
118
+ # set gpu ids
119
+ str_ids = opt.gpu_ids.split(',')
120
+ opt.gpu_ids = []
121
+ for str_id in str_ids:
122
+ id = int(str_id)
123
+ if id >= 0:
124
+ opt.gpu_ids.append(id)
125
+ if len(opt.gpu_ids) > 0:
126
+ torch.cuda.set_device(opt.gpu_ids[0])
127
+
128
+ assert len(opt.gpu_ids) == 0 or opt.batch_size % len(opt.gpu_ids) == 0, \
129
+ "Batch size %d is wrong. It must be a multiple of # GPUs %d." \
130
+ % (opt.batch_size, len(opt.gpu_ids))
131
+
132
+ return opt
133
+
134
+ def apply_layer_freezing(model, num_layers_to_freeze, last_layer_mode):
135
+ """Apply selective layer freezing and handle the last layer based on mode."""
136
+ children = list(model.named_children())
137
+ total_layers = len(children)
138
+
139
+ # Freeze specified layers from the start
140
+ for i, (name, module) in enumerate(children):
141
+ if i < num_layers_to_freeze:
142
+ for param in module.parameters():
143
+ param.requires_grad = False
144
+
145
+ # Handle the last layer based on mode
146
+ if total_layers > 0 and last_layer_mode != 'train':
147
+ last_name, last_module = children[-1]
148
+ if last_layer_mode == 'freeze':
149
+ for param in last_module.parameters():
150
+ param.requires_grad = False
151
+ elif last_layer_mode == 'half':
152
+ # Freeze half of the parameters in the last layer
153
+ params = list(last_module.parameters())
154
+ half_idx = len(params) // 2
155
+ for param in params[:half_idx]:
156
+ param.requires_grad = False
157
+ for param in params[half_idx:]:
158
+ param.requires_grad = True
159
+
160
+ def train(opt, train_loader, test_loader, tocg, generator, discriminator, model):
161
+ """
162
+ Train Generator and Condition Generator
163
+ """
164
+ # Model
165
+ tocg.cuda()
166
+ tocg.train() # Enable training for tocg
167
+ generator.train()
168
+ discriminator.train()
169
+ if not opt.composition_mask:
170
+ discriminator.feature_network.requires_grad_(False)
171
+ discriminator.cuda()
172
+ model.eval()
173
+
174
+ # Apply layer freezing
175
+ apply_layer_freezing(tocg, opt.freeze_tocg_layers, opt.last_layer_mode)
176
+ apply_layer_freezing(generator, opt.freeze_gen_layers, opt.last_layer_mode)
177
+
178
+ # criterion
179
+ criterionGAN = None
180
+ if opt.fp16:
181
+ if opt.composition_mask:
182
+ criterionGAN = GANLoss('hinge', tensor=torch.cuda.HalfTensor)
183
+ else:
184
+ criterionGAN = Projected_GANs_Loss(tensor=torch.cuda.HalfTensor)
185
+ else:
186
+ if opt.composition_mask:
187
+ criterionGAN = GANLoss('hinge', tensor=torch.cuda.FloatTensor)
188
+ else:
189
+ criterionGAN = Projected_GANs_Loss(tensor=torch.cuda.FloatTensor)
190
+
191
+ criterionL1 = nn.L1Loss()
192
+ criterionFeat = nn.L1Loss()
193
+ criterionVGG = VGGLoss()
194
+
195
+ # optimizer
196
+ optimizer_gen = torch.optim.Adam(
197
+ list(generator.parameters()) + list(tocg.parameters()), # Include tocg parameters
198
+ lr=opt.G_lr, betas=(0.0, 0.9)
199
+ )
200
+ scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda=lambda step: 1.0 -
201
+ max(0, step * 1000 + opt.load_step - opt.keep_step) / float(opt.decay_step + 1))
202
+ optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=opt.D_lr, betas=(0.0, 0.9))
203
+ scheduler_dis = torch.optim.lr_scheduler.LambdaLR(optimizer_dis, lr_lambda=lambda step: 1.0 -
204
+ max(0, step * 1000 + opt.load_step - opt.keep_step) / float(opt.decay_step + 1))
205
+
206
+ if opt.fp16:
207
+ from apex import amp
208
+ [tocg, generator, discriminator], [optimizer_gen, optimizer_dis] = amp.initialize(
209
+ [tocg, generator, discriminator], [optimizer_gen, optimizer_dis], opt_level='O1', num_losses=2)
210
+
211
+ if len(opt.gpu_ids) > 0:
212
+ tocg = DataParallelWithCallback(tocg, device_ids=opt.gpu_ids)
213
+ generator = DataParallelWithCallback(generator, device_ids=opt.gpu_ids)
214
+ discriminator = DataParallelWithCallback(discriminator, device_ids=opt.gpu_ids)
215
+ criterionGAN = DataParallelWithCallback(criterionGAN, device_ids=opt.gpu_ids)
216
+ criterionFeat = DataParallelWithCallback(criterionFeat, device_ids=opt.gpu_ids)
217
+ criterionVGG = DataParallelWithCallback(criterionVGG, device_ids=opt.gpu_ids)
218
+ criterionL1 = DataParallelWithCallback(criterionL1, device_ids=opt.gpu_ids)
219
+
220
+ upsample = torch.nn.Upsample(scale_factor=4, mode='bilinear')
221
+ gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
222
+ gauss = gauss.cuda()
223
+
224
+ checkpoint_path = os.path.join(opt.checkpoint_dir, opt.name, 'checkpoint.pth')
225
+ if opt.resume:
226
+ if os.path.exists(checkpoint_path):
227
+ print(f"Resuming from checkpoint: {checkpoint_path}")
228
+ checkpoint = torch.load(checkpoint_path)
229
+ opt.load_step = checkpoint['step']
230
+ generator.load_state_dict(checkpoint['generator_state_dict'])
231
+ discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
232
+ tocg.load_state_dict(checkpoint['tocg_state_dict']) # Load tocg state
233
+ optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
234
+ optimizer_dis.load_state_dict(checkpoint['optimizer_dis_state_dict'])
235
+ scheduler_gen.load_state_dict(checkpoint['scheduler_gen_state_dict'])
236
+ scheduler_dis.load_state_dict(checkpoint['scheduler_dis_state_dict'])
237
+ else:
238
+ print(f"Checkpoint not found at {checkpoint_path}, starting from scratch")
239
+
240
+ for step in tqdm(range(opt.load_step, opt.keep_step + opt.decay_step)):
241
+ iter_start_time = time.time()
242
+ inputs = train_loader.next_batch()
243
+
244
+ # input
245
+ agnostic = inputs['agnostic'].cuda()
246
+ parse_GT = inputs['parse'].cuda()
247
+ pose = inputs['densepose'].cuda()
248
+ parse_cloth = inputs['parse_cloth'].cuda()
249
+ parse_agnostic = inputs['parse_agnostic'].cuda()
250
+ pcm = inputs['pcm'].cuda()
251
+ cm = inputs['cloth_mask']['paired'].cuda()
252
+ c_paired = inputs['cloth']['paired'].cuda()
253
+
254
+ # target
255
+ im = inputs['image'].cuda()
256
+
257
+ # Warping Cloth (tocg is now trainable)
258
+ pre_clothes_mask_down = F.interpolate(cm, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='nearest')
259
+ input_parse_agnostic_down = F.interpolate(parse_agnostic, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='nearest')
260
+ clothes_down = F.interpolate(c_paired, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='bilinear')
261
+ densepose_down = F.interpolate(pose, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='bilinear')
262
+
263
+ input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
264
+ input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
265
+
266
+ flow_list_taco, fake_segmap, warped_cloth_paired_taco, warped_clothmask_paired_taco, flow_list_tvob, warped_cloth_paired_tvob, warped_clothmask_paired_tvob = tocg(input1, input2)
267
+
268
+ warped_clothmask_paired_taco_onehot = torch.FloatTensor((warped_clothmask_paired_taco.detach().cpu().numpy() > 0.5).astype(float)).cuda()
269
+
270
+ cloth_mask = torch.ones_like(fake_segmap)
271
+ cloth_mask[:,3:4, :, :] = warped_clothmask_paired_taco
272
+ fake_segmap = fake_segmap * cloth_mask
273
+
274
+ N, _, iH, iW = c_paired.shape
275
+ N, flow_iH, flow_iW, _ = flow_list_tvob[-1].shape
276
+
277
+ flow_tvob = F.interpolate(flow_list_tvob[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
278
+ flow_tvob_norm = torch.cat([flow_tvob[:, :, :, 0:1] / ((flow_iW - 1.0) / 2.0), flow_tvob[:, :, :, 1:2] / ((flow_iH - 1.0) / 2.0)], 3)
279
+
280
+ grid = make_grid(N, iH, iW)
281
+ grid_3d = make_grid_3d(N, iH, iW)
282
+
283
+ warped_grid_tvob = grid + flow_tvob_norm
284
+ warped_cloth_tvob = F.grid_sample(c_paired, warped_grid_tvob, padding_mode='border')
285
+ warped_clothmask_tvob = F.grid_sample(cm, warped_grid_tvob, padding_mode='border')
286
+
287
+ flow_taco = F.interpolate(flow_list_taco[-1].permute(0, 4, 1, 2, 3), size=(2,iH,iW), mode='trilinear').permute(0, 2, 3, 4, 1)
288
+ flow_taco_norm = torch.cat([flow_taco[:, :, :, :, 0:1] / ((flow_iW - 1.0) / 2.0), flow_taco[:, :, :, :, 1:2] / ((flow_iH - 1.0) / 2.0), flow_taco[:, :, :, :, 2:3]], 4)
289
+ warped_cloth_tvob = warped_cloth_tvob.unsqueeze(2)
290
+ warped_cloth_paired_taco = F.grid_sample(torch.cat((warped_cloth_tvob, torch.zeros_like(warped_cloth_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
291
+ warped_cloth_paired_taco = warped_cloth_paired_taco[:,:,0,:,:]
292
+
293
+ warped_clothmask_tvob = warped_clothmask_tvob.unsqueeze(2)
294
+ warped_clothmask_taco = F.grid_sample(torch.cat((warped_clothmask_tvob, torch.zeros_like(warped_clothmask_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
295
+ warped_clothmask_taco = warped_clothmask_taco[:,:,0,:,:]
296
+
297
+ fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(iH, iW), mode='bilinear'))
298
+ fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
299
+
300
+ if opt.occlusion:
301
+ warped_clothmask_taco = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask_taco)
302
+ warped_cloth_paired_taco = warped_cloth_paired_taco * warped_clothmask_taco + torch.ones_like(warped_cloth_paired_taco) * (1-warped_clothmask_taco)
303
+ warped_cloth_paired_taco = warped_cloth_paired_taco.detach()
304
+
305
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
306
+ old_parse.scatter_(1, fake_parse, 1.0)
307
+
308
+ labels = {
309
+ 0: ['background', [0]],
310
+ 1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
311
+ 2: ['upper', [3]],
312
+ 3: ['hair', [1]],
313
+ 4: ['left_arm', [5]],
314
+ 5: ['right_arm', [6]],
315
+ 6: ['noise', [12]]
316
+ }
317
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
318
+ for i in range(len(labels)):
319
+ for label in labels[i][1]:
320
+ parse[:, i] += old_parse[:, label]
321
+
322
+ parse = parse.detach()
323
+
324
+ # Train the generator and tocg
325
+ G_losses = {}
326
+ if opt.composition_mask:
327
+ output_paired_rendered, output_paired_comp = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
328
+ output_paired_comp1 = output_paired_comp * warped_clothmask_taco
329
+ output_paired_comp = parse[:,2:3,:,:] * output_paired_comp1
330
+ output_paired = warped_cloth_paired_taco * output_paired_comp + output_paired_rendered * (1 - output_paired_comp)
331
+
332
+ fake_concat = torch.cat((parse, output_paired_rendered), dim=1)
333
+ real_concat = torch.cat((parse, im), dim=1)
334
+ pred = discriminator(torch.cat((fake_concat, real_concat), dim=0))
335
+
336
+ pred_fake = []
337
+ pred_real = []
338
+ for p in pred:
339
+ pred_fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
340
+ pred_real.append([tensor[tensor.size(0) // 2:] for tensor in p])
341
+
342
+ G_losses['GAN'] = criterionGAN(pred_fake, True, for_discriminator=False)
343
+
344
+ num_D = len(pred_fake)
345
+ GAN_Feat_loss = torch.cuda.FloatTensor(len(opt.gpu_ids)).zero_()
346
+ for i in range(num_D):
347
+ num_intermediate_outputs = len(pred_fake[i]) - 1
348
+ for j in range(num_intermediate_outputs):
349
+ unweighted_loss = criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
350
+ GAN_Feat_loss += unweighted_loss * opt.lambda_feat / num_D
351
+ G_losses['GAN_Feat'] = GAN_Feat_loss
352
+
353
+ G_losses['VGG'] = criterionVGG(output_paired, im) * opt.lambda_vgg + criterionVGG(output_paired_rendered, im) * opt.lambda_vgg
354
+ G_losses['L1'] = criterionL1(output_paired_rendered, im) * opt.lambda_l1 + criterionL1(output_paired, im) * opt.lambda_l1
355
+ G_losses['Composition_Mask'] = torch.mean(torch.abs(1 - output_paired_comp))
356
+
357
+ loss_gen = sum(G_losses.values()).mean()
358
+
359
+ else:
360
+ set_requires_grad(discriminator, False)
361
+ output_paired = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
362
+
363
+ pred_fake, feats_fake = discriminator(output_paired)
364
+ pred_real, feats_real = discriminator(im)
365
+
366
+ G_losses['GAN'] = criterionGAN(pred_fake, True, for_discriminator=False) * 0.5
367
+
368
+ num_D = len(feats_fake)
369
+ GAN_Feat_loss = torch.cuda.FloatTensor(len(opt.gpu_ids)).zero_()
370
+ for i in range(num_D):
371
+ num_intermediate_outputs = len(feats_fake[i])
372
+ for j in range(num_intermediate_outputs):
373
+ unweighted_loss = criterionFeat(feats_fake[i][j], feats_real[i][j].detach())
374
+ GAN_Feat_loss += unweighted_loss * opt.lambda_feat / num_D
375
+ G_losses['GAN_Feat'] = GAN_Feat_loss
376
+
377
+ G_losses['VGG'] = criterionVGG(output_paired, im) * opt.lambda_vgg
378
+ G_losses['L1'] = criterionL1(output_paired, im) * opt.lambda_l1
379
+
380
+ loss_gen = sum(G_losses.values()).mean()
381
+
382
+ optimizer_gen.zero_grad()
383
+ if opt.fp16:
384
+ with amp.scale_loss(loss_gen, optimizer_gen, loss_id=0) as loss_gen_scaled:
385
+ loss_gen_scaled.backward()
386
+ else:
387
+ loss_gen.backward()
388
+ optimizer_gen.step()
389
+
390
+ # Train the discriminator
391
+ D_losses = {}
392
+ if opt.composition_mask:
393
+ with torch.no_grad():
394
+ output_paired_rendered, output_comp = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
395
+ output_comp1 = output_comp * warped_clothmask_taco
396
+ output_comp = parse[:,2:3,:,:] * output_comp1
397
+ output = warped_cloth_paired_taco * output_comp + output_paired_rendered * (1 - output_comp)
398
+ output_comp = output_comp.detach()
399
+ output = output.detach()
400
+ output_comp.requires_grad_()
401
+ output.requires_grad_()
402
+
403
+ fake_concat = torch.cat((parse, output_paired_rendered), dim=1)
404
+ real_concat = torch.cat((parse, im), dim=1)
405
+ pred = discriminator(torch.cat((fake_concat, real_concat), dim=0))
406
+
407
+ pred_fake = []
408
+ pred_real = []
409
+ for p in pred:
410
+ pred_fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
411
+ pred_real.append([tensor[tensor.size(0) // 2:] for tensor in p])
412
+
413
+ D_losses['D_Fake'] = criterionGAN(pred_fake, False, for_discriminator=True)
414
+ D_losses['D_Real'] = criterionGAN(pred_real, True, for_discriminator=True)
415
+
416
+ loss_dis = sum(D_losses.values()).mean()
417
+
418
+ else:
419
+ set_requires_grad(discriminator, True)
420
+ discriminator.module.feature_network.requires_grad_(False)
421
+
422
+ with torch.no_grad():
423
+ output = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
424
+ output = output.detach()
425
+ output.requires_grad_()
426
+
427
+ pred_fake, _ = discriminator(output)
428
+ pred_real, _ = discriminator(im)
429
+
430
+ D_losses['D_Fake'] = criterionGAN(pred_fake, False, for_discriminator=True)
431
+ D_losses['D_Real'] = criterionGAN(pred_real, True, for_discriminator=True)
432
+
433
+ loss_dis = sum(D_losses.values()).mean()
434
+
435
+ optimizer_dis.zero_grad()
436
+ if opt.fp16:
437
+ with amp.scale_loss(loss_dis, optimizer_dis, loss_id=1) as loss_dis_scaled:
438
+ loss_dis_scaled.backward()
439
+ else:
440
+ loss_dis.backward()
441
+ optimizer_dis.step()
442
+
443
+ if not opt.composition_mask:
444
+ set_requires_grad(discriminator, False)
445
+
446
+ if (step+1) % 100 == 0:
447
+ a_0 = im.cuda()[0]
448
+ b_0 = output.cuda()[0]
449
+ c_0 = warped_cloth_paired_taco.cuda()[0]
450
+ combine = torch.cat((a_0, b_0, c_0), dim=2)
451
+ cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2
452
+ rgb=(cv_img*255).astype(np.uint8)
453
+ bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
454
+ cv2.imwrite('sample_fs_toig/'+str(step)+'.jpg',bgr)
455
+
456
+ # Evaluate the generator
457
+ if (step + 1) % opt.lpips_count == 0:
458
+ generator.eval()
459
+ tocg.eval()
460
+ T2 = transforms.Compose([transforms.Resize((128, 128))])
461
+ lpips_list = []
462
+ avg_distance = 0.0
463
+
464
+ with torch.no_grad():
465
+ print("LPIPS")
466
+ for i in tqdm(range(500)):
467
+ inputs = test_loader.next_batch()
468
+ agnostic = inputs['agnostic'].cuda()
469
+ parse_GT = inputs['parse'].cuda()
470
+ pose = inputs['densepose'].cuda()
471
+ parse_cloth = inputs['parse_cloth'].cuda()
472
+ parse_agnostic = inputs['parse_agnostic'].cuda()
473
+ pcm = inputs['pcm'].cuda()
474
+ cm = inputs['cloth_mask']['paired'].cuda()
475
+ c_paired = inputs['cloth']['paired'].cuda()
476
+ im = inputs['image'].cuda()
477
+
478
+ pre_clothes_mask_down = F.interpolate(cm, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='nearest')
479
+ input_parse_agnostic_down = F.interpolate(parse_agnostic, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='nearest')
480
+ clothes_down = F.interpolate(c_paired, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='bilinear')
481
+ densepose_down = F.interpolate(pose, size=(opt.cond_G_input_height, opt.cond_G_input_width), mode='bilinear')
482
+
483
+ input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
484
+ input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
485
+
486
+ flow_list_taco, fake_segmap, warped_cloth_paired_taco, warped_clothmask_paired_taco, flow_list_tvob, warped_cloth_paired_tvob, warped_clothmask_paired_tvob = tocg(input1, input2)
487
+
488
+ warped_clothmask_paired_taco_onehot = torch.FloatTensor((warped_clothmask_paired_taco.detach().cpu().numpy() > 0.5).astype(float)).cuda()
489
+
490
+ cloth_mask = torch.ones_like(fake_segmap)
491
+ cloth_mask[:,3:4, :, :] = warped_clothmask_paired_taco
492
+ fake_segmap = fake_segmap * cloth_mask
493
+
494
+ N, _, iH, iW = c_paired.shape
495
+ N, flow_iH, flow_iW, _ = flow_list_tvob[-1].shape
496
+
497
+ flow_tvob = F.interpolate(flow_list_tvob[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
498
+ flow_tvob_norm = torch.cat([flow_tvob[:, :, :, 0:1] / ((flow_iW - 1.0) / 2.0), flow_tvob[:, :, :, 1:2] / ((flow_iH - 1.0) / 2.0)], 3)
499
+
500
+ grid = make_grid(N, iH, iW)
501
+ grid_3d = make_grid_3d(N, iH, iW)
502
+
503
+ warped_grid_tvob = grid + flow_tvob_norm
504
+ warped_cloth_tvob = F.grid_sample(c_paired, warped_grid_tvob, padding_mode='border')
505
+ warped_clothmask_tvob = F.grid_sample(cm, warped_grid_tvob, padding_mode='border')
506
+
507
+ flow_taco = F.interpolate(flow_list_taco[-1].permute(0, 4, 1, 2, 3), size=(2, iH, iW), mode='trilinear').permute(0, 2, 3, 4, 1)
508
+ flow_taco_norm = torch.cat([flow_taco[:, :, :, :, 0:1] / ((flow_iW - 1.0) / 2.0), flow_taco[:, :, :, :, 1:2] / ((flow_iH - 1.0) / 2.0), flow_taco[:, :, :, :, 2:3]], 4)
509
+ warped_cloth_tvob = warped_cloth_tvob.unsqueeze(2)
510
+ warped_cloth_paired_taco = F.grid_sample(torch.cat((warped_cloth_tvob, torch.zeros_like(warped_cloth_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
511
+ warped_cloth_paired_taco = warped_cloth_paired_taco[:,:,0,:,:]
512
+
513
+ warped_clothmask_tvob = warped_clothmask_tvob.unsqueeze(2)
514
+ warped_clothmask_taco = F.grid_sample(torch.cat((warped_clothmask_tvob, torch.zeros_like(warped_clothmask_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
515
+ warped_clothmask_taco = warped_clothmask_taco[:,:,0,:,:]
516
+
517
+ fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(iH, iW), mode='bilinear'))
518
+ fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
519
+
520
+ if opt.occlusion:
521
+ warped_clothmask_taco = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask_taco)
522
+ warped_cloth_paired_taco = warped_cloth_paired_taco * warped_clothmask_taco + torch.ones_like(warped_cloth_paired_taco) * (1-warped_clothmask_taco)
523
+ warped_cloth_paired_taco = warped_cloth_paired_taco.detach()
524
+
525
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
526
+ old_parse.scatter_(1, fake_parse, 1.0)
527
+
528
+ labels = {
529
+ 0: ['background', [0]],
530
+ 1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
531
+ 2: ['upper', [3]],
532
+ 3: ['hair', [1]],
533
+ 4: ['left_arm', [5]],
534
+ 5: ['right_arm', [6]],
535
+ 6: ['noise', [12]]
536
+ }
537
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
538
+ for i in range(len(labels)):
539
+ for label in labels[i][1]:
540
+ parse[:, i] += old_parse[:, label]
541
+
542
+ parse = parse.detach()
543
+
544
+ if opt.composition_mask:
545
+ output_paired_rendered, output_paired_comp = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
546
+ output_paired_comp1 = output_paired_comp * warped_clothmask_taco
547
+ output_paired_comp = parse[:,2:3,:,:] * output_paired_comp1
548
+ output_paired = warped_cloth_paired_taco * output_paired_comp + output_paired_rendered * (1 - output_paired_comp)
549
+ else:
550
+ output_paired = generator(torch.cat((agnostic, pose, warped_cloth_paired_taco), dim=1), parse)
551
+
552
+ avg_distance += model.forward(T2(im), T2(output_paired))
553
+
554
+ avg_distance = avg_distance / 500
555
+ print(f"LPIPS: {avg_distance}")
556
+ generator.train()
557
+ tocg.train()
558
+
559
+ if (step + 1) % opt.display_count == 0:
560
+ t = time.time() - iter_start_time
561
+ print("step: %8d, time: %.3f, G_loss: %.4f, G_adv_loss: %.4f, D_loss: %.4f, D_fake_loss: %.4f, D_real_loss: %.4f"
562
+ % (step + 1, t, loss_gen.item(), G_losses['GAN'].mean().item(), loss_dis.item(),
563
+ D_losses['D_Fake'].mean().item(), D_losses['D_Real'].mean().item()), flush=True)
564
+
565
+ if (step + 1) % opt.save_count == 0:
566
+ checkpoint = {
567
+ 'step': step + 1,
568
+ 'generator_state_dict': generator.state_dict(),
569
+ 'discriminator_state_dict': discriminator.state_dict(),
570
+ 'tocg_state_dict': tocg.state_dict(), # Save tocg state
571
+ 'optimizer_gen_state_dict': optimizer_gen.state_dict(),
572
+ 'optimizer_dis_state_dict': optimizer_dis.state_dict(),
573
+ 'scheduler_gen_state_dict': scheduler_gen.state_dict(),
574
+ 'scheduler_dis_state_dict': scheduler_dis.state_dict(),
575
+ }
576
+ torch.save(checkpoint, checkpoint_path)
577
+
578
+ if (step + 1) % 1000 == 0:
579
+ scheduler_gen.step()
580
+ scheduler_dis.step()
581
+
582
+ def main():
583
+ opt = get_opt()
584
+ print(opt)
585
+ print("Start to train %s!" % opt.name)
586
+
587
+ os.makedirs('sample_fs_toig', exist_ok=True)
588
+ os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
589
+
590
+ train_dataset = CPDataset(opt)
591
+ train_loader = CPDataLoader(opt, train_dataset)
592
+
593
+ opt.batch_size = 1
594
+ opt.dataroot = opt.test_dataroot
595
+ opt.datamode = 'test'
596
+ opt.data_list = opt.test_data_list
597
+ test_dataset = CPDatasetTest(opt)
598
+ test_dataset = Subset(test_dataset, np.arange(500))
599
+ test_loader = CPDataLoader(opt, test_dataset)
600
+
601
+ input1_nc = 4
602
+ input2_nc = opt.semantic_nc + 3
603
+ tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=13, ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers)
604
+ load_checkpoint(tocg, opt.tocg_checkpoint)
605
+
606
+ generator = SPADEGenerator(opt, 3+3+3)
607
+ generator.print_network()
608
+ if len(opt.gpu_ids) > 0:
609
+ assert(torch.cuda.is_available())
610
+ generator.cuda()
611
+ generator.init_weights(opt.init_type, opt.init_variance)
612
+
613
+ discriminator = None
614
+ if opt.composition_mask:
615
+ discriminator = create_network(MultiscaleDiscriminator, opt)
616
+ else:
617
+ discriminator = ProjectedDiscriminator(interp224=False)
618
+
619
+ model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=True)
620
+
621
+ if opt.gen_checkpoint and os.path.exists(opt.gen_checkpoint):
622
+ load_checkpoint(generator, opt.gen_checkpoint)
623
+ if opt.dis_checkpoint and os.path.exists(opt.dis_checkpoint):
624
+ load_checkpoint(discriminator, opt.dis_checkpoint)
625
+
626
+ train(opt, train_loader, test_loader, tocg, generator, discriminator, model)
627
+
628
+ save_checkpoint(generator, os.path.join(opt.checkpoint_dir, opt.name, 'gen_model_final.pth'))
629
+ save_checkpoint(discriminator, os.path.join(opt.checkpoint_dir, opt.name, 'dis_model_final.pth'))
630
+ save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_model_final.pth'))
631
+
632
+ print("Finished training %s!" % opt.name)
633
+
634
+ if __name__ == "__main__":
635
+ main()