Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| class AlignLoss(nn.Module): | |
| def __init__(self, reduction='mean'): | |
| super().__init__() | |
| self.loss_fn = nn.L1Loss(reduction=reduction) | |
| def forward(self, frames, masks, aligned_vs, aligned_rs): | |
| """ | |
| :param frames: The original frames(GT) | |
| :param masks: Original masks | |
| :param aligned_vs: aligned visibility map from reference frame(List: B, C, T, H, W) | |
| :param aligned_rs: aligned reference frames(List: B, C, T, H, W) | |
| :return: | |
| """ | |
| try: | |
| B, C, T, H, W = frames.shape | |
| except ValueError: | |
| frames = frames.unsqueeze(2) | |
| masks = masks.unsqueeze(2) | |
| B, C, T, H, W = frames.shape | |
| loss = 0 | |
| for i in range(T): | |
| frame = frames[:, :, i] | |
| mask = masks[:, :, i] | |
| aligned_v = aligned_vs[i] | |
| aligned_r = aligned_rs[i] | |
| loss += self._singleFrameAlignLoss(frame, mask, aligned_v, aligned_r) | |
| return loss | |
| def _singleFrameAlignLoss(self, targetFrame, targetMask, aligned_v, aligned_r): | |
| """ | |
| :param targetFrame: targetFrame to be aligned-> B, C, H, W | |
| :param targetMask: the mask of target frames | |
| :param aligned_v: aligned visibility map from reference frame | |
| :param aligned_r: aligned reference frame-> B, C, T, H, W | |
| :return: | |
| """ | |
| targetVisibility = 1. - targetMask | |
| targetVisibility = targetVisibility.unsqueeze(2) | |
| targetFrame = targetFrame.unsqueeze(2) | |
| visibility_map = targetVisibility * aligned_v | |
| target_visibility = visibility_map * targetFrame | |
| reference_visibility = visibility_map * aligned_r | |
| loss = 0 | |
| for i in range(aligned_r.shape[2]): | |
| loss += self.loss_fn(target_visibility[:, :, i], reference_visibility[:, :, i]) | |
| return loss | |
| class HoleVisibleLoss(nn.Module): | |
| def __init__(self, reduction='mean'): | |
| super().__init__() | |
| self.loss_fn = nn.L1Loss(reduction=reduction) | |
| def forward(self, outputs, masks, GTs, c_masks): | |
| try: | |
| B, C, T, H, W = outputs.shape | |
| except ValueError: | |
| outputs = outputs.unsqueeze(2) | |
| masks = masks.unsqueeze(2) | |
| GTs = GTs.unsqueeze(2) | |
| c_masks = c_masks.unsqueeze(2) | |
| B, C, T, H, W = outputs.shape | |
| loss = 0 | |
| for i in range(T): | |
| loss += self._singleFrameHoleVisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) | |
| return loss | |
| def _singleFrameHoleVisibleLoss(self, targetFrame, targetMask, c_mask, GT): | |
| return self.loss_fn(targetMask * c_mask * targetFrame, targetMask * c_mask * GT) | |
| class HoleInvisibleLoss(nn.Module): | |
| def __init__(self, reduction='mean'): | |
| super().__init__() | |
| self.loss_fn = nn.L1Loss(reduction=reduction) | |
| def forward(self, outputs, masks, GTs, c_masks): | |
| try: | |
| B, C, T, H, W = outputs.shape | |
| except ValueError: | |
| outputs = outputs.unsqueeze(2) | |
| masks = masks.unsqueeze(2) | |
| GTs = GTs.unsqueeze(2) | |
| c_masks = c_masks.unsqueeze(2) | |
| B, C, T, H, W = outputs.shape | |
| loss = 0 | |
| for i in range(T): | |
| loss += self._singleFrameHoleInvisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) | |
| return loss | |
| def _singleFrameHoleInvisibleLoss(self, targetFrame, targetMask, c_mask, GT): | |
| return self.loss_fn(targetMask * (1. - c_mask) * targetFrame, targetMask * (1. - c_mask) * GT) | |
| class NonHoleLoss(nn.Module): | |
| def __init__(self, reduction='mean'): | |
| super().__init__() | |
| self.loss_fn = nn.L1Loss(reduction=reduction) | |
| def forward(self, outputs, masks, GTs): | |
| try: | |
| B, C, T, H, W = outputs.shape | |
| except ValueError: | |
| outputs = outputs.unsqueeze(2) | |
| masks = masks.unsqueeze(2) | |
| GTs = GTs.unsqueeze(2) | |
| B, C, T, H, W = outputs.shape | |
| loss = 0 | |
| for i in range(T): | |
| loss += self._singleNonHoleLoss(outputs[:, :, i], masks[:, :, i], GTs[:, :, i]) | |
| return loss | |
| def _singleNonHoleLoss(self, targetFrame, targetMask, GT): | |
| return self.loss_fn((1. - targetMask) * targetFrame, (1. - targetMask) * GT) | |
| class ReconLoss(nn.Module): | |
| def __init__(self, reduction='mean', masked=False): | |
| super().__init__() | |
| self.loss_fn = nn.L1Loss(reduction=reduction) | |
| self.masked = masked | |
| def forward(self, model_output, target, mask): | |
| outputs = model_output | |
| targets = target | |
| if self.masked: | |
| masks = mask | |
| return self.loss_fn(outputs * masks, targets * masks) # L1 loss in masked region | |
| else: | |
| return self.loss_fn(outputs, targets) # L1 loss in the whole region | |
| class VGGLoss(nn.Module): | |
| def __init__(self, vgg): | |
| super().__init__() | |
| self.l1_loss = nn.L1Loss() | |
| self.vgg = vgg | |
| def vgg_loss(self, output, target): | |
| output_feature = self.vgg(output) | |
| target_feature = self.vgg(target) | |
| loss = ( | |
| self.l1_loss(output_feature.relu2_2, target_feature.relu2_2) | |
| + self.l1_loss(output_feature.relu3_3, target_feature.relu3_3) | |
| + self.l1_loss(output_feature.relu4_3, target_feature.relu4_3) | |
| ) | |
| return loss | |
| def forward(self, data_input, model_output): | |
| targets = data_input | |
| outputs = model_output | |
| mean_image_loss = self.vgg_loss(outputs, targets) | |
| return mean_image_loss | |
| class StyleLoss(nn.Module): | |
| def __init__(self, vgg, original_channel_norm=True): | |
| super().__init__() | |
| self.l1_loss = nn.L1Loss() | |
| self.vgg = vgg | |
| self.original_channel_norm = original_channel_norm | |
| # From https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py | |
| def gram_matrix(self, input): | |
| a, b, c, d = input.size() # a=batch size(=1) | |
| # b=number of feature maps | |
| # (c,d)=dimensions of a f. map (N=c*d) | |
| features = input.view(a * b, c * d) # resise F_XL into \hat F_XL | |
| G = torch.mm(features, features.t()) # compute the gram product | |
| # we 'normalize' the values of the gram matrix | |
| # by dividing by the number of element in each feature maps. | |
| return G.div(a * b * c * d) | |
| # Implement "Image Inpainting for Irregular Holes Using Partial Convolutions", Liu et al., 2018 | |
| def style_loss(self, output, target): | |
| output_features = self.vgg(output) | |
| target_features = self.vgg(target) | |
| layers = ['relu2_2', 'relu3_3', 'relu4_3'] # n_channel: 128 (=2 ** 7), 256 (=2 ** 8), 512 (=2 ** 9) | |
| loss = 0 | |
| for i, layer in enumerate(layers): | |
| output_feature = getattr(output_features, layer) | |
| target_feature = getattr(target_features, layer) | |
| B, C_P, H, W = output_feature.shape | |
| output_gram_matrix = self.gram_matrix(output_feature) | |
| target_gram_matrix = self.gram_matrix(target_feature) | |
| if self.original_channel_norm: | |
| C_P_square_divider = 2 ** (i + 1) # original design (avoid too small loss) | |
| else: | |
| C_P_square_divider = C_P ** 2 | |
| assert C_P == 128 * 2 ** i | |
| loss += self.l1_loss(output_gram_matrix, target_gram_matrix) / C_P_square_divider | |
| return loss | |
| def forward(self, data_input, model_output): | |
| targets = data_input | |
| outputs = model_output | |
| mean_image_loss = self.style_loss(outputs, targets) | |
| return mean_image_loss | |
| class L1LossMaskedMean(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.l1 = nn.L1Loss(reduction='sum') | |
| def forward(self, x, y, mask): | |
| masked = 1 - mask # 默认missing region的mask值为0,原有区域为1 | |
| l1_sum = self.l1(x * masked, y * masked) | |
| return l1_sum / torch.sum(masked) | |
| class L2LossMaskedMean(nn.Module): | |
| def __init__(self, reduction='sum'): | |
| super().__init__() | |
| self.l2 = nn.MSELoss(reduction=reduction) | |
| def forward(self, x, y, mask): | |
| masked = 1 - mask | |
| l2_sum = self.l2(x * masked, y * masked) | |
| return l2_sum / torch.sum(masked) | |
| class ImcompleteVideoReconLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.loss_fn = L1LossMaskedMean() | |
| def forward(self, data_input, model_output): | |
| imcomplete_video = model_output['imcomplete_video'] | |
| targets = data_input['targets'] | |
| down_sampled_targets = nn.functional.interpolate( | |
| targets.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) | |
| masks = data_input['masks'] | |
| down_sampled_masks = nn.functional.interpolate( | |
| masks.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) | |
| return self.loss_fn( | |
| imcomplete_video, down_sampled_targets, | |
| down_sampled_masks | |
| ) | |
| class CompleteFramesReconLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.loss_fn = L1LossMaskedMean() | |
| def forward(self, data_input, model_output): | |
| outputs = model_output['outputs'] | |
| targets = data_input['targets'] | |
| masks = data_input['masks'] | |
| return self.loss_fn(outputs, targets, masks) | |
| class AdversarialLoss(nn.Module): | |
| r""" | |
| Adversarial loss | |
| https://arxiv.org/abs/1711.10337 | |
| """ | |
| def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): | |
| r""" | |
| type = nsgan | lsgan | hinge | |
| """ | |
| super(AdversarialLoss, self).__init__() | |
| self.type = type | |
| self.register_buffer('real_label', torch.tensor(target_real_label)) | |
| self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
| if type == 'nsgan': | |
| self.criterion = nn.BCELoss() | |
| elif type == 'lsgan': | |
| self.criterion = nn.MSELoss() | |
| elif type == 'hinge': | |
| self.criterion = nn.ReLU() | |
| def __call__(self, outputs, is_real, is_disc=None): | |
| if self.type == 'hinge': | |
| if is_disc: | |
| if is_real: | |
| outputs = -outputs | |
| return self.criterion(1 + outputs).mean() | |
| else: | |
| return (-outputs).mean() | |
| else: | |
| labels = (self.real_label if is_real else self.fake_label).expand_as( | |
| outputs) | |
| loss = self.criterion(outputs, labels) | |
| return loss | |
| # # From https://github.com/phoenix104104/fast_blind_video_consistency | |
| # class TemporalWarpingLoss(nn.Module): | |
| # def __init__(self, opts, flownet_checkpoint_path=None, alpha=50): | |
| # super().__init__() | |
| # self.loss_fn = L1LossMaskedMean() | |
| # self.alpha = alpha | |
| # self.opts = opts | |
| # | |
| # assert flownet_checkpoint_path is not None, "Flownet2 pretrained models must be provided" | |
| # | |
| # self.flownet_checkpoint_path = flownet_checkpoint_path | |
| # raise NotImplementedError | |
| # | |
| # def get_flownet_checkpoint_path(self): | |
| # return self.flownet_checkpoint_path | |
| # | |
| # def _flownetwrapper(self): | |
| # Flownet = FlowNet2(self.opts, requires_grad=False) | |
| # Flownet2_ckpt = torch.load(self.flownet_checkpoint_path) | |
| # Flownet.load_state_dict(Flownet2_ckpt['state_dict']) | |
| # Flownet.to(device) | |
| # Flownet.exal() | |
| # return Flownet | |
| # | |
| # def _setup(self): | |
| # self.flownet = self._flownetwrapper() | |
| # | |
| # def _get_non_occlusuib_mask(self, targets, warped_targets): | |
| # non_occlusion_masks = torch.exp( | |
| # -self.alpha * torch.sum(targets[:, 1:] - warped_targets, dim=2).pow(2) | |
| # ).unsqueeze(2) | |
| # return non_occlusion_masks | |
| # | |
| # def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks): | |
| # return self.loss_fn( | |
| # outputs[:, 1:] * non_occlusion_masks, | |
| # warped_outputs * non_occlusion_masks, | |
| # masks[:, 1:] | |
| # ) | |
| # | |
| # def forward(self, data_input, model_output): | |
| # if self.flownet is None: | |
| # self._setup() | |
| # | |
| # targets = data_input['targets'].to(device) | |
| # outputs = model_output['outputs'].to(device) | |
| # flows = self.flownet.infer_video(targets).to(device) | |
| # | |
| # from utils.flow_utils import warp_optical_flow | |
| # warped_targets = warp_optical_flow(targets[:, :-1], -flows).detach() | |
| # warped_outputs = warp_optical_flow(outputs[:, :-1], -flows).detach() | |
| # non_occlusion_masks = self._get_non_occlusion_mask(targets, warped_targets) | |
| # | |
| # # model_output is passed by name and dictionary is mutable | |
| # # These values are sent to trainer for visualization | |
| # model_output['warped_outputs'] = warped_outputs[0] | |
| # model_output['warped_targets'] = warped_targets[0] | |
| # model_output['non_occlusion_masks'] = non_occlusion_masks[0] | |
| # from utils.flow_utils import flow_to_image | |
| # flow_imgs = [] | |
| # for flow in flows[0]: | |
| # flow_img = flow_to_image(flow.cpu().permute(1, 2, 0).detach().numpy()).transpose(2, 0, 1) | |
| # flow_imgs.append(torch.Tensor(flow_img)) | |
| # model_output['flow_imgs'] = flow_imgs | |
| # | |
| # masks = data_input['masks'].to(device) | |
| # return self._get_loss(outputs, warped_outputs, non_occlusion_masks, masks) | |
| # | |
| # | |
| # class TemporalWarpingError(TemporalWarpingLoss): | |
| # def __init__(self, flownet_checkpoint_path, alpha=50): | |
| # super().__init__(flownet_checkpoint_path, alpha) | |
| # self.loss_fn = L2LossMaskedMean(reduction='none') | |
| # | |
| # def _get_loss(self, outputs, warped_outputs, non_occlusion_masks, masks): | |
| # # See https://arxiv.org/pdf/1808.00449.pdf 4.3 | |
| # # The sum of non_occlusion_masks is different for each video, | |
| # # So the batch dim is kept | |
| # loss = self.loss_fn( | |
| # outputs[:, 1:] * non_occlusion_masks, | |
| # warped_outputs * non_occlusion_masks, | |
| # masks[:, 1:] | |
| # ).sum(1).sum(1).sum(1).sum(1) | |
| # | |
| # loss = loss / non_occlusion_masks.sum(1).sum(1).sum(1).sum(1) | |
| # return loss.sum() | |
| class ValidLoss(nn.Module): | |
| def __init__(self): | |
| super(ValidLoss, self).__init__() | |
| self.loss_fn = nn.L1Loss(reduction='mean') | |
| def forward(self, model_output, target, mk): | |
| outputs = model_output | |
| targets = target | |
| return self.loss_fn(outputs * (1 - mk), targets * (1 - mk)) # L1 loss in masked region | |
| class TVLoss(nn.Module): | |
| def __init__(self): | |
| super(TVLoss, self).__init__() | |
| def forward(self, mask_input, model_output): | |
| # View 3D data as 2D | |
| outputs = model_output | |
| if len(mask_input.shape) == 4: | |
| mask_input = mask_input.unsqueeze(2) | |
| if len(outputs.shape) == 4: | |
| outputs = outputs.unsqueeze(2) | |
| outputs = outputs.permute((0, 2, 1, 3, 4)).contiguous() | |
| masks = mask_input.permute((0, 2, 1, 3, 4)).contiguous() | |
| B, L, C, H, W = outputs.shape | |
| x = outputs.view([B * L, C, H, W]) | |
| masks = masks.view([B * L, -1]) | |
| mask_areas = masks.sum(dim=1) | |
| h_x = x.size()[2] | |
| w_x = x.size()[3] | |
| h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum(1).sum(1).sum(1) # 差分是为了求梯度,本质上还是梯度平方和 | |
| w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum(1).sum(1).sum(1) | |
| return ((h_tv + w_tv) / mask_areas).mean() | |
| # for debug | |
| def show_images(image, name): | |
| import cv2 | |
| import numpy as np | |
| image = np.array(image) | |
| image[image > 0.5] = 255. | |
| image = image.transpose((1, 2, 0)) | |
| cv2.imwrite(name, image) | |
| if __name__ == '__main__': | |
| # test align loss, | |
| targetFrame = torch.ones(1, 3, 32, 32) | |
| GT = torch.ones(1, 3, 32, 32) | |
| GT += 1 | |
| mask = torch.zeros(1, 1, 32, 32) | |
| mask[:, :, 8:24, 8:24] = 1. | |
| # referenceFrames = torch.ones(1, 3, 4, 32, 32) | |
| # referenceMasks = torch.zeros(1, 1, 4, 32, 32) | |
| # referenceMasks[:, :, 0, 4:12, 4:12] = 1. | |
| # referenceFrames[:, :, 0, 4:12, 4:12] = 2. | |
| # referenceMasks[:, :, 1, 4:12, 20:28] = 1. | |
| # referenceFrames[:, :, 1, 4:12, 20:28] = 2. | |
| # referenceMasks[:, :, 2, 20:28, 4:12] = 1. | |
| # referenceFrames[:, :, 2, 20:28, 4:12] = 2. | |
| # referenceMasks[:, :, 3, 20:28, 20:28] = 1. | |
| # referenceFrames[:, :, 3, 20:28, 20:28] = 2. | |
| # | |
| # aligned_v = referenceMasks | |
| # aligned_v, referenceFrames = [aligned_v], [referenceFrames] | |
| # | |
| # result = AlignLoss()(targetFrame, mask, aligned_v, referenceFrames) | |
| # print(result) | |
| c_mask = torch.zeros(1, 1, 32, 32) | |
| c_mask[:, :, 8:16, 16:24] = 1. | |
| result1 = HoleVisibleLoss()(targetFrame, mask, GT, c_mask) | |
| result2 = HoleInvisibleLoss()(targetFrame, mask, GT, c_mask) | |
| result3 = NonHoleLoss()(targetFrame, mask, GT) | |
| print('vis: {}, invis: {}, gt: {}'.format(result1, result2, result3)) | |