Cat-AIR / test.py
jiachen
more images
87b5869
raw
history blame
3.81 kB
import argparse
import subprocess
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
import os
import torch.nn as nn
# from utils.dataset_utils import DenoiseTestDataset, DerainDehazeDataset
# from utils.val_utils import AverageMeter, compute_psnr_ssim
# from utils.image_io import save_image_tensor
from PIL import Image
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
import torch.nn.functional as F
from net.prompt_xrestormer import PromptXRestormer
import json
# crop an image to the multiple of base
def crop_img(image, base=64):
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
class PromptXRestormerIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = PromptXRestormer(
inp_channels=3,
out_channels=3,
dim = 48,
num_blocks = [2,4,4,4],
num_refinement_blocks = 4,
channel_heads= [1,1,1,1],
spatial_heads= [1,2,4,8],
overlap_ratio= [0.5, 0.5, 0.5, 0.5],
ffn_expansion_factor = 2.66,
bias = False,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
scale = 1,prompt = True
)
self.loss_fn = nn.L1Loss()
def forward(self,x):
return self.net(x)
def np_to_pil(img_np):
"""
Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
:param img_np:
:return:
"""
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
assert img_np.shape[0] == 3, img_np.shape
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def torch_to_np(img_var):
"""
Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
:param img_var:
:return:
"""
return img_var.detach().cpu().numpy()[0]
def save_image_tensor(image_tensor, output_path="output/"):
image_np = torch_to_np(image_tensor)
# print(image_np.shape)
p = np_to_pil(image_np)
p.save(output_path)
if __name__ == '__main__':
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(0)
ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt"
print("CKPT name : {}".format(ckpt_path))
net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda()
net.eval()
degraded_path = "/home/jiachen/MyGradio/test_images/noisy_myimage.jpg"
degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
toTensor = ToTensor()
degraded_img = toTensor(degraded_img)
print(degraded_img.shape)
with torch.no_grad():
degraded_img = degraded_img.unsqueeze(0).cuda()
_, _, H_old, W_old = degraded_img.shape
h_pad = (H_old // 64 + 1) * 64 - H_old
w_pad = (W_old // 64 + 1) * 64 - W_old
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:]
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad]
print("inputImage size", degraded_img.shape)
restored = net(degraded_img)
restored = restored[:,:,:H_old:,:W_old]
save_image_tensor(restored, "output.png")