newTryOn / losses /masked_lpips /dist_model.py
amanSethSmava
new commit
6d314be
from __future__ import absolute_import
import sys
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
from .base_model import BaseModel
from scipy.ndimage import zoom
import fractions
import functools
import skimage.transform
from tqdm import tqdm
from IPython import embed
from . import networks_basic as netw
from losses import masked_lpips as util
class DistModel(BaseModel):
def name(self):
return self.model_name
def initialize(
self,
model="net-lin",
net="alex",
vgg_blocks=[1, 2, 3, 4, 5],
colorspace="Lab",
pnet_rand=False,
pnet_tune=False,
model_path=None,
use_gpu=True,
printNet=False,
spatial=False,
is_train=False,
lr=0.0001,
beta1=0.5,
version="0.1",
gpu_ids=[0],
):
"""
INPUTS
model - ['net-lin'] for linearly calibrated network
['net'] for off-the-shelf network
['L2'] for L2 distance in Lab colorspace
['SSIM'] for ssim in RGB colorspace
net - ['squeeze','alex','vgg']
model_path - if None, will look in weights/[NET_NAME].pth
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
use_gpu - bool - whether or not to use a GPU
printNet - bool - whether or not to print network architecture out
spatial - bool - whether to output an array containing varying distances across spatial dimensions
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
is_train - bool - [True] for training mode
lr - float - initial learning rate
beta1 - float - initial momentum term for adam
version - 0.1 for latest, 0.0 was original (with a bug)
gpu_ids - int array - [0] by default, gpus to use
"""
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
self.model = model
self.net = net
self.is_train = is_train
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model_name = "%s [%s]" % (model, net)
if self.model == "net-lin": # pretrained net + linear layer
self.net = netw.PNetLin(
pnet_rand=pnet_rand,
pnet_tune=pnet_tune,
pnet_type=net,
use_dropout=True,
spatial=spatial,
version=version,
lpips=True,
vgg_blocks=vgg_blocks,
)
kw = {}
if not use_gpu:
kw["map_location"] = "cpu"
if model_path is None:
import inspect
model_path = os.path.abspath(
os.path.join(
inspect.getfile(self.initialize),
"..",
"weights/v%s/%s.pth" % (version, net),
)
)
if not is_train:
print("Loading model from: %s" % model_path)
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
elif self.model == "net": # pretrained network
self.net = netw.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
elif self.model in ["L2", "l2"]:
self.net = netw.L2(
use_gpu=use_gpu, colorspace=colorspace
) # not really a network, only for testing
self.model_name = "L2"
elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]:
self.net = netw.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
self.model_name = "SSIM"
else:
raise ValueError("Model [%s] not recognized." % self.model)
self.parameters = list(self.net.parameters())
if self.is_train: # training mode
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
self.rankLoss = netw.BCERankingLoss()
self.parameters += list(self.rankLoss.net.parameters())
self.lr = lr
self.old_lr = lr
self.optimizer_net = torch.optim.Adam(
self.parameters, lr=lr, betas=(beta1, 0.999)
)
else: # test mode
self.net.eval()
if use_gpu:
self.net.to(gpu_ids[0])
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
if self.is_train:
self.rankLoss = self.rankLoss.to(
device=gpu_ids[0]
) # just put this on GPU0
if printNet:
print("---------- Networks initialized -------------")
netw.print_network(self.net)
print("-----------------------------------------------")
def forward(self, in0, in1, mask=None, retPerLayer=False):
"""Function computes the distance between image patches in0 and in1
INPUTS
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
OUTPUT
computed distances between in0 and in1
"""
return self.net.forward(in0, in1, mask=mask, retPerLayer=retPerLayer)
# ***** TRAINING FUNCTIONS *****
def optimize_parameters(self):
self.forward_train()
self.optimizer_net.zero_grad()
self.backward_train()
self.optimizer_net.step()
self.clamp_weights()
def clamp_weights(self):
for module in self.net.modules():
if hasattr(module, "weight") and module.kernel_size == (1, 1):
module.weight.data = torch.clamp(module.weight.data, min=0)
def set_input(self, data):
self.input_ref = data["ref"]
self.input_p0 = data["p0"]
self.input_p1 = data["p1"]
self.input_judge = data["judge"]
if self.use_gpu:
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
self.var_ref = Variable(self.input_ref, requires_grad=True)
self.var_p0 = Variable(self.input_p0, requires_grad=True)
self.var_p1 = Variable(self.input_p1, requires_grad=True)
def forward_train(self): # run forward pass
# print(self.net.module.scaling_layer.shift)
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
self.d0 = self.forward(self.var_ref, self.var_p0)
self.d1 = self.forward(self.var_ref, self.var_p1)
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size())
self.loss_total = self.rankLoss.forward(
self.d0, self.d1, self.var_judge * 2.0 - 1.0
)
return self.loss_total
def backward_train(self):
torch.mean(self.loss_total).backward()
def compute_accuracy(self, d0, d1, judge):
""" d0, d1 are Variables, judge is a Tensor """
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
judge_per = judge.cpu().numpy().flatten()
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
def get_current_errors(self):
retDict = OrderedDict(
[("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)]
)
for key in retDict.keys():
retDict[key] = np.mean(retDict[key])
return retDict
def get_current_visuals(self):
zoom_factor = 256 / self.var_ref.data.size()[2]
ref_img = util.tensor2im(self.var_ref.data)
p0_img = util.tensor2im(self.var_p0.data)
p1_img = util.tensor2im(self.var_p1.data)
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
return OrderedDict(
[("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)]
)
def save(self, path, label):
if self.use_gpu:
self.save_network(self.net.module, path, "", label)
else:
self.save_network(self.net, path, "", label)
self.save_network(self.rankLoss.net, path, "rank", label)
def update_learning_rate(self, nepoch_decay):
lrd = self.lr / nepoch_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_net.param_groups:
param_group["lr"] = lr
print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr))
self.old_lr = lr
def score_2afc_dataset(data_loader, func, name=""):
"""Function computes Two Alternative Forced Choice (2AFC) score using
distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
OUTPUTS
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
[1] - dictionary with following elements
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
gts - N array in [0,1], preferred patch selected by human evaluators
(closer to "0" for left patch p0, "1" for right patch p1,
"0.6" means 60pct people preferred right patch, 40pct preferred left)
scores - N array in [0,1], corresponding to what percentage function agreed with humans
CONSTS
N - number of test triplets in data_loader
"""
d0s = []
d1s = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist()
d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist()
gts += data["judge"].cpu().numpy().flatten().tolist()
d0s = np.array(d0s)
d1s = np.array(d1s)
gts = np.array(gts)
scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5
return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
def score_jnd_dataset(data_loader, func, name=""):
"""Function computes JND score using distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
OUTPUTS
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
[1] - dictionary with following elements
ds - N array containing distances between two patches shown to human evaluator
sames - N array containing fraction of people who thought the two patches were identical
CONSTS
N - number of test triplets in data_loader
"""
ds = []
gts = []
for data in tqdm(data_loader.load_data(), desc=name):
ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist()
gts += data["same"].cpu().numpy().flatten().tolist()
sames = np.array(gts)
ds = np.array(ds)
sorted_inds = np.argsort(ds)
ds_sorted = ds[sorted_inds]
sames_sorted = sames[sorted_inds]
TPs = np.cumsum(sames_sorted)
FPs = np.cumsum(1 - sames_sorted)
FNs = np.sum(sames_sorted) - TPs
precs = TPs / (TPs + FPs)
recs = TPs / (TPs + FNs)
score = util.voc_ap(recs, precs)
return (score, dict(ds=ds, sames=sames))