import numpy as np import gradio as gr import numpy as np import torch import spaces from PIL import Image from torchvision.transforms import ToTensor from net.prompt_xrestormer import PromptXRestormer import lightning.pytorch as pl # crop an image to the multiple of base def crop_img(image, base=64): h = image.shape[0] w = image.shape[1] crop_h = h % base crop_w = w % base return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] class PromptXRestormerIRModel(pl.LightningModule): def __init__(self): super().__init__() self.net = PromptXRestormer( inp_channels=3, out_channels=3, dim = 48, num_blocks = [2,4,4,4], num_refinement_blocks = 4, channel_heads= [1,1,1,1], spatial_heads= [1,2,4,8], overlap_ratio= [0.5, 0.5, 0.5, 0.5], ffn_expansion_factor = 2.66, bias = False, LayerNorm_type = 'WithBias', ## Other option 'BiasFree' dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 scale = 1,prompt = True ) def forward(self,x): return self.net(x) def np_to_pil(img_np): """ Converts image in np.array format to PIL image. From C x W x H [0..1] to W x H x C [0...255] :param img_np: :return: """ ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) if img_np.shape[0] == 1: ar = ar[0] else: assert img_np.shape[0] == 3, img_np.shape ar = ar.transpose(1, 2, 0) return Image.fromarray(ar) def torch_to_np(img_var): """ Converts an image in torch.Tensor format to np.array. From 1 x C x W x H [0..1] to C x W x H [0..1] :param img_var: :return: """ return img_var.detach().cpu().numpy()[0] @spaces.GPU(duration=200) def restore_image(input_img): np.random.seed(0) torch.manual_seed(0) #ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt" ckpt_path = "ckpt/promptxrestormer_epoch=64-step=578630.ckpt" print("CKPT name : {}".format(ckpt_path)) #net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda() net = PromptXRestormerIRModel.load_from_checkpoint(ckpt_path).cuda() net.eval() #degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png" degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16) toTensor = ToTensor() degraded_img = toTensor(degraded_img) with torch.no_grad(): degraded_img = degraded_img.unsqueeze(0).cuda() _, _, H_old, W_old = degraded_img.shape h_pad = (H_old // 64 + 1) * 64 - H_old w_pad = (W_old // 64 + 1) * 64 - W_old degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:] degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad] restored = net(degraded_img) restored = restored[:,:,:H_old:,:W_old] restored_image = torch_to_np(restored) # change shape from [C, H, W] to [H, W, C] restored_image = restored_image.transpose(1, 2, 0) restored_image = np.clip(restored_image * 255, 0, 255).astype(np.uint8) # restored_image = Image.fromarray(restored_image) # print("restored shape : {}".format(restored_image.size)) return restored_image # degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png" # input_img = np.array(Image.open(degraded_path).convert('RGB')) # print(input_img) # restored_image = restore_image(input_img) # print(restored_image) title = "Content & Task Awareness All-In-One Image Restoration✏️🖼️ 🤗" description = ''' ## [Content & Task Awareness All-In-One Image Restoration] The Ohio State Unviersity | Microsoft Research ### TL;DR: quickstart ***One single model can perform several restoration tasks including image denoising, deraining and dehazing 🚀 . Our content & task awareness model would have better efficiency*** The (single) neural model performs all-in-one image restoration. **🚀 You can start with the [demo tutorial.]** Check [our github] for more information.
''' article = "

Content & Task Awareness All-In-One Image Restoration

" #### Image,Prompts examples examples = [['test_images/noisy_0000.png'], ['test_images/noisy_0001.png'], ['test_images/noisy_0002.png'], ['test_images/noisy_0003.png'], ['test_images/noisy_0004.png'], ['test_images/rain-01.png'], ['test_images/rain-02.png'], ['test_images/rain-03.png'], ['test_images/rain-04.png'], ['test_images/rain-05.png'], ['test_images/rain-06.png'], ['test_images/hazy-00.jpg'], ['test_images/hazy-01.jpg'], ['test_images/hazy-02.jpg'], ['test_images/hazy-03.jpg'], ['test_images/hazy-04.jpg'], ] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=restore_image, inputs=[gr.Image(type="pil", label="Input")], outputs=[gr.Image(type="pil", label="Ouput")], title=title, description=description, article=article, examples=examples, css=css, ) # if __name__ == "__main__": demo.launch(debug=True, show_error=True)