amanSethSmava
new commit
6d314be
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from skimage.metrics import structural_similarity
import torch
from torch.autograd import Variable
from ..masked_lpips import dist_model
class PerceptualLoss(torch.nn.Module):
def __init__(
self,
model="net-lin",
net="alex",
vgg_blocks=[1, 2, 3, 4, 5],
colorspace="rgb",
spatial=False,
use_gpu=True,
gpu_ids=[0],
): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print("Setting up Perceptual loss...")
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(
model=model,
net=net,
vgg_blocks=vgg_blocks,
use_gpu=use_gpu,
colorspace=colorspace,
spatial=self.spatial,
gpu_ids=gpu_ids,
)
print("...[%s] initialized" % self.model.name())
print("...Done")
def forward(self, pred, target, mask=None, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]
Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""
if normalize:
target = 2 * target - 1
pred = 2 * pred - 1
return self.model.forward(target, pred, mask=mask)
def normalize_tensor(in_feat, eps=1e-10):
# takes care of masked tensors implicitly.
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def l2(p0, p1, range=255.0):
return 0.5 * np.mean((p0 / range - p1 / range) ** 2)
def psnr(p0, p1, peak=255.0):
return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2))
def dssim(p0, p1, range=255.0):
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0
def rgb2lab(in_img, mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if mean_cent:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
return img_lab
def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
# image tensor to lab tensor
from skimage import color
img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if mc_only:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
if to_norm and not mc_only:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
img_lab = img_lab / 100.0
return np2tensor(img_lab)
def tensorlab2tensor(lab_tensor, return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")
lab = tensor2np(lab_tensor) * 100.0
lab[:, :, 0] = lab[:, :, 0] + 50
rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1)
if return_inbnd:
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype("uint8"))
mask = 1.0 * np.isclose(lab_back, lab, atol=2.0)
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
return (im2tensor(rgb_back), mask)
else:
return im2tensor(rgb_back)
def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.0)
def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
return torch.Tensor(
(image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
)
def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
def voc_ap(rec, prec, use_07_metric=False):
"""ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.0
for t in np.arange(0.0, 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.0
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.0], rec, [1.0]))
mpre = np.concatenate(([0.0], prec, [0.0]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor(
(image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
)