from __future__ import absolute_import import sys import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable from torch.nn import functional as F import numpy as np from pdb import set_trace as st from skimage import color from IPython import embed from . import pretrained_networks as pn from losses import masked_lpips as util def spatial_average(in_tens, mask=None, keepdim=True): if mask is None: return in_tens.mean([2, 3], keepdim=keepdim) else: in_tens = in_tens * mask # sum masked_in_tens across spatial dims in_tens = in_tens.sum([2, 3], keepdim=keepdim) in_tens = in_tens / torch.sum(mask) return in_tens def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W in_H = in_tens.shape[2] scale_factor = 1.0 * out_H / in_H return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)( in_tens ) # Learned perceptual metric class PNetLin(nn.Module): def __init__( self, pnet_type="vgg", pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version="0.1", lpips=True, vgg_blocks=[1, 2, 3, 4, 5] ): super(PNetLin, self).__init__() self.pnet_type = pnet_type self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: net_type = pn.vgg16 self.blocks = vgg_blocks self.chns = [] self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == "alex": net_type = pn.alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": net_type = pn.squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] #self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == "squeeze": # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] def forward(self, in0, in1, mask=None, retPerLayer=False): # blocks: list of layer names # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = ( (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) ) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} # prepare list of masks at different resolutions if mask is not None: masks = [] if len(mask.shape) == 3: mask = torch.unsqueeze(mask, axis=0) # 4D for kk in range(self.L): N, C, H, W = outs0[kk].shape mask = F.interpolate(mask, size=(H, W), mode="nearest") masks.append(mask) """ outs0 has 5 feature maps 1. [1, 64, 256, 256] 2. [1, 128, 128, 128] 3. [1, 256, 64, 64] 4. [1, 512, 32, 32] 5. [1, 512, 16, 16] """ for kk in range(self.L): feats0[kk], feats1[kk] = ( util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]), ) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.lpips: if self.spatial: res = [ upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L) ] else: # NOTE: this block is used # self.lins has 5 elements, where each element is a layer of LIN """ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] """ # NOTE: # Each lins is applying a 1x1 conv on the spatial tensor to output 1 channel # Therefore, to prevent this problem, we can simply mask out the activations # in the spatial_average block. Right now, spatial_average does a spatial mean. # We can mask out the tensor and then consider only on pixels for the mean op. res = [ spatial_average( self.lins[kk].model(diffs[kk]), mask=masks[kk] if mask is not None else None, keepdim=True, ) for kk in range(self.L) ] else: if self.spatial: res = [ upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L) ] else: res = [ spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L) ] ''' val = res[0] for l in range(1, self.L): val += res[l] ''' val = 0.0 for l in range(self.L): # l is going to run from 0 to 4 # check if (l + 1), i.e., [1 -> 5] in self.blocks, then count the loss if str(l + 1) in self.blocks: val += res[l] if retPerLayer: return (val, res) else: return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer( "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] ) self.register_buffer( "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] ) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """ A single linear layer which does a 1x1 conv """ def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = ( [ nn.Dropout(), ] if (use_dropout) else [] ) layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) class Dist2LogitLayer(nn.Module): """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """ def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [ nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] if use_sigmoid: layers += [ nn.Sigmoid(), ] self.model = nn.Sequential(*layers) def forward(self, d0, d1, eps=0.1): return self.model.forward( torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) ) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge + 1.0) / 2.0 self.logit = self.net.forward(d0, d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace="Lab"): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": (N, C, X, Y) = in0.size() value = torch.mean( torch.mean( torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 ).view(N, 1, 1, Y), dim=3, ).view(N) return value elif self.colorspace == "Lab": value = util.l2( util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": value = util.dssim( 1.0 * util.tensor2im(in0.data), 1.0 * util.tensor2im(in1.data), range=255.0, ).astype("float") elif self.colorspace == "Lab": value = util.dssim( util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print("Network", net) print("Total number of parameters: %d" % num_params)