Cat-AIR / app_local.py
jiachen
cata
b6ee05f
import numpy as np
import gradio as gr
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from net.prompt_xrestormer import PromptXRestormer
import lightning.pytorch as pl
# crop an image to the multiple of base
def crop_img(image, base=64):
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
class PromptXRestormerIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = PromptXRestormer(
inp_channels=3,
out_channels=3,
dim = 48,
num_blocks = [2,4,4,4],
num_refinement_blocks = 4,
channel_heads= [1,1,1,1],
spatial_heads= [1,2,4,8],
overlap_ratio= [0.5, 0.5, 0.5, 0.5],
ffn_expansion_factor = 2.66,
bias = False,
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
scale = 1,prompt = True
)
def forward(self,x):
return self.net(x)
def np_to_pil(img_np):
"""
Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
:param img_np:
:return:
"""
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
assert img_np.shape[0] == 3, img_np.shape
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def torch_to_np(img_var):
"""
Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
:param img_var:
:return:
"""
return img_var.detach().cpu().numpy()[0]
#@spaces.GPU(duration=200)
def restore_image(input_img):
np.random.seed(0)
torch.manual_seed(0)
#ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt"
ckpt_path = "ckpt/promptxrestormer_epoch=64-step=578630.ckpt"
print("CKPT name : {}".format(ckpt_path))
#net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda()
net = PromptXRestormerIRModel.load_from_checkpoint(ckpt_path).cuda()
net.eval()
#degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png"
degraded_img = crop_img(np.array(input_img.convert('RGB')), base=16)
toTensor = ToTensor()
degraded_img = toTensor(degraded_img)
with torch.no_grad():
degraded_img = degraded_img.unsqueeze(0).cuda()
_, _, H_old, W_old = degraded_img.shape
h_pad = (H_old // 64 + 1) * 64 - H_old
w_pad = (W_old // 64 + 1) * 64 - W_old
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:]
degraded_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad]
restored = net(degraded_img)
restored = restored[:,:,:H_old:,:W_old]
restored_image = torch_to_np(restored)
# change shape from [C, H, W] to [H, W, C]
restored_image = restored_image.transpose(1, 2, 0)
restored_image = np.clip(restored_image * 255, 0, 255).astype(np.uint8)
# restored_image = Image.fromarray(restored_image)
# print("restored shape : {}".format(restored_image.size))
return restored_image
# degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png"
# input_img = np.array(Image.open(degraded_path).convert('RGB'))
# print(input_img)
# restored_image = restore_image(input_img)
# print(restored_image)
title = "Content & Task Awareness All-In-One Image Restoration✏️🖼️ 🤗"
description = ''' ## [Content & Task Awareness All-In-One Image Restoration]
The Ohio State Unviersity | Microsoft Research
### TL;DR: quickstart
***One single model can perform several restoration tasks including image denoising, deraining and dehazing 🚀 . Our content & task awareness model would have better efficiency***
The (single) neural model performs all-in-one image restoration.
**🚀 You can start with the [demo tutorial.]** Check [our github] for more information.
<br>
'''
article = "<p style='text-align: center'><a href='https://github.com/mv-lab/InstructIR' target='_blank'>Content & Task Awareness All-In-One Image Restoration</a></p>"
#### Image,Prompts examples
examples = [['test_images/noisy_0000.png'],
['test_images/noisy_0001.png'],
['test_images/noisy_0002.png'],
['test_images/noisy_0003.png'],
['test_images/noisy_0004.png'],
['test_images/rain-01.png'],
['test_images/rain-02.png'],
['test_images/rain-03.png'],
['test_images/rain-04.png'],
['test_images/rain-05.png'],
['test_images/rain-06.png'],
['test_images/hazy-00.jpg'],
['test_images/hazy-01.jpg'],
['test_images/hazy-02.jpg'],
['test_images/hazy-03.jpg'],
['test_images/hazy-04.jpg'],
]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=restore_image,
inputs=[gr.Image(type="pil", label="Input")],
outputs=[gr.Image(type="pil", label="Ouput")],
title=title,
description=description,
article=article,
examples=examples,
css=css,
)
if __name__ == "__main__":
demo.launch(server_port=8085)