| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class SRMConv2d(nn.Module): | |
| def __init__(self, stride: int = 1, padding: int = 2, clip: float = 2): | |
| super().__init__() | |
| self.stride = stride | |
| self.padding = padding | |
| self.clip = clip | |
| self.conv = self._get_srm_filter() | |
| def _get_srm_filter(self): | |
| filter1 = [ | |
| [0, 0, 0, 0, 0], | |
| [0, -1, 2, -1, 0], | |
| [0, 2, -4, 2, 0], | |
| [0, -1, 2, -1, 0], | |
| [0, 0, 0, 0, 0], | |
| ] | |
| filter2 = [ | |
| [-1, 2, -2, 2, -1], | |
| [2, -6, 8, -6, 2], | |
| [-2, 8, -12, 8, -2], | |
| [2, -6, 8, -6, 2], | |
| [-1, 2, -2, 2, -1], | |
| ] | |
| filter3 = [ | |
| [0, 0, 0, 0, 0], | |
| [0, 0, 0, 0, 0], | |
| [0, 1, -2, 1, 0], | |
| [0, 0, 0, 0, 0], | |
| [0, 0, 0, 0, 0], | |
| ] | |
| q = [4.0, 12.0, 2.0] | |
| filter1 = np.asarray(filter1, dtype=float) / q[0] | |
| filter2 = np.asarray(filter2, dtype=float) / q[1] | |
| filter3 = np.asarray(filter3, dtype=float) / q[2] | |
| filters = [ | |
| [filter1, filter1, filter1], | |
| [filter2, filter2, filter2], | |
| [filter3, filter3, filter3], | |
| ] | |
| filters = torch.tensor(filters).float() | |
| conv2d = nn.Conv2d( | |
| 3, | |
| 3, | |
| kernel_size=5, | |
| stride=self.stride, | |
| padding=self.padding, | |
| padding_mode="zeros", | |
| ) | |
| conv2d.weight = nn.Parameter(filters, requires_grad=False) | |
| conv2d.bias = nn.Parameter(torch.zeros_like(conv2d.bias), requires_grad=False) | |
| return conv2d | |
| def forward(self, x): | |
| x = self.conv(x) | |
| if self.clip != 0.0: | |
| x = x.clamp(-self.clip, self.clip) | |
| return x | |
| if __name__ == "__main__": | |
| srm = SRMConv2d() | |
| x = torch.rand((63, 3, 64, 64)) | |
| x = srm(x) | |