Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| import pathlib | |
| import sys | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from torch import nn | |
| submodule_dir = pathlib.Path(__file__).parent / "MangaLineExtraction_PyTorch" | |
| sys.path.append(submodule_dir.as_posix()) | |
| from model_torch import res_skip # type: ignore # noqa: E402 | |
| DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)" | |
| def load_model(device: torch.device) -> nn.Module: | |
| ckpt_path = hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth") | |
| state_dict = torch.load(ckpt_path) | |
| model = res_skip() | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| MAX_SIZE = 1000 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = load_model(device) | |
| def predict(image: np.ndarray) -> np.ndarray: | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| if max(gray.shape) > MAX_SIZE: | |
| scale = MAX_SIZE / max(gray.shape) | |
| gray = cv2.resize(gray, None, fx=scale, fy=scale) | |
| h, w = gray.shape | |
| size = 16 | |
| new_w = (w + size - 1) // size * size | |
| new_h = (h + size - 1) // size * size | |
| patch = np.ones((1, 1, new_h, new_w), dtype=np.float32) | |
| patch[0, 0, :h, :w] = gray | |
| tensor = torch.from_numpy(patch).to(device) | |
| out = model(tensor) | |
| res = out.cpu().numpy()[0, 0, :h, :w] | |
| return np.clip(res, 0, 255).astype(np.uint8) | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input", type="numpy") | |
| run_button = gr.Button() | |
| with gr.Column(): | |
| result = gr.Image(label="Result", elem_id="result") | |
| run_button.click( | |
| fn=predict, | |
| inputs=input_image, | |
| outputs=result, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |