Cat-AIR / app_cata.py
jiang.2880
5tasks
0d31e9f
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
from net.cata_prompt_xrestormer import CATAPromptXRestormerOnlyAttn
from einops import rearrange
import spaces
# crop an image to the multiple of base
def crop_img(image, base=64):
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
class CATAPromptXRestormerIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = CATAPromptXRestormerOnlyAttn(
inp_channels=3,
out_channels=3,
dim = 48,
num_blocks = [2,4,4,4],
num_refinement_blocks = 4,
channel_heads = [1,1,1,1],
spatial_heads = [1,2,4,8],
overlap_ratio = 0.5,
dim_head = 16,
ratio = 0.5,
window_size = 8,
bias = False,
ffn_expansion_factor = 2.66,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
scale = 1,
prompt = True,
hard_ratio = 0.5
)
self.loss_fn = nn.L1Loss()
def forward(self,x, training=False):
return self.net(x, training)
def np_to_pil(img_np):
"""
Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
:param img_np:
:return:
"""
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
assert img_np.shape[0] == 3, img_np.shape
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def torch_to_np(img_var):
"""
Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
:param img_var:
:return:
"""
return img_var.detach().cpu().numpy()[0]
@spaces.GPU(duration=200)
def restore_image(input_img):
np.random.seed(0)
torch.manual_seed(0)
ckpt_path = "ckpt/cata_promptxrestormeronlyattn_epoch=30-step=275962.ckpt"
print("CKPT name : {}".format(ckpt_path))
net = CATAPromptXRestormerIRModel.load_from_checkpoint(ckpt_path).cuda()
net.eval()
degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
toTensor = ToTensor()
degraded_img = toTensor(degraded_img)
with torch.no_grad():
degraded_img = degraded_img.unsqueeze(0).cuda()
_, _, H_old, W_old = degraded_img.shape
h_pad = (H_old // 64 + 1) * 64 - H_old
w_pad = (W_old // 64 + 1) * 64 - W_old
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:]
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad]
restored, spatial_mask, channel_mask = net(degraded_img, training=False)
# process spatial mask
encoder_level1_mask = spatial_mask['encoder_level1'][0][0][0]
window_size = 8
_, c, h, w = restored.shape
restored_windows = rearrange(restored, 'b c (h w1) (w w2) -> b c (h w) w1 w2', w1=window_size, w2=window_size)
for idx in encoder_level1_mask:
restored_windows[:, :, idx, :, :] = 1 # Mask out the window by setting it to one
restored_masked = rearrange(restored_windows, 'b c (h w) w1 w2 -> b c (h w1) (w w2)', h=h // window_size, w=w // window_size)
restored = restored[:,:,:H_old:,:W_old]
restored_image = torch_to_np(restored)
restored_image = restored_image.transpose(1, 2, 0)
restored_image = np.clip(restored_image * 255, 0, 255).astype(np.uint8)
restored_masked = restored_masked[:,:,:H_old:,:W_old]
restored_masked_image = torch_to_np(restored_masked)
restored_masked_image = restored_masked_image.transpose(1, 2, 0)
restored_masked_image = np.clip(restored_masked_image * 255, 0, 255).astype(np.uint8)
return restored_image, restored_masked_image
title = "Content & Task Awareness All-In-One Image Restoration✏️🖼️ 🤗"
description = ''' ## [Content & Task Awareness All-In-One Image Restoration]
The Ohio State Unviersity | Microsoft Research
### TL;DR: quickstart
***One single model can perform several restoration tasks including image denoising, deraining and dehazing 🚀 . Our content & task awareness model would have better efficiency***
The (single) neural model performs all-in-one image restoration.
**🚀 You can start with the [demo tutorial.]** Check [our github] for more information.
<br>
'''
article = "<p style='text-align: center'><a href='https://github.com/mv-lab/InstructIR' target='_blank'>Content & Task Awareness All-In-One Image Restoration</a></p>"
#### Image,Prompts examples
examples = [['test_images/noisy_0000.png'],
['test_images/noisy_0001.png'],
['test_images/noisy_0002.png'],
['test_images/noisy_0003.png'],
['test_images/noisy_0004.png'],
['test_images/rain-001.png'],
['test_images/rain-002.png'],
['test_images/rain-003.png'],
['test_images/rain-004.png'],
['test_images/rain-005.png'],
['test_images/hazy-01.jpg'],
['test_images/hazy-02.jpg'],
['test_images/hazy-03.jpg'],
['test_images/hazy-04.jpg'],
['test_images/hazy-05.jpg'],
]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=restore_image,
inputs=[gr.Image(type="pil", label="Input")],
outputs=[gr.Image(type="pil", label="Ouput"), gr.Image(type="pil", label="Output-Mask")],
title=title,
description=description,
article=article,
examples=examples,
css=css,
)
if __name__ == "__main__":
demo.launch(debug=True, show_error=True)