Spaces:
Paused
Paused
import numpy as np | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision.transforms import ToTensor | |
import lightning.pytorch as pl | |
from net.cata_prompt_xrestormer import CATAPromptXRestormerOnlyAttn | |
from einops import rearrange | |
import spaces | |
# crop an image to the multiple of base | |
def crop_img(image, base=64): | |
h = image.shape[0] | |
w = image.shape[1] | |
crop_h = h % base | |
crop_w = w % base | |
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] | |
class CATAPromptXRestormerIRModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.net = CATAPromptXRestormerOnlyAttn( | |
inp_channels=3, | |
out_channels=3, | |
dim = 48, | |
num_blocks = [2,4,4,4], | |
num_refinement_blocks = 4, | |
channel_heads = [1,1,1,1], | |
spatial_heads = [1,2,4,8], | |
overlap_ratio = 0.5, | |
dim_head = 16, | |
ratio = 0.5, | |
window_size = 8, | |
bias = False, | |
ffn_expansion_factor = 2.66, | |
LayerNorm_type = 'WithBias', ## Other option 'BiasFree' | |
dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 | |
scale = 1, | |
prompt = True, | |
hard_ratio = 0.5 | |
) | |
self.loss_fn = nn.L1Loss() | |
def forward(self,x, training=False): | |
return self.net(x, training) | |
def np_to_pil(img_np): | |
""" | |
Converts image in np.array format to PIL image. | |
From C x W x H [0..1] to W x H x C [0...255] | |
:param img_np: | |
:return: | |
""" | |
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) | |
if img_np.shape[0] == 1: | |
ar = ar[0] | |
else: | |
assert img_np.shape[0] == 3, img_np.shape | |
ar = ar.transpose(1, 2, 0) | |
return Image.fromarray(ar) | |
def torch_to_np(img_var): | |
""" | |
Converts an image in torch.Tensor format to np.array. | |
From 1 x C x W x H [0..1] to C x W x H [0..1] | |
:param img_var: | |
:return: | |
""" | |
return img_var.detach().cpu().numpy()[0] | |
def restore_image(input_img): | |
np.random.seed(0) | |
torch.manual_seed(0) | |
ckpt_path = "ckpt/cata_promptxrestormeronlyattn_epoch=30-step=275962.ckpt" | |
print("CKPT name : {}".format(ckpt_path)) | |
net = CATAPromptXRestormerIRModel.load_from_checkpoint(ckpt_path).cuda() | |
net.eval() | |
degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16) | |
toTensor = ToTensor() | |
degraded_img = toTensor(degraded_img) | |
with torch.no_grad(): | |
degraded_img = degraded_img.unsqueeze(0).cuda() | |
_, _, H_old, W_old = degraded_img.shape | |
h_pad = (H_old // 64 + 1) * 64 - H_old | |
w_pad = (W_old // 64 + 1) * 64 - W_old | |
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:] | |
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad] | |
restored, spatial_mask, channel_mask = net(degraded_img, training=False) | |
# process spatial mask | |
encoder_level1_mask = spatial_mask['encoder_level1'][0][0][0] | |
window_size = 8 | |
_, c, h, w = restored.shape | |
restored_windows = rearrange(restored, 'b c (h w1) (w w2) -> b c (h w) w1 w2', w1=window_size, w2=window_size) | |
for idx in encoder_level1_mask: | |
restored_windows[:, :, idx, :, :] = 1 # Mask out the window by setting it to one | |
restored_masked = rearrange(restored_windows, 'b c (h w) w1 w2 -> b c (h w1) (w w2)', h=h // window_size, w=w // window_size) | |
restored = restored[:,:,:H_old:,:W_old] | |
restored_image = torch_to_np(restored) | |
restored_image = restored_image.transpose(1, 2, 0) | |
restored_image = np.clip(restored_image * 255, 0, 255).astype(np.uint8) | |
restored_masked = restored_masked[:,:,:H_old:,:W_old] | |
restored_masked_image = torch_to_np(restored_masked) | |
restored_masked_image = restored_masked_image.transpose(1, 2, 0) | |
restored_masked_image = np.clip(restored_masked_image * 255, 0, 255).astype(np.uint8) | |
return restored_image, restored_masked_image | |
title = "Content & Task Awareness All-In-One Image Restoration✏️🖼️ 🤗" | |
description = ''' ## [Content & Task Awareness All-In-One Image Restoration] | |
The Ohio State Unviersity | Microsoft Research | |
### TL;DR: quickstart | |
***One single model can perform several restoration tasks including image denoising, deraining and dehazing 🚀 . Our content & task awareness model would have better efficiency*** | |
The (single) neural model performs all-in-one image restoration. | |
**🚀 You can start with the [demo tutorial.]** Check [our github] for more information. | |
<br> | |
''' | |
article = "<p style='text-align: center'><a href='https://github.com/mv-lab/InstructIR' target='_blank'>Content & Task Awareness All-In-One Image Restoration</a></p>" | |
#### Image,Prompts examples | |
examples = [['test_images/noisy_0000.png'], | |
['test_images/noisy_0001.png'], | |
['test_images/noisy_0002.png'], | |
['test_images/noisy_0003.png'], | |
['test_images/noisy_0004.png'], | |
['test_images/rain-001.png'], | |
['test_images/rain-002.png'], | |
['test_images/rain-003.png'], | |
['test_images/rain-004.png'], | |
['test_images/rain-005.png'], | |
['test_images/hazy-01.jpg'], | |
['test_images/hazy-02.jpg'], | |
['test_images/hazy-03.jpg'], | |
['test_images/hazy-04.jpg'], | |
['test_images/hazy-05.jpg'], | |
] | |
css = """ | |
.image-frame img, .image-container img { | |
width: auto; | |
height: auto; | |
max-width: none; | |
} | |
""" | |
demo = gr.Interface( | |
fn=restore_image, | |
inputs=[gr.Image(type="pil", label="Input")], | |
outputs=[gr.Image(type="pil", label="Ouput"), gr.Image(type="pil", label="Output-Mask")], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
css=css, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, show_error=True) | |