from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from skimage.metrics import structural_similarity import torch from torch.autograd import Variable from ..masked_lpips import dist_model class PerceptualLoss(torch.nn.Module): def __init__( self, model="net-lin", net="alex", vgg_blocks=[1, 2, 3, 4, 5], colorspace="rgb", spatial=False, use_gpu=True, gpu_ids=[0], ): # VGG using our perceptually-learned weights (LPIPS metric) # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss super(PerceptualLoss, self).__init__() print("Setting up Perceptual loss...") self.use_gpu = use_gpu self.spatial = spatial self.gpu_ids = gpu_ids self.model = dist_model.DistModel() self.model.initialize( model=model, net=net, vgg_blocks=vgg_blocks, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, ) print("...[%s] initialized" % self.model.name()) print("...Done") def forward(self, pred, target, mask=None, normalize=False): """ Pred and target are Variables. If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] If normalize is False, assumes the images are already between [-1,+1] Inputs pred and target are Nx3xHxW Output pytorch Variable N long """ if normalize: target = 2 * target - 1 pred = 2 * pred - 1 return self.model.forward(target, pred, mask=mask) def normalize_tensor(in_feat, eps=1e-10): # takes care of masked tensors implicitly. norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) return in_feat / (norm_factor + eps) def l2(p0, p1, range=255.0): return 0.5 * np.mean((p0 / range - p1 / range) ** 2) def psnr(p0, p1, peak=255.0): return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) def dssim(p0, p1, range=255.0): return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 def rgb2lab(in_img, mean_cent=False): from skimage import color img_lab = color.rgb2lab(in_img) if mean_cent: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 return img_lab def tensor2np(tensor_obj): # change dimension of a tensor object into a numpy array return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) def np2tensor(np_obj): # change dimenion of np array into tensor array return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): # image tensor to lab tensor from skimage import color img = tensor2im(image_tensor) img_lab = color.rgb2lab(img) if mc_only: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 if to_norm and not mc_only: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 img_lab = img_lab / 100.0 return np2tensor(img_lab) def tensorlab2tensor(lab_tensor, return_inbnd=False): from skimage import color import warnings warnings.filterwarnings("ignore") lab = tensor2np(lab_tensor) * 100.0 lab[:, :, 0] = lab[:, :, 0] + 50 rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) if return_inbnd: # convert back to lab, see if we match lab_back = color.rgb2lab(rgb_back.astype("uint8")) mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) return (im2tensor(rgb_back), mask) else: return im2tensor(rgb_back) def rgb2lab(input): from skimage import color return color.rgb2lab(input / 255.0) def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): return torch.Tensor( (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) ) def tensor2vec(vector_tensor): return vector_tensor.data.cpu().numpy()[:, :, 0, 0] def voc_ap(rec, prec, use_07_metric=False): """ap = voc_ap(rec, prec, [use_07_metric]) Compute VOC AP given precision and recall. If use_07_metric is true, uses the VOC 07 11 point method (default:False). """ if use_07_metric: # 11 point metric ap = 0.0 for t in np.arange(0.0, 1.1, 0.1): if np.sum(rec >= t) == 0: p = 0 else: p = np.max(prec[rec >= t]) ap = ap + p / 11.0 else: # correct AP calculation # first append sentinel values at the end mrec = np.concatenate(([0.0], rec, [1.0])) mpre = np.concatenate(([0.0], prec, [0.0])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): return torch.Tensor( (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) )