Ubaida10 commited on
Commit
8b261fe
·
verified ·
1 Parent(s): c39d2ef

Upload network_generator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. network_generator.py +476 -0
network_generator.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from torch.nn.utils import spectral_norm
6
+ import numpy as np
7
+
8
+
9
+ class BaseNetwork(nn.Module):
10
+ def __init__(self):
11
+ super(BaseNetwork, self).__init__()
12
+
13
+ def print_network(self):
14
+ num_params = 0
15
+ for param in self.parameters():
16
+ num_params += param.numel()
17
+ print("Network [{}] was created. Total number of parameters: {:.1f} million. "
18
+ "To see the architecture, do print(network).".format(self.__class__.__name__, num_params / 1000000))
19
+
20
+ def init_weights(self, init_type='normal', gain=0.02):
21
+ def init_func(m):
22
+ classname = m.__class__.__name__
23
+ if 'BatchNorm2d' in classname:
24
+ if hasattr(m, 'weight') and m.weight is not None:
25
+ init.normal_(m.weight.data, 1.0, gain)
26
+ if hasattr(m, 'bias') and m.bias is not None:
27
+ init.constant_(m.bias.data, 0.0)
28
+ elif ('Conv' in classname or 'Linear' in classname) and hasattr(m, 'weight'):
29
+ if init_type == 'normal':
30
+ init.normal_(m.weight.data, 0.0, gain)
31
+ elif init_type == 'xavier':
32
+ init.xavier_normal_(m.weight.data, gain=gain)
33
+ elif init_type == 'xavier_uniform':
34
+ init.xavier_uniform_(m.weight.data, gain=1.0)
35
+ elif init_type == 'kaiming':
36
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37
+ elif init_type == 'orthogonal':
38
+ init.orthogonal_(m.weight.data, gain=gain)
39
+ elif init_type == 'none': # uses pytorch's default init method
40
+ m.reset_parameters()
41
+ else:
42
+ raise NotImplementedError("initialization method '{}' is not implemented".format(init_type))
43
+ if hasattr(m, 'bias') and m.bias is not None:
44
+ init.constant_(m.bias.data, 0.0)
45
+
46
+ self.apply(init_func)
47
+
48
+ def forward(self, *inputs):
49
+ pass
50
+
51
+
52
+ class MaskNorm(nn.Module):
53
+ def __init__(self, norm_nc):
54
+ super(MaskNorm, self).__init__()
55
+
56
+ self.norm_layer = nn.InstanceNorm2d(norm_nc, affine=False)
57
+
58
+ def normalize_region(self, region, mask):
59
+ b, c, h, w = region.size()
60
+
61
+ num_pixels = mask.sum((2, 3), keepdim=True) # size: (b, 1, 1, 1)
62
+ num_pixels[num_pixels == 0] = 1
63
+ mu = region.sum((2, 3), keepdim=True) / num_pixels # size: (b, c, 1, 1)
64
+
65
+ normalized_region = self.norm_layer(region + (1 - mask) * mu)
66
+ return normalized_region * torch.sqrt(num_pixels / (h * w))
67
+
68
+ def forward(self, x, mask):
69
+ mask = mask.detach()
70
+ normalized_foreground = self.normalize_region(x * mask, mask)
71
+ normalized_background = self.normalize_region(x * (1 - mask), 1 - mask)
72
+ return normalized_foreground + normalized_background
73
+
74
+
75
+ class SPADENorm(nn.Module):
76
+ def __init__(self, norm_type, norm_nc, label_nc):
77
+ super(SPADENorm, self).__init__()
78
+
79
+ self.noise_scale = nn.Parameter(torch.zeros(norm_nc))
80
+
81
+ assert norm_type.startswith('alias')
82
+ param_free_norm_type = norm_type[len('alias'):]
83
+ if param_free_norm_type == 'batch':
84
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
85
+ elif param_free_norm_type == 'instance':
86
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
87
+ elif param_free_norm_type == 'mask':
88
+ self.param_free_norm = MaskNorm(norm_nc)
89
+ else:
90
+ raise ValueError(
91
+ "'{}' is not a recognized parameter-free normalization type in SPADENorm".format(param_free_norm_type)
92
+ )
93
+
94
+ nhidden = 128
95
+ ks = 3
96
+ pw = ks // 2
97
+ self.conv_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
98
+ self.conv_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
99
+ self.conv_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
100
+
101
+ def forward(self, x, seg, misalign_mask=None):
102
+ # Part 1. Generate parameter-free normalized activations.
103
+ b, c, h, w = x.size()
104
+ noise = (torch.randn(b, w, h, 1).cuda() * self.noise_scale).transpose(1, 3)
105
+
106
+ if misalign_mask is None:
107
+ normalized = self.param_free_norm(x + noise)
108
+ else:
109
+ normalized = self.param_free_norm(x + noise, misalign_mask)
110
+
111
+ # Part 2. Produce affine parameters conditioned on the segmentation map.
112
+ actv = self.conv_shared(seg)
113
+ gamma = self.conv_gamma(actv)
114
+ beta = self.conv_beta(actv)
115
+
116
+ # Apply the affine parameters.
117
+ output = normalized * (1 + gamma) + beta
118
+ return output
119
+
120
+
121
+ class SPADEResBlock(nn.Module):
122
+ def __init__(self, opt, input_nc, output_nc, use_mask_norm=True):
123
+ super(SPADEResBlock, self).__init__()
124
+
125
+ self.learned_shortcut = (input_nc != output_nc)
126
+ middle_nc = min(input_nc, output_nc)
127
+
128
+ self.conv_0 = nn.Conv2d(input_nc, middle_nc, kernel_size=3, padding=1)
129
+ self.conv_1 = nn.Conv2d(middle_nc, output_nc, kernel_size=3, padding=1)
130
+ if self.learned_shortcut:
131
+ self.conv_s = nn.Conv2d(input_nc, output_nc, kernel_size=1, bias=False)
132
+
133
+ subnorm_type = opt.norm_G
134
+ if subnorm_type.startswith('spectral'):
135
+ subnorm_type = subnorm_type[len('spectral'):]
136
+ self.conv_0 = spectral_norm(self.conv_0)
137
+ self.conv_1 = spectral_norm(self.conv_1)
138
+ if self.learned_shortcut:
139
+ self.conv_s = spectral_norm(self.conv_s)
140
+
141
+ gen_semantic_nc = opt.gen_semantic_nc
142
+ if use_mask_norm:
143
+ subnorm_type = 'aliasmask'
144
+ gen_semantic_nc = gen_semantic_nc + 1
145
+
146
+ self.norm_0 = SPADENorm(subnorm_type, input_nc, gen_semantic_nc)
147
+ self.norm_1 = SPADENorm(subnorm_type, middle_nc, gen_semantic_nc)
148
+ if self.learned_shortcut:
149
+ self.norm_s = SPADENorm(subnorm_type, input_nc, gen_semantic_nc)
150
+
151
+ self.relu = nn.LeakyReLU(0.2)
152
+
153
+ def shortcut(self, x, seg, misalign_mask):
154
+ if self.learned_shortcut:
155
+ return self.conv_s(self.norm_s(x, seg, misalign_mask))
156
+ else:
157
+ return x
158
+
159
+ def forward(self, x, seg, misalign_mask=None):
160
+ seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
161
+ if misalign_mask is not None:
162
+ misalign_mask = F.interpolate(misalign_mask, size=x.size()[2:], mode='nearest')
163
+
164
+ x_s = self.shortcut(x, seg, misalign_mask)
165
+
166
+ dx = self.conv_0(self.relu(self.norm_0(x, seg, misalign_mask)))
167
+ dx = self.conv_1(self.relu(self.norm_1(dx, seg, misalign_mask)))
168
+ output = x_s + dx
169
+ return output
170
+
171
+
172
+ class SPADEGenerator(BaseNetwork):
173
+ def __init__(self, opt, input_nc):
174
+ super(SPADEGenerator, self).__init__()
175
+
176
+ self.opt = opt
177
+
178
+ self.num_upsampling_layers = opt.num_upsampling_layers
179
+
180
+ self.sh, self.sw = self.compute_latent_vector_size(opt)
181
+
182
+ nf = opt.ngf
183
+ self.conv_0 = nn.Conv2d(input_nc, nf * 16, kernel_size=3, padding=1)
184
+ for i in range(1, 8):
185
+ self.add_module('conv_{}'.format(i), nn.Conv2d(input_nc, 16, kernel_size=3, padding=1))
186
+
187
+ self.head_0 = SPADEResBlock(opt, nf * 16, nf * 16, use_mask_norm=False)
188
+
189
+ self.G_middle_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
190
+ self.G_middle_1 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
191
+
192
+ self.up_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 8, use_mask_norm=False)
193
+ self.up_1 = SPADEResBlock(opt, nf * 8 + 16, nf * 4, use_mask_norm=False)
194
+ self.up_2 = SPADEResBlock(opt, nf * 4 + 16, nf * 2, use_mask_norm=False)
195
+ self.up_3 = SPADEResBlock(opt, nf * 2 + 16, nf * 1, use_mask_norm=False)
196
+ if self.num_upsampling_layers == 'most':
197
+ self.up_4 = SPADEResBlock(opt, nf * 1 + 16, nf // 2, use_mask_norm=False)
198
+ nf = nf // 2
199
+
200
+ if opt.composition_mask:
201
+ self.conv_img = nn.Conv2d(nf, 4, kernel_size=3, padding=1)
202
+ self.sigmoid = nn.Sigmoid()
203
+ else:
204
+ self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
205
+
206
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
207
+ self.relu = nn.LeakyReLU(0.2)
208
+ self.tanh = nn.Tanh()
209
+
210
+ def compute_latent_vector_size(self, opt):
211
+ if self.num_upsampling_layers == 'normal':
212
+ num_up_layers = 5
213
+ elif self.num_upsampling_layers == 'more':
214
+ num_up_layers = 6
215
+ elif self.num_upsampling_layers == 'most':
216
+ num_up_layers = 7
217
+ else:
218
+ raise ValueError("opt.num_upsampling_layers '{}' is not recognized".format(self.num_upsampling_layers))
219
+
220
+ sh = opt.fine_height // 2**num_up_layers
221
+ sw = opt.fine_width // 2**num_up_layers
222
+ return sh, sw
223
+
224
+ def forward(self, x, seg):
225
+ samples = [F.interpolate(x, size=(self.sh * 2**i, self.sw * 2**i), mode='nearest') for i in range(8)]
226
+ features = [self._modules['conv_{}'.format(i)](samples[i]) for i in range(8)]
227
+
228
+ x = self.head_0(features[0], seg)
229
+ x = self.up(x)
230
+ x = self.G_middle_0(torch.cat((x, features[1]), 1), seg)
231
+ if self.num_upsampling_layers in ['more', 'most']:
232
+ x = self.up(x)
233
+ x = self.G_middle_1(torch.cat((x, features[2]), 1), seg)
234
+
235
+ x = self.up(x)
236
+ x = self.up_0(torch.cat((x, features[3]), 1), seg)
237
+ x = self.up(x)
238
+ x = self.up_1(torch.cat((x, features[4]), 1), seg)
239
+ x = self.up(x)
240
+ x = self.up_2(torch.cat((x, features[5]), 1), seg)
241
+ x = self.up(x)
242
+ x = self.up_3(torch.cat((x, features[6]), 1), seg)
243
+ if self.num_upsampling_layers == 'most':
244
+ x = self.up(x)
245
+ x = self.up_4(torch.cat((x, features[7]), 1), seg)
246
+
247
+ x = self.conv_img(self.relu(x))
248
+
249
+ if self.opt.composition_mask:
250
+ x, comp_x = torch.split(x, [3, 1], 1)
251
+ return self.tanh(x), self.sigmoid(comp_x)
252
+ else:
253
+ return self.tanh(x)
254
+
255
+ ########################################################################
256
+
257
+ ########################################################################
258
+
259
+ class NLayerDiscriminator(BaseNetwork):
260
+
261
+ def __init__(self, opt):
262
+ super().__init__()
263
+ self.no_ganFeat_loss = opt.no_ganFeat_loss
264
+ nf = opt.ndf
265
+
266
+ kw = 4
267
+ pw = int(np.ceil((kw - 1.0) / 2))
268
+ norm_layer = get_nonspade_norm_layer(opt.norm_D)
269
+
270
+ input_nc = opt.gen_semantic_nc + 3
271
+ # input_nc = opt.gen_semantic_nc + 13
272
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=pw),
273
+ nn.LeakyReLU(0.2, False)]]
274
+
275
+ for n in range(1, opt.n_layers_D):
276
+ nf_prev = nf
277
+ nf = min(nf * 2, 512)
278
+ sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=pw)),
279
+ nn.LeakyReLU(0.2, False)]]
280
+
281
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=pw)]]
282
+
283
+ # We divide the layers into groups to extract intermediate layer outputs
284
+ for n in range(len(sequence)):
285
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
286
+
287
+ def forward(self, input):
288
+ results = [input]
289
+ for submodel in self.children():
290
+ intermediate_output = submodel(results[-1])
291
+ results.append(intermediate_output)
292
+
293
+ get_intermediate_features = not self.no_ganFeat_loss
294
+ if get_intermediate_features:
295
+ return results[1:]
296
+ else:
297
+ return results[-1]
298
+
299
+
300
+ class MultiscaleDiscriminator(BaseNetwork):
301
+
302
+ def __init__(self, opt):
303
+ super().__init__()
304
+ self.no_ganFeat_loss = opt.no_ganFeat_loss
305
+
306
+ for i in range(opt.num_D):
307
+ subnetD = NLayerDiscriminator(opt)
308
+ self.add_module('discriminator_%d' % i, subnetD)
309
+
310
+ def downsample(self, input):
311
+ return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
312
+
313
+ # Returns list of lists of discriminator outputs.
314
+ # The final result is of size opt.num_D x opt.n_layers_D
315
+ def forward(self, input):
316
+ result = []
317
+ get_intermediate_features = not self.no_ganFeat_loss
318
+ for name, D in self.named_children():
319
+ out = D(input)
320
+ if not get_intermediate_features:
321
+ out = [out]
322
+ result.append(out)
323
+ input = self.downsample(input)
324
+
325
+ return result
326
+
327
+ def set_requires_grad(nets, requires_grad=False):
328
+ if not isinstance(nets, list):
329
+ nets = [nets]
330
+ for net in nets:
331
+ if net is not None:
332
+ for param in net.parameters():
333
+ param.requires_grad = requires_grad
334
+
335
+ class Projected_GANs_Loss(nn.Module):
336
+
337
+ def __init__(self, tensor=torch.FloatTensor):
338
+ super(Projected_GANs_Loss, self).__init__()
339
+ self.Tensor = tensor
340
+
341
+ def __call__(self, input, label, for_discriminator):
342
+
343
+ return self.loss(input, label, for_discriminator)
344
+
345
+ def loss(self, input, target_is_real, for_discriminator=True):
346
+
347
+ if for_discriminator == False:
348
+
349
+ return (-input).mean()
350
+
351
+ else:
352
+ real_label_tensor = self.Tensor(1).fill_(1.0)
353
+ real_label_tensor = real_label_tensor.requires_grad_(False)
354
+ real_label_tensor = real_label_tensor.expand_as(input)
355
+ if target_is_real:
356
+ return (F.relu(real_label_tensor - input)).mean()
357
+ else:
358
+ return (F.relu(real_label_tensor + input)).mean()
359
+
360
+
361
+ class GANLoss(nn.Module):
362
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
363
+ super(GANLoss, self).__init__()
364
+ self.real_label = target_real_label
365
+ self.fake_label = target_fake_label
366
+ self.real_label_tensor = None
367
+ self.fake_label_tensor = None
368
+ self.zero_tensor = None
369
+ self.Tensor = tensor
370
+ self.gan_mode = gan_mode
371
+ if gan_mode == 'ls':
372
+ pass
373
+ elif gan_mode == 'original':
374
+ pass
375
+ elif gan_mode == 'w':
376
+ pass
377
+ elif gan_mode == 'hinge':
378
+ pass
379
+ else:
380
+ raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
381
+
382
+ def get_target_tensor(self, input, target_is_real):
383
+ if target_is_real:
384
+ if self.real_label_tensor is None:
385
+ self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
386
+ self.real_label_tensor.requires_grad_(False)
387
+ return self.real_label_tensor.expand_as(input)
388
+ else:
389
+ if self.fake_label_tensor is None:
390
+ self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
391
+ self.fake_label_tensor.requires_grad_(False)
392
+ return self.fake_label_tensor.expand_as(input)
393
+
394
+ def get_zero_tensor(self, input):
395
+ if self.zero_tensor is None:
396
+ self.zero_tensor = self.Tensor(1).fill_(0)
397
+ self.zero_tensor.requires_grad_(False)
398
+ return self.zero_tensor.expand_as(input)
399
+
400
+ def loss(self, input, target_is_real, for_discriminator=True):
401
+ if self.gan_mode == 'original': # cross entropy loss
402
+ target_tensor = self.get_target_tensor(input, target_is_real)
403
+ loss = F.binary_cross_entropy_with_logits(input, target_tensor)
404
+ return loss
405
+ elif self.gan_mode == 'ls':
406
+ target_tensor = self.get_target_tensor(input, target_is_real)
407
+ return F.mse_loss(input, target_tensor)
408
+ elif self.gan_mode == 'hinge':
409
+ if for_discriminator:
410
+ if target_is_real:
411
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
412
+ loss = -torch.mean(minval)
413
+ else:
414
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
415
+ loss = -torch.mean(minval)
416
+ else:
417
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
418
+ loss = -torch.mean(input)
419
+ return loss
420
+ else:
421
+ # wgan
422
+ if target_is_real:
423
+ return -input.mean()
424
+ else:
425
+ return input.mean()
426
+
427
+ def __call__(self, input, target_is_real, for_discriminator=True):
428
+ # computing loss is a bit complicated because |input| may not be
429
+ # a tensor, but list of tensors in case of multiscale discriminator
430
+ if isinstance(input, list):
431
+ loss = 0
432
+ for pred_i in input:
433
+ if isinstance(pred_i, list):
434
+ pred_i = pred_i[-1]
435
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
436
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
437
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
438
+ loss += new_loss
439
+ return loss / len(input)
440
+ else:
441
+ return self.loss(input, target_is_real, for_discriminator)
442
+
443
+
444
+ def get_nonspade_norm_layer(norm_type='instance'):
445
+ def get_out_channel(layer):
446
+ if hasattr(layer, 'out_channels'):
447
+ return getattr(layer, 'out_channels')
448
+ return layer.weight.size(0)
449
+
450
+ def add_norm_layer(layer):
451
+ nonlocal norm_type
452
+ if norm_type.startswith('spectral'):
453
+ layer = spectral_norm(layer)
454
+ subnorm_type = norm_type[len('spectral'):]
455
+
456
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
457
+ return layer
458
+
459
+ # remove bias in the previous layer, which is meaningless
460
+ # since it has no effect after normalization
461
+ if getattr(layer, 'bias', None) is not None:
462
+ delattr(layer, 'bias')
463
+ layer.register_parameter('bias', None)
464
+
465
+ if subnorm_type == 'batch':
466
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
467
+ # elif subnorm_type == 'sync_batch':
468
+ # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
469
+ elif subnorm_type == 'instance':
470
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
471
+ else:
472
+ raise ValueError('normalization layer %s is not recognized' % subnorm_type)
473
+
474
+ return nn.Sequential(layer, norm_layer)
475
+
476
+ return add_norm_layer