|
import os
|
|
import numpy as np
|
|
import torch
|
|
import shutil
|
|
from torch.autograd import Variable
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
|
|
|
|
|
|
def pair_downsampler(img):
|
|
|
|
c = img.shape[1]
|
|
filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
|
|
filter1 = filter1.repeat(c, 1, 1, 1)
|
|
filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
|
|
filter2 = filter2.repeat(c, 1, 1, 1)
|
|
output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
|
|
output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
|
|
return output1,output2
|
|
|
|
def gauss_cdf(x):
|
|
return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
|
|
|
|
def gauss_kernel(kernlen=21,nsig=3,channels=1):
|
|
interval=(2*nsig+1.)/(kernlen)
|
|
x=torch.linspace(-nsig-interval/2.,nsig+interval/2.,kernlen+1,).cuda()
|
|
|
|
kern1d=torch.diff(gauss_cdf(x))
|
|
kernel_raw=torch.sqrt(torch.outer(kern1d,kern1d))
|
|
kernel=kernel_raw/torch.sum(kernel_raw)
|
|
|
|
out_filter=kernel.view(1,1,kernlen,kernlen)
|
|
out_filter = out_filter.repeat(channels,1,1,1)
|
|
return out_filter
|
|
|
|
class LocalMean(torch.nn.Module):
|
|
def __init__(self, patch_size=5):
|
|
super(LocalMean, self).__init__()
|
|
self.patch_size = patch_size
|
|
self.padding = self.patch_size // 2
|
|
|
|
def forward(self, image):
|
|
image = torch.nn.functional.pad(image, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
|
|
patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
|
|
return patches.mean(dim=(4, 5))
|
|
|
|
def blur(x):
|
|
device = x.device
|
|
kernel_size = 21
|
|
padding = kernel_size // 2
|
|
kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
|
|
x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
|
|
return torch.nn.functional .conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
|
|
|
|
def padr_tensor(img):
|
|
pad=2
|
|
pad_mod=torch.nn.ConstantPad2d(pad,0)
|
|
img_pad=pad_mod(img)
|
|
return img_pad
|
|
|
|
def calculate_local_variance(train_noisy):
|
|
b,c,w,h=train_noisy.shape
|
|
avg_pool = torch.nn.AvgPool2d(kernel_size=5,stride=1,padding=2)
|
|
noisy_avg= avg_pool(train_noisy)
|
|
noisy_avg_pad=padr_tensor(noisy_avg)
|
|
train_noisy=padr_tensor(train_noisy)
|
|
unfolded_noisy_avg=noisy_avg_pad.unfold(2,5,1).unfold(3,5,1)
|
|
unfolded_noisy=train_noisy.unfold(2,5,1).unfold(3,5,1)
|
|
unfolded_noisy_avg=unfolded_noisy_avg.reshape(unfolded_noisy_avg.shape[0],-1,5,5)
|
|
unfolded_noisy=unfolded_noisy.reshape(unfolded_noisy.shape[0],-1,5,5)
|
|
noisy_diff_squared=(unfolded_noisy-unfolded_noisy_avg)**2
|
|
noisy_var=torch.mean(noisy_diff_squared,dim=(2,3))
|
|
noisy_var=noisy_var.view(b,c,w,h)
|
|
return noisy_var
|
|
|
|
def count_parameters_in_MB(model):
|
|
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
|
|
|
|
|
|
|
|
def save_checkpoint(state, is_best, save):
|
|
filename = os.path.join(save, 'checkpoint.pth.tar')
|
|
torch.save(state, filename)
|
|
if is_best:
|
|
best_filename = os.path.join(save, 'model_best.pth.tar')
|
|
shutil.copyfile(filename, best_filename)
|
|
|
|
|
|
def save(model, model_path):
|
|
torch.save(model.state_dict(), model_path)
|
|
|
|
|
|
def load(model, model_path):
|
|
model.load_state_dict(torch.load(model_path))
|
|
|
|
def drop_path(x, drop_prob):
|
|
if drop_prob > 0.:
|
|
keep_prob = 1.-drop_prob
|
|
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
|
x.div_(keep_prob)
|
|
x.mul_(mask)
|
|
return x
|
|
|
|
def create_exp_dir(path, scripts_to_save=None):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path,exist_ok=True)
|
|
print('Experiment dir : {}'.format(path))
|
|
|
|
if scripts_to_save is not None:
|
|
os.makedirs(os.path.join(path, 'scripts'),exist_ok=True)
|
|
for script in scripts_to_save:
|
|
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
|
shutil.copyfile(script, dst_file)
|
|
|
|
def show_pic(pic, name,path):
|
|
pic_num = len(pic)
|
|
for i in range(pic_num):
|
|
img = pic[i]
|
|
image_numpy = img[0].cpu().float().numpy()
|
|
if image_numpy.shape[0]==3:
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
|
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
|
|
img_name = name[i]
|
|
plt.subplot(5, 6, i + 1)
|
|
plt.xlabel(str(img_name))
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.imshow(im)
|
|
elif image_numpy.shape[0]==1:
|
|
im = Image.fromarray(np.clip(image_numpy[0] * 255.0, 0, 255.0).astype('uint8'))
|
|
img_name = name[i]
|
|
plt.subplot(5, 6, i + 1)
|
|
plt.xlabel(str(img_name))
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.imshow(im,plt.cm.gray)
|
|
plt.savefig(path)
|
|
|
|
|
|
|
|
|