import argparse import subprocess from tqdm import tqdm import numpy as np import torch from torch.utils.data import DataLoader import os import torch.nn as nn # from utils.dataset_utils import DenoiseTestDataset, DerainDehazeDataset # from utils.val_utils import AverageMeter, compute_psnr_ssim # from utils.image_io import save_image_tensor from PIL import Image from torchvision.transforms import ToTensor import lightning.pytorch as pl import torch.nn.functional as F from net.cata_prompt_xrestormer import CATAPromptXRestormerOnlyAttn from einops import rearrange # crop an image to the multiple of base def crop_img(image, base=64): h = image.shape[0] w = image.shape[1] crop_h = h % base crop_w = w % base return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] class CATAPromptXRestormerIRModel(pl.LightningModule): def __init__(self): super().__init__() self.net = CATAPromptXRestormerOnlyAttn( inp_channels=3, out_channels=3, dim = 48, num_blocks = [2,4,4,4], num_refinement_blocks = 4, channel_heads = [1,1,1,1], spatial_heads = [1,2,4,8], overlap_ratio = 0.5, dim_head = 16, ratio = 0.5, window_size = 8, bias = False, ffn_expansion_factor = 2.66, LayerNorm_type = 'WithBias', ## Other option 'BiasFree' dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 scale = 1, prompt = True, hard_ratio = 0.5 ) self.loss_fn = nn.L1Loss() def forward(self,x, training=False): return self.net(x, training) def np_to_pil(img_np): """ Converts image in np.array format to PIL image. From C x W x H [0..1] to W x H x C [0...255] :param img_np: :return: """ ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) if img_np.shape[0] == 1: ar = ar[0] else: assert img_np.shape[0] == 3, img_np.shape ar = ar.transpose(1, 2, 0) return Image.fromarray(ar) def torch_to_np(img_var): """ Converts an image in torch.Tensor format to np.array. From 1 x C x W x H [0..1] to C x W x H [0..1] :param img_var: :return: """ return img_var.detach().cpu().numpy()[0] def save_image_tensor(image_tensor, output_path="output/"): image_np = torch_to_np(image_tensor) # print(image_np.shape) p = np_to_pil(image_np) p.save(output_path) if __name__ == '__main__': np.random.seed(0) torch.manual_seed(0) torch.cuda.set_device(0) ckpt_path = "ckpt/cata_promptxrestormeronlyattn_epoch=30-step=275962.ckpt" print("CKPT name : {}".format(ckpt_path)) net = CATAPromptXRestormerIRModel.load_from_checkpoint(ckpt_path).cuda() net.eval() degraded_path = "/home/jiachen/MyGradio/test_images/rain-01.png" degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16) toTensor = ToTensor() degraded_img = toTensor(degraded_img) print(degraded_img.shape) with torch.no_grad(): degraded_img = degraded_img.unsqueeze(0).cuda() _, _, H_old, W_old = degraded_img.shape h_pad = (H_old // 64 + 1) * 64 - H_old w_pad = (W_old // 64 + 1) * 64 - W_old degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:] degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad] print("inputImage size", degraded_img.shape) restored, spatial_mask, channel_mask = net(degraded_img, training=False) encoder_level1_mask = spatial_mask['encoder_level1'][0][0][0] window_size = 8 _, c, h, w = restored.shape # Split the restored image into 8x8 windows restored_windows = rearrange(restored, 'b c (h w1) (w w2) -> b c (h w) w1 w2', w1=window_size, w2=window_size) # Mask out the windows according to the indices in encoder_level1_mask for idx in encoder_level1_mask: restored_windows[:, :, idx, :, :] = 1 # Mask out the window by setting it to one # Reconstruct the image from the masked windows restored_masked = rearrange(restored_windows, 'b c (h w) w1 w2 -> b c (h w1) (w w2)', h=h // window_size, w=w // window_size) restored = restored[:,:,:H_old:,:W_old] restored_masked = restored_masked[:,:,:H_old:,:W_old] save_image_tensor(restored, "output.png") save_image_tensor(restored_masked, "output_masked.png")