|
import torch |
|
import torch.nn as nn |
|
from loss import LossFunction, TextureDifference |
|
from utils import blur, pair_downsampler |
|
|
|
|
|
|
|
class Denoise_1(nn.Module): |
|
def __init__(self, chan_embed=48): |
|
super(Denoise_1, self).__init__() |
|
|
|
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
self.conv1 = nn.Conv2d(3, chan_embed, 3, padding=1) |
|
self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1) |
|
self.conv3 = nn.Conv2d(chan_embed, 3, 1) |
|
|
|
def forward(self, x): |
|
x = self.act(self.conv1(x)) |
|
x = self.act(self.conv2(x)) |
|
x = self.conv3(x) |
|
return x |
|
|
|
|
|
class Denoise_2(nn.Module): |
|
def __init__(self, chan_embed=96): |
|
super(Denoise_2, self).__init__() |
|
|
|
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
self.conv1 = nn.Conv2d(6, chan_embed, 3, padding=1) |
|
self.conv2 = nn.Conv2d(chan_embed, chan_embed, 3, padding=1) |
|
self.conv3 = nn.Conv2d(chan_embed, 6, 1) |
|
|
|
def forward(self, x): |
|
x = self.act(self.conv1(x)) |
|
x = self.act(self.conv2(x)) |
|
x = self.conv3(x) |
|
return x |
|
|
|
|
|
class Enhancer(nn.Module): |
|
def __init__(self, layers, channels): |
|
super(Enhancer, self).__init__() |
|
|
|
kernel_size = 3 |
|
dilation = 1 |
|
padding = int((kernel_size - 1) / 2) * dilation |
|
|
|
self.in_conv = nn.Sequential( |
|
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding), |
|
nn.ReLU() |
|
) |
|
|
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding), |
|
nn.BatchNorm2d(channels), |
|
nn.ReLU() |
|
) |
|
self.blocks = nn.ModuleList() |
|
for i in range(layers): |
|
self.blocks.append(self.conv) |
|
|
|
self.out_conv = nn.Sequential( |
|
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, input): |
|
fea = self.in_conv(input) |
|
for conv in self.blocks: |
|
fea = fea + conv(fea) |
|
fea = self.out_conv(fea) |
|
fea = torch.clamp(fea, 0.0001, 1) |
|
|
|
return fea |
|
|
|
|
|
class Network(nn.Module): |
|
|
|
def __init__(self): |
|
super(Network, self).__init__() |
|
|
|
self.enhance = Enhancer(layers=3, channels=64) |
|
self.denoise_1 = Denoise_1(chan_embed=48) |
|
self.denoise_2 = Denoise_2(chan_embed=48) |
|
self._l2_loss = nn.MSELoss() |
|
self._l1_loss = nn.L1Loss() |
|
self._criterion = LossFunction() |
|
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) |
|
self.TextureDifference = TextureDifference() |
|
|
|
|
|
def enhance_weights_init(self, m): |
|
if isinstance(m, nn.Conv2d): |
|
m.weight.data.normal_(0.0, 0.02) |
|
if m.bias != None: |
|
m.bias.data.zero_() |
|
|
|
if isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.normal_(1., 0.02) |
|
|
|
def denoise_weights_init(self, m): |
|
if isinstance(m, nn.Conv2d): |
|
m.weight.data.normal_(0, 0.02) |
|
if m.bias != None: |
|
m.bias.data.zero_() |
|
|
|
if isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.normal_(1., 0.02) |
|
|
|
def forward(self, input): |
|
eps = 1e-4 |
|
input = input + eps |
|
|
|
L11, L12 = pair_downsampler(input) |
|
L_pred1 = L11 - self.denoise_1(L11) |
|
L_pred2 = L12 - self.denoise_1(L12) |
|
L2 = input - self.denoise_1(input) |
|
L2 = torch.clamp(L2, eps, 1) |
|
|
|
s2 = self.enhance(L2.detach()) |
|
s21, s22 = pair_downsampler(s2) |
|
H2 = input / s2 |
|
H2 = torch.clamp(H2, eps, 1) |
|
|
|
H11 = L11 / s21 |
|
H11 = torch.clamp(H11, eps, 1) |
|
|
|
H12 = L12 / s22 |
|
H12 = torch.clamp(H12, eps, 1) |
|
|
|
H3_pred = torch.cat([H11, s21], 1).detach() - self.denoise_2(torch.cat([H11, s21], 1)) |
|
H3_pred = torch.clamp(H3_pred, eps, 1) |
|
H13 = H3_pred[:, :3, :, :] |
|
s13 = H3_pred[:, 3:, :, :] |
|
|
|
H4_pred = torch.cat([H12, s22], 1).detach() - self.denoise_2(torch.cat([H12, s22], 1)) |
|
H4_pred = torch.clamp(H4_pred, eps, 1) |
|
H14 = H4_pred[:, :3, :, :] |
|
s14 = H4_pred[:, 3:, :, :] |
|
|
|
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1)) |
|
H5_pred = torch.clamp(H5_pred, eps, 1) |
|
H3 = H5_pred[:, :3, :, :] |
|
s3 = H5_pred[:, 3:, :, :] |
|
|
|
L_pred1_L_pred2_diff = self.TextureDifference(L_pred1, L_pred2) |
|
H3_denoised1, H3_denoised2 = pair_downsampler(H3) |
|
H3_denoised1_H3_denoised2_diff= self.TextureDifference(H3_denoised1, H3_denoised2) |
|
|
|
H1 = L2 / s2 |
|
H1 = torch.clamp(H1, 0, 1) |
|
H2_blur = blur(H1) |
|
H3_blur = blur(H3) |
|
|
|
return L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur |
|
|
|
def _loss(self, input): |
|
L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, H3_blur = self( |
|
input) |
|
loss = 0 |
|
|
|
loss += self._criterion(input, L_pred1, L_pred2, L2, s2, s21, s22, H2, H11, H12, H13, s13, H14, s14, H3, s3, |
|
H3_pred, H4_pred, L_pred1_L_pred2_diff, H3_denoised1_H3_denoised2_diff, H2_blur, |
|
H3_blur) |
|
return loss |
|
|
|
|
|
class Finetunemodel(nn.Module): |
|
|
|
def __init__(self, weights): |
|
super(Finetunemodel, self).__init__() |
|
|
|
self.enhance = Enhancer(layers=3, channels=64) |
|
self.denoise_1 = Denoise_1(chan_embed=48) |
|
self.denoise_2 = Denoise_2(chan_embed=48) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
try: |
|
base_weights = torch.load(weights, map_location=device) |
|
pretrained_dict = base_weights |
|
model_dict = self.state_dict() |
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
model_dict.update(pretrained_dict) |
|
self.load_state_dict(model_dict) |
|
print(f"✅ Loaded weights from {weights} on {device}") |
|
except Exception as e: |
|
print(f"⚠️ Could not load weights: {e}") |
|
print("Using random initialization") |
|
|
|
def weights_init(self, m): |
|
if isinstance(m, nn.Conv2d): |
|
m.weight.data.normal_(0, 0.02) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
|
|
if isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.normal_(1., 0.02) |
|
|
|
def forward(self, input): |
|
eps = 1e-4 |
|
input = input + eps |
|
L2 = input - self.denoise_1(input) |
|
L2 = torch.clamp(L2, eps, 1) |
|
s2 = self.enhance(L2) |
|
H2 = input / s2 |
|
H2 = torch.clamp(H2, eps, 1) |
|
H5_pred = torch.cat([H2, s2], 1).detach() - self.denoise_2(torch.cat([H2, s2], 1)) |
|
H5_pred = torch.clamp(H5_pred, eps, 1) |
|
H3 = H5_pred[:, :3, :, :] |
|
return H2, H3 |