Spaces:
Sleeping
Sleeping
| # Modified partialconv source code based on implementation from | |
| # https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py | |
| ############################################################################### | |
| # BSD 3-Clause License | |
| # | |
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Author & Contact: Guilin Liu ([email protected]) | |
| ############################################################################### | |
| # Original Author & Contact: Guilin Liu ([email protected]) | |
| # Modified by Kevin Shih ([email protected]) | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class PartialConv1d(nn.Conv1d): | |
| def __init__(self, *args, **kwargs): | |
| self.multi_channel = False | |
| self.return_mask = False | |
| super(PartialConv1d, self).__init__(*args, **kwargs) | |
| self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) | |
| self.slide_winsize = ( | |
| self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] | |
| ) | |
| self.last_size = (None, None, None) | |
| self.update_mask = None | |
| self.mask_ratio = None | |
| def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None): | |
| """ | |
| input: standard input to a 1D conv | |
| mask_in: binary mask for valid values, same shape as input | |
| """ | |
| assert len(input.shape) == 3 | |
| # if a mask is input, or tensor shape changed, update mask ratio | |
| if mask_in is not None or self.last_size != tuple(input.shape): | |
| self.last_size = tuple(input.shape) | |
| with torch.no_grad(): | |
| if self.weight_maskUpdater.type() != input.type(): | |
| self.weight_maskUpdater = self.weight_maskUpdater.to(input) | |
| if mask_in is None: | |
| mask = torch.ones(1, 1, input.data.shape[2]).to(input) | |
| else: | |
| mask = mask_in | |
| self.update_mask = F.conv1d( | |
| mask, | |
| self.weight_maskUpdater, | |
| bias=None, | |
| stride=self.stride, | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| groups=1, | |
| ) | |
| # for mixed precision training, change 1e-8 to 1e-6 | |
| self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6) | |
| self.update_mask = torch.clamp(self.update_mask, 0, 1) | |
| self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) | |
| raw_out = super(PartialConv1d, self).forward( | |
| torch.mul(input, mask) if mask_in is not None else input | |
| ) | |
| if self.bias is not None: | |
| bias_view = self.bias.view(1, self.out_channels, 1) | |
| output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view | |
| output = torch.mul(output, self.update_mask) | |
| else: | |
| output = torch.mul(raw_out, self.mask_ratio) | |
| if self.return_mask: | |
| return output, self.update_mask | |
| else: | |
| return output | |