SD-VITON / networks.py
Ubaida10's picture
Upload networks.py with huggingface_hub
fb413bf verified
raw
history blame
25.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
import os
from torch.nn.utils import spectral_norm
import numpy as np
import functools
class ConditionGenerator(nn.Module):
def __init__(self, opt, input1_nc, input2_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, num_layers=5):
super(ConditionGenerator, self).__init__()
self.warp_feature = opt.warp_feature
self.out_layer_opt = opt.out_layer
if num_layers == 5:
self.ClothEncoder = nn.Sequential(
ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 256
ResBlock(ngf, ngf*2, norm_layer=norm_layer, scale='down'), # 128
ResBlock(ngf*2, ngf*4, norm_layer=norm_layer, scale='down'), # 64
ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
)
self.PoseEncoder = nn.Sequential(
ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
ResBlock(ngf * 2, ngf*4, norm_layer=norm_layer, scale='down'),
ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'),
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
)
if opt.warp_feature == 'T1':
# in_nc -> skip connection + T1, T2 channel
self.SegDecoder = nn.Sequential(
ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
ResBlock(ngf * 4 * 2 + ngf * 4, ngf * 2, norm_layer=norm_layer, scale='up'), # 64
ResBlock(ngf * 2 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 128
ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 256
)
# Cloth Conv 1x1
self.conv1 = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
)
# Person Conv 1x1
self.conv2 = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
)
self.flow_conv = nn.ModuleList([
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
]
)
self.bottleneck = nn.Sequential(
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
)
if num_layers == 6:
self.ClothEncoder = nn.Sequential(
ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 512
ResBlock(ngf, ngf*2, norm_layer=norm_layer, scale='down'), # 256
ResBlock(ngf*2, ngf*4, norm_layer=norm_layer, scale='down'), # 128
ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'), # 64
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
)
self.PoseEncoder = nn.Sequential(
ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
ResBlock(ngf * 2, ngf*4, norm_layer=norm_layer, scale='down'),
ResBlock(ngf*4, ngf * 4, norm_layer=norm_layer, scale='down'),
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
)
if opt.warp_feature == 'T1':
# in_nc -> skip connection + T1, T2 channel
self.SegDecoder = nn.Sequential(
ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 64
ResBlock(ngf * 4 * 2 + ngf * 4, ngf * 2, norm_layer=norm_layer, scale='up'), # 128
ResBlock(ngf * 2 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 256
ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up'), # 512
)
# Cloth Conv 1x1
self.conv1 = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
)
# Person Conv 1x1
self.conv2 = nn.Sequential(
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
)
self.flow_conv = nn.ModuleList([
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
]
)
self.bottleneck = nn.Sequential(
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
)
self.conv = ResBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, scale='same')
if opt.out_layer == 'relu':
self.out_layer = ResBlock(ngf + ngf, output_nc, norm_layer=norm_layer, scale='same')
self.residual_sequential_flow_list = nn.Sequential(
nn.Sequential(nn.Conv3d(ngf * 8, 3, kernel_size=3, stride=1, padding=1, bias=True)),
)
self.out_layer_input1_resblk = ResBlock(input1_nc + input2_nc, ngf, norm_layer=norm_layer, scale='same')
self.num_layers = num_layers
def normalize(self, x):
return x
def forward(self, input1, input2, upsample='bilinear'):
E1_list = []
E2_list = []
flow_list_tvob = []
flow_list_taco = []
layers_max_idx = self.num_layers - 1
# Feature Pyramid Network
for i in range(self.num_layers):
if i == 0:
E1_list.append(self.ClothEncoder[i](input1))
E2_list.append(self.PoseEncoder[i](input2))
else:
E1_list.append(self.ClothEncoder[i](E1_list[i - 1]))
E2_list.append(self.PoseEncoder[i](E2_list[i - 1]))
# Compute Clothflow
for i in range(self.num_layers):
N, _, iH, iW = E1_list[layers_max_idx - i].size()
grid = make_grid(N, iH, iW)
if i == 0:
T1 = E1_list[layers_max_idx - i] # (ngf * 4) x 8 x 6
T2 = E2_list[layers_max_idx - i]
E4 = torch.cat([T1, T2], 1)
flow = self.flow_conv[i](self.normalize(E4)).permute(0, 2, 3, 1)
flow_list_tvob.append(flow)
x = self.conv(T2)
x = self.SegDecoder[i](x)
else:
T1 = F.interpolate(T1, scale_factor=2, mode=upsample) + self.conv1[layers_max_idx - i](E1_list[layers_max_idx - i])
T2 = F.interpolate(T2, scale_factor=2, mode=upsample) + self.conv2[layers_max_idx - i](E2_list[layers_max_idx - i])
flow = F.interpolate(flow_list_tvob[i - 1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1) # upsample n-1 flow
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
flow = flow + self.flow_conv[i](self.normalize(torch.cat([warped_T1, self.bottleneck[i-1](x)], 1))).permute(0, 2, 3, 1) # F(n)
flow_list_tvob.append(flow)
# TACO layer of SD-VITON
if i == layers_max_idx:
## Eq.10 of SD-VITON
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH - 1.0) / 2.0)], 3)
warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
input_3d_flow_out = self.normalize(torch.cat([warped_T1, T2], 1)).unsqueeze(2)
flow_out = self.residual_sequential_flow_list[0](torch.cat((input_3d_flow_out, torch.zeros_like(input_3d_flow_out).cuda()), dim=2)).permute(0,2,3,4,1)
flow_list_taco.append(flow_out)
if self.warp_feature == 'T1':
x = self.SegDecoder[i](torch.cat([x, E2_list[layers_max_idx-i], warped_T1], 1))
## Eq.11 of SD-VITON
N, _, iH, iW = input1.size()
grid = make_grid(N, iH, iW)
grid_3d = make_grid_3d(N, iH, iW)
flow_tvob = F.interpolate(flow_list_tvob[-1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1)
flow_tvob_norm = torch.cat([flow_tvob[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow_tvob[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
warped_input1_tvob = F.grid_sample(input1, flow_tvob_norm + grid, padding_mode='border')
flow_taco = F.interpolate(flow_list_taco[-1].permute(0, 4, 1, 2, 3), scale_factor=(1,2,2), mode='trilinear').permute(0, 2, 3, 4, 1)
flow_taco_norm = torch.cat([flow_taco[:, :, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow_taco[:, :, :, :, 1:2] / ((iH/2 - 1.0) / 2.0), flow_taco[:, :, :, :, 2:3]], 4)
warped_input1_tvob = warped_input1_tvob.unsqueeze(2)
warped_input1_taco = F.grid_sample(torch.cat((warped_input1_tvob, torch.zeros_like(warped_input1_tvob).cuda()), dim=2), flow_taco_norm + grid_3d, padding_mode='border')
warped_input1_taco_non_roi = warped_input1_taco[:,:,1,:,:]
warped_input1_taco_roi = warped_input1_taco[:,:,0,:,:]
warped_input1_tvob = warped_input1_tvob[:,:,0,:,:]
out_inputs_resblk = self.out_layer_input1_resblk(torch.cat([input2, warped_input1_taco_roi], 1))
x = self.out_layer(torch.cat([x, out_inputs_resblk], 1))
warped_c_tvob = warped_input1_tvob[:, :-1, :, :]
warped_cm_tvob = warped_input1_tvob[:, -1:, :, :]
warped_c_taco_roi = warped_input1_taco_roi[:, :-1, :, :]
warped_cm_taco_roi = warped_input1_taco_roi[:, -1:, :, :]
warped_c_taco_non_roi = warped_input1_taco_non_roi[:,:-1,:,:]
warped_cm_taco_non_roi = warped_input1_taco_non_roi[:,-1:,:,:]
return flow_list_taco, x, warped_c_taco_roi, warped_cm_taco_roi, flow_list_tvob, warped_c_tvob, warped_cm_tvob
def make_grid_3d(N, iH, iW):
grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, 1, iW, 1).expand(N, 2, iH, -1, -1)
grid_y = torch.linspace(-1.0, 1.0, iH).view(1, 1, iH, 1, 1).expand(N, 2, -1, iW, -1)
grid_z = torch.linspace(-1.0, 1.0, 2).view(1, 2, 1, 1, 1).expand(N, -1, iH, iW, -1)
grid = torch.cat([grid_x, grid_y, grid_z], 4).cuda()
return grid
def make_grid(N, iH, iW):
grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1)
grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1)
grid = torch.cat([grid_x, grid_y], 3).cuda()
return grid
class ResBlock(nn.Module):
def __init__(self, in_nc, out_nc, scale='down', norm_layer=nn.BatchNorm2d):
super(ResBlock, self).__init__()
use_bias = norm_layer == nn.InstanceNorm2d
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
if scale == 'same':
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
if scale == 'up':
self.scale = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
)
if scale == 'down':
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
self.block = nn.Sequential(
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(out_nc),
nn.ReLU(inplace=True),
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(out_nc)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = self.scale(x)
return self.relu(residual + self.block(residual))
class Vgg19(nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self, layids = None):
super(VGGLoss, self).__init__()
self.vgg = Vgg19()
self.vgg.cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
self.layids = layids
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
if self.layids is None:
self.layids = list(range(len(x_vgg)))
for i in self.layids:
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False, Ddownx2=False, Ddropout=False, spectral=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
self.Ddownx2 = Ddownx2
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat, Ddropout, spectral=spectral)
if getIntermFeat:
for j in range(n_layers + 2):
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
else:
setattr(self, 'layer' + str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
if self.Ddownx2:
input_downsampled = self.downsample(input)
else:
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
range(self.n_layers + 2)]
else:
model = getattr(self, 'layer' + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False, Ddropout=False, spectral=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
self.spectral_norm = spectral_norm if spectral else lambda x: x
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
if Ddropout:
sequence += [[
self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
norm_layer(nf), nn.LeakyReLU(0.2, True), nn.Dropout(0.5)
]]
else:
sequence += [[
self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
def save_checkpoint(model, save_path):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(model.cpu().state_dict(), save_path)
model.cuda()
def load_checkpoint(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print(" [*] checkpoint does not exist!")
return
print(" [*] Loading checkpoint from %s" % checkpoint_path)
state_dict = torch.load(checkpoint_path)
model_state_dict = model.state_dict()
# Remove keys that have shape mismatches
for key in list(state_dict.keys()):
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
print(f"Removing {key} due to shape mismatch: {state_dict[key].shape} vs {model_state_dict[key].shape}")
del state_dict[key]
log = model.load_state_dict(state_dict, strict=False)
print(" [*] Load Success! log : ", log)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_D(input_nc, ndf=64, n_layers_D=3, norm='instance', use_sigmoid=False, num_D=2, getIntermFeat=False, gpu_ids=[], Ddownx2=False, Ddropout=False, spectral=False):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat, Ddownx2, Ddropout, spectral=spectral)
print(netD)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netD.cuda()
netD.apply(weights_init)
return netD