Ubaida10 commited on
Commit
fb413bf
·
verified ·
1 Parent(s): 550ed07

Upload networks.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. networks.py +545 -0
networks.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from torchvision import models
6
+ import os
7
+ from torch.nn.utils import spectral_norm
8
+ import numpy as np
9
+
10
+ import functools
11
+
12
+
13
+ class ConditionGenerator(nn.Module):
14
+ def __init__(self, opt, input1_nc, input2_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, num_layers=5):
15
+ super(ConditionGenerator, self).__init__()
16
+ self.warp_feature = opt.warp_feature
17
+ self.out_layer_opt = opt.out_layer
18
+
19
+ if num_layers == 5:
20
+ self.ClothEncoder = nn.Sequential(
21
+ ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 256
22
+ ResBlock(ngf, ngf*2, norm_layer=norm_layer, scale='down'), # 128
23
+ ResBlock(ngf*2, ngf*4, norm_layer=norm_layer, scale='down'), # 64
24
+ ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
25
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
26
+ )
27
+
28
+ self.PoseEncoder = nn.Sequential(
29
+ ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
30
+ ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
31
+ ResBlock(ngf * 2, ngf*4, norm_layer=norm_layer, scale='down'),
32
+ ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'),
33
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
34
+ )
35
+
36
+ if opt.warp_feature == 'T1':
37
+ # in_nc -> skip connection + T1, T2 channel
38
+ self.SegDecoder = nn.Sequential(
39
+ ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
40
+ ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
41
+ ResBlock(ngf * 4 * 2 + ngf * 4, ngf * 2, norm_layer=norm_layer, scale='up'), # 64
42
+ ResBlock(ngf * 2 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 128
43
+ ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 256
44
+ )
45
+
46
+ # Cloth Conv 1x1
47
+ self.conv1 = nn.Sequential(
48
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
49
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
50
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
51
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
52
+ )
53
+
54
+ # Person Conv 1x1
55
+ self.conv2 = nn.Sequential(
56
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
57
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
58
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
59
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
60
+ )
61
+
62
+ self.flow_conv = nn.ModuleList([
63
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
64
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
65
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
66
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
67
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
68
+ ]
69
+ )
70
+
71
+ self.bottleneck = nn.Sequential(
72
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
73
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
74
+ nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
75
+ nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
76
+ )
77
+
78
+ if num_layers == 6:
79
+ self.ClothEncoder = nn.Sequential(
80
+ ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 512
81
+ ResBlock(ngf, ngf*2, norm_layer=norm_layer, scale='down'), # 256
82
+ ResBlock(ngf*2, ngf*4, norm_layer=norm_layer, scale='down'), # 128
83
+ ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'), # 64
84
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
85
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
86
+ )
87
+
88
+ self.PoseEncoder = nn.Sequential(
89
+ ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
90
+ ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
91
+ ResBlock(ngf * 2, ngf*4, norm_layer=norm_layer, scale='down'),
92
+ ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'),
93
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
94
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
95
+ )
96
+
97
+ if opt.warp_feature == 'T1':
98
+ # in_nc -> skip connection + T1, T2 channel
99
+ self.SegDecoder = nn.Sequential(
100
+ ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
101
+ ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
102
+ ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 64
103
+ ResBlock(ngf * 4 * 2 + ngf * 4, ngf * 2, norm_layer=norm_layer, scale='up'), # 128
104
+ ResBlock(ngf * 2 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 256
105
+ ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 512
106
+ )
107
+
108
+ # Cloth Conv 1x1
109
+ self.conv1 = nn.Sequential(
110
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
111
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
112
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
113
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
114
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
115
+ )
116
+
117
+ # Person Conv 1x1
118
+ self.conv2 = nn.Sequential(
119
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
120
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
121
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
122
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
123
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
124
+ )
125
+
126
+ self.flow_conv = nn.ModuleList([
127
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
128
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
129
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
130
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
131
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
132
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
133
+ ]
134
+ )
135
+
136
+ self.bottleneck = nn.Sequential(
137
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
138
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
139
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
140
+ nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
141
+ nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
142
+ )
143
+
144
+
145
+ self.conv = ResBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, scale='same')
146
+
147
+ if opt.out_layer == 'relu':
148
+ self.out_layer = ResBlock(ngf + ngf, output_nc, norm_layer=norm_layer, scale='same')
149
+
150
+ self.residual_sequential_flow_list = nn.Sequential(
151
+ nn.Sequential(nn.Conv3d(ngf * 8, 3, kernel_size=3, stride=1, padding=1, bias=True)),
152
+ )
153
+ self.out_layer_input1_resblk = ResBlock(input1_nc + input2_nc, ngf, norm_layer=norm_layer, scale='same')
154
+
155
+ self.num_layers = num_layers
156
+
157
+ def normalize(self, x):
158
+ return x
159
+
160
+ def forward(self, input1, input2, upsample='bilinear'):
161
+ E1_list = []
162
+ E2_list = []
163
+ flow_list_tvob = []
164
+ flow_list_taco = []
165
+ layers_max_idx = self.num_layers - 1
166
+
167
+ # Feature Pyramid Network
168
+ for i in range(self.num_layers):
169
+ if i == 0:
170
+ E1_list.append(self.ClothEncoder[i](input1))
171
+ E2_list.append(self.PoseEncoder[i](input2))
172
+ else:
173
+ E1_list.append(self.ClothEncoder[i](E1_list[i - 1]))
174
+ E2_list.append(self.PoseEncoder[i](E2_list[i - 1]))
175
+
176
+ # Compute Clothflow
177
+ for i in range(self.num_layers):
178
+ N, _, iH, iW = E1_list[layers_max_idx - i].size()
179
+ grid = make_grid(N, iH, iW)
180
+
181
+ if i == 0:
182
+ T1 = E1_list[layers_max_idx - i] # (ngf * 4) x 8 x 6
183
+ T2 = E2_list[layers_max_idx - i]
184
+ E4 = torch.cat([T1, T2], 1)
185
+
186
+ flow = self.flow_conv[i](self.normalize(E4)).permute(0, 2, 3, 1)
187
+ flow_list_tvob.append(flow)
188
+
189
+ x = self.conv(T2)
190
+ x = self.SegDecoder[i](x)
191
+
192
+ else:
193
+ T1 = F.interpolate(T1, scale_factor=2, mode=upsample) + self.conv1[layers_max_idx - i](E1_list[layers_max_idx - i])
194
+ T2 = F.interpolate(T2, scale_factor=2, mode=upsample) + self.conv2[layers_max_idx - i](E2_list[layers_max_idx - i])
195
+
196
+ flow = F.interpolate(flow_list_tvob[i - 1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1) # upsample n-1 flow
197
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
198
+ warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
199
+
200
+ flow = flow + self.flow_conv[i](self.normalize(torch.cat([warped_T1, self.bottleneck[i-1](x)], 1))).permute(0, 2, 3, 1) # F(n)
201
+ flow_list_tvob.append(flow)
202
+
203
+ # TACO layer of SD-VITON
204
+ if i == layers_max_idx:
205
+ ## Eq.10 of SD-VITON
206
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH - 1.0) / 2.0)], 3)
207
+ warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
208
+ input_3d_flow_out = self.normalize(torch.cat([warped_T1, T2], 1)).unsqueeze(2)
209
+ flow_out = self.residual_sequential_flow_list[0](torch.cat((input_3d_flow_out, torch.zeros_like(input_3d_flow_out).cuda()), dim=2)).permute(0,2,3,4,1)
210
+ flow_list_taco.append(flow_out)
211
+
212
+ if self.warp_feature == 'T1':
213
+ x = self.SegDecoder[i](torch.cat([x, E2_list[layers_max_idx-i], warped_T1], 1))
214
+
215
+ ## Eq.11 of SD-VITON
216
+ N, _, iH, iW = input1.size()
217
+ grid = make_grid(N, iH, iW)
218
+ grid_3d = make_grid_3d(N, iH, iW)
219
+
220
+ flow_tvob = F.interpolate(flow_list_tvob[-1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1)
221
+ flow_tvob_norm = torch.cat([flow_tvob[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow_tvob[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
222
+ warped_input1_tvob = F.grid_sample(input1, flow_tvob_norm + grid, padding_mode='border')
223
+
224
+ flow_taco = F.interpolate(flow_list_taco[-1].permute(0, 4, 1, 2, 3), scale_factor=(1,2,2), mode='trilinear').permute(0, 2, 3, 4, 1)
225
+ flow_taco_norm = torch.cat([flow_taco[:, :, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow_taco[:, :, :, :, 1:2] / ((iH/2 - 1.0) / 2.0), flow_taco[:, :, :, :, 2:3]], 4)
226
+ warped_input1_tvob = warped_input1_tvob.unsqueeze(2)
227
+ warped_input1_taco = F.grid_sample(torch.cat((warped_input1_tvob, torch.zeros_like(warped_input1_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
228
+ warped_input1_taco_non_roi = warped_input1_taco[:,:,1,:,:]
229
+ warped_input1_taco_roi = warped_input1_taco[:,:,0,:,:]
230
+ warped_input1_tvob = warped_input1_tvob[:,:,0,:,:]
231
+
232
+ out_inputs_resblk = self.out_layer_input1_resblk(torch.cat([input2, warped_input1_taco_roi], 1))
233
+
234
+ x = self.out_layer(torch.cat([x, out_inputs_resblk], 1))
235
+
236
+ warped_c_tvob = warped_input1_tvob[:, :-1, :, :]
237
+ warped_cm_tvob = warped_input1_tvob[:, -1:, :, :]
238
+
239
+ warped_c_taco_roi = warped_input1_taco_roi[:, :-1, :, :]
240
+ warped_cm_taco_roi = warped_input1_taco_roi[:, -1:, :, :]
241
+
242
+ warped_c_taco_non_roi = warped_input1_taco_non_roi[:,:-1,:,:]
243
+ warped_cm_taco_non_roi = warped_input1_taco_non_roi[:,-1:,:,:]
244
+
245
+ return flow_list_taco, x, warped_c_taco_roi, warped_cm_taco_roi, flow_list_tvob, warped_c_tvob, warped_cm_tvob
246
+
247
+ def make_grid_3d(N, iH, iW):
248
+ grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, 1, iW, 1).expand(N, 2, iH, -1, -1)
249
+ grid_y = torch.linspace(-1.0, 1.0, iH).view(1, 1, iH, 1, 1).expand(N, 2, -1, iW, -1)
250
+ grid_z = torch.linspace(-1.0, 1.0, 2).view(1, 2, 1, 1, 1).expand(N, -1, iH, iW, -1)
251
+ grid = torch.cat([grid_x, grid_y, grid_z], 4).cuda()
252
+ return grid
253
+
254
+ def make_grid(N, iH, iW):
255
+ grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1)
256
+ grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1)
257
+ grid = torch.cat([grid_x, grid_y], 3).cuda()
258
+ return grid
259
+
260
+ class ResBlock(nn.Module):
261
+ def __init__(self, in_nc, out_nc, scale='down', norm_layer=nn.BatchNorm2d):
262
+ super(ResBlock, self).__init__()
263
+ use_bias = norm_layer == nn.InstanceNorm2d
264
+ assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
265
+
266
+ if scale == 'same':
267
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
268
+ if scale == 'up':
269
+ self.scale = nn.Sequential(
270
+ nn.Upsample(scale_factor=2, mode='bilinear'),
271
+ nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
272
+ )
273
+ if scale == 'down':
274
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
275
+
276
+ self.block = nn.Sequential(
277
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
278
+ norm_layer(out_nc),
279
+ nn.ReLU(inplace=True),
280
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
281
+ norm_layer(out_nc)
282
+ )
283
+ self.relu = nn.ReLU(inplace=True)
284
+
285
+ def forward(self, x):
286
+ residual = self.scale(x)
287
+ return self.relu(residual + self.block(residual))
288
+
289
+
290
+ class Vgg19(nn.Module):
291
+ def __init__(self, requires_grad=False):
292
+ super(Vgg19, self).__init__()
293
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
294
+ self.slice1 = torch.nn.Sequential()
295
+ self.slice2 = torch.nn.Sequential()
296
+ self.slice3 = torch.nn.Sequential()
297
+ self.slice4 = torch.nn.Sequential()
298
+ self.slice5 = torch.nn.Sequential()
299
+ for x in range(2):
300
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
301
+ for x in range(2, 7):
302
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
303
+ for x in range(7, 12):
304
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
305
+ for x in range(12, 21):
306
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
307
+ for x in range(21, 30):
308
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
309
+ if not requires_grad:
310
+ for param in self.parameters():
311
+ param.requires_grad = False
312
+
313
+ def forward(self, X):
314
+ h_relu1 = self.slice1(X)
315
+ h_relu2 = self.slice2(h_relu1)
316
+ h_relu3 = self.slice3(h_relu2)
317
+ h_relu4 = self.slice4(h_relu3)
318
+ h_relu5 = self.slice5(h_relu4)
319
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
320
+ return out
321
+
322
+
323
+ class VGGLoss(nn.Module):
324
+ def __init__(self, layids = None):
325
+ super(VGGLoss, self).__init__()
326
+ self.vgg = Vgg19()
327
+ self.vgg.cuda()
328
+ self.criterion = nn.L1Loss()
329
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
330
+ self.layids = layids
331
+
332
+ def forward(self, x, y):
333
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
334
+ loss = 0
335
+ if self.layids is None:
336
+ self.layids = list(range(len(x_vgg)))
337
+ for i in self.layids:
338
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
339
+ return loss
340
+
341
+
342
+ class GANLoss(nn.Module):
343
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
344
+ tensor=torch.FloatTensor):
345
+ super(GANLoss, self).__init__()
346
+ self.real_label = target_real_label
347
+ self.fake_label = target_fake_label
348
+ self.real_label_var = None
349
+ self.fake_label_var = None
350
+ self.Tensor = tensor
351
+ if use_lsgan:
352
+ self.loss = nn.MSELoss()
353
+ else:
354
+ self.loss = nn.BCELoss()
355
+
356
+ def get_target_tensor(self, input, target_is_real):
357
+ if target_is_real:
358
+ create_label = ((self.real_label_var is None) or
359
+ (self.real_label_var.numel() != input.numel()))
360
+ if create_label:
361
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
362
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
363
+ target_tensor = self.real_label_var
364
+ else:
365
+ create_label = ((self.fake_label_var is None) or
366
+ (self.fake_label_var.numel() != input.numel()))
367
+ if create_label:
368
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
369
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
370
+ target_tensor = self.fake_label_var
371
+ return target_tensor
372
+
373
+ def __call__(self, input, target_is_real):
374
+ if isinstance(input[0], list):
375
+ loss = 0
376
+ for input_i in input:
377
+ pred = input_i[-1]
378
+ target_tensor = self.get_target_tensor(pred, target_is_real)
379
+ loss += self.loss(pred, target_tensor)
380
+ return loss
381
+ else:
382
+ target_tensor = self.get_target_tensor(input[-1], target_is_real)
383
+ return self.loss(input[-1], target_tensor)
384
+
385
+
386
+ class MultiscaleDiscriminator(nn.Module):
387
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
388
+ use_sigmoid=False, num_D=3, getIntermFeat=False, Ddownx2=False, Ddropout=False, spectral=False):
389
+ super(MultiscaleDiscriminator, self).__init__()
390
+ self.num_D = num_D
391
+ self.n_layers = n_layers
392
+ self.getIntermFeat = getIntermFeat
393
+ self.Ddownx2 = Ddownx2
394
+
395
+
396
+ for i in range(num_D):
397
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat, Ddropout, spectral=spectral)
398
+ if getIntermFeat:
399
+ for j in range(n_layers + 2):
400
+ setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
401
+ else:
402
+ setattr(self, 'layer' + str(i), netD.model)
403
+
404
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
405
+
406
+ def singleD_forward(self, model, input):
407
+ if self.getIntermFeat:
408
+ result = [input]
409
+ for i in range(len(model)):
410
+ result.append(model[i](result[-1]))
411
+ return result[1:]
412
+ else:
413
+ return [model(input)]
414
+
415
+ def forward(self, input):
416
+ num_D = self.num_D
417
+
418
+ result = []
419
+ if self.Ddownx2:
420
+ input_downsampled = self.downsample(input)
421
+ else:
422
+ input_downsampled = input
423
+ for i in range(num_D):
424
+
425
+ if self.getIntermFeat:
426
+ model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
427
+ range(self.n_layers + 2)]
428
+ else:
429
+ model = getattr(self, 'layer' + str(num_D - 1 - i))
430
+ result.append(self.singleD_forward(model, input_downsampled))
431
+ if i != (num_D - 1):
432
+ input_downsampled = self.downsample(input_downsampled)
433
+ return result
434
+
435
+ class NLayerDiscriminator(nn.Module):
436
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False, Ddropout=False, spectral=False):
437
+ super(NLayerDiscriminator, self).__init__()
438
+ self.getIntermFeat = getIntermFeat
439
+ self.n_layers = n_layers
440
+ self.spectral_norm = spectral_norm if spectral else lambda x: x
441
+
442
+ kw = 4
443
+ padw = int(np.ceil((kw - 1.0) / 2))
444
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
445
+
446
+ nf = ndf
447
+ for n in range(1, n_layers):
448
+ nf_prev = nf
449
+ nf = min(nf * 2, 512)
450
+ if Ddropout:
451
+ sequence += [[
452
+ self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
453
+ norm_layer(nf), nn.LeakyReLU(0.2, True), nn.Dropout(0.5)
454
+ ]]
455
+ else:
456
+
457
+ sequence += [[
458
+ self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
459
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
460
+ ]]
461
+
462
+ nf_prev = nf
463
+ nf = min(nf * 2, 512)
464
+ sequence += [[
465
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
466
+ norm_layer(nf),
467
+ nn.LeakyReLU(0.2, True)
468
+ ]]
469
+
470
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
471
+
472
+ if use_sigmoid:
473
+ sequence += [[nn.Sigmoid()]]
474
+
475
+ if getIntermFeat:
476
+ for n in range(len(sequence)):
477
+ setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
478
+ else:
479
+ sequence_stream = []
480
+ for n in range(len(sequence)):
481
+ sequence_stream += sequence[n]
482
+ self.model = nn.Sequential(*sequence_stream)
483
+
484
+ def forward(self, input):
485
+ if self.getIntermFeat:
486
+ res = [input]
487
+ for n in range(self.n_layers + 2):
488
+ model = getattr(self, 'model' + str(n))
489
+ res.append(model(res[-1]))
490
+ return res[1:]
491
+ else:
492
+ return self.model(input)
493
+
494
+
495
+ def save_checkpoint(model, save_path):
496
+ if not os.path.exists(os.path.dirname(save_path)):
497
+ os.makedirs(os.path.dirname(save_path))
498
+
499
+ torch.save(model.cpu().state_dict(), save_path)
500
+ model.cuda()
501
+
502
+ def load_checkpoint(model, checkpoint_path):
503
+ if not os.path.exists(checkpoint_path):
504
+ print(" [*] checkpoint does not exist!")
505
+ return
506
+ print(" [*] Loading checkpoint from %s" % checkpoint_path)
507
+ state_dict = torch.load(checkpoint_path)
508
+ model_state_dict = model.state_dict()
509
+
510
+ # Remove keys that have shape mismatches
511
+ for key in list(state_dict.keys()):
512
+ if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
513
+ print(f"Removing {key} due to shape mismatch: {state_dict[key].shape} vs {model_state_dict[key].shape}")
514
+ del state_dict[key]
515
+
516
+ log = model.load_state_dict(state_dict, strict=False)
517
+ print(" [*] Load Success! log : ", log)
518
+
519
+
520
+ def weights_init(m):
521
+ classname = m.__class__.__name__
522
+ if classname.find('Conv2d') != -1:
523
+ m.weight.data.normal_(0.0, 0.02)
524
+ elif classname.find('BatchNorm2d') != -1:
525
+ m.weight.data.normal_(1.0, 0.02)
526
+ m.bias.data.fill_(0)
527
+
528
+ def get_norm_layer(norm_type='instance'):
529
+ if norm_type == 'batch':
530
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
531
+ elif norm_type == 'instance':
532
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
533
+ else:
534
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
535
+ return norm_layer
536
+
537
+ def define_D(input_nc, ndf=64, n_layers_D=3, norm='instance', use_sigmoid=False, num_D=2, getIntermFeat=False, gpu_ids=[], Ddownx2=False, Ddropout=False, spectral=False):
538
+ norm_layer = get_norm_layer(norm_type=norm)
539
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat, Ddownx2, Ddropout, spectral=spectral)
540
+ print(netD)
541
+ if len(gpu_ids) > 0:
542
+ assert (torch.cuda.is_available())
543
+ netD.cuda()
544
+ netD.apply(weights_init)
545
+ return netD