Spaces:
Running
Running
| import argparse | |
| import cv2 | |
| import os | |
| from imutils import paths | |
| from tqdm import tqdm | |
| from config import * | |
| from utils import get_face_enhancer, get_upsampler | |
| def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"): | |
| if scale > 4: | |
| scale = 4 # avoid too large scale value | |
| try: | |
| img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) | |
| h, w = img.shape[0:2] | |
| if h > 3500 or w > 3500: | |
| output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return output | |
| if (h < 300 and w < 300) and upsampler_name != "srcnn": | |
| img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
| return img | |
| upsampler = get_upsampler(upsampler_name, device=device) | |
| if face_enhancer_name: | |
| face_enhancer = get_face_enhancer( | |
| face_enhancer_name, scale, upsampler, device=device | |
| ) | |
| else: | |
| face_enhancer = None | |
| try: | |
| if face_enhancer is not None: | |
| _, _, output = face_enhancer.enhance( | |
| img, has_aligned=False, only_center_face=False, paste_back=True | |
| ) | |
| else: | |
| output, _ = upsampler.enhance(img, outscale=scale) | |
| except RuntimeError as error: | |
| print(f"Runtime error: {error}") | |
| return output | |
| except Exception as error: | |
| print(f"global exception: {error}") | |
| def main(args: argparse.Namespace) -> None: | |
| device = args.device | |
| scale = args.scale | |
| upsampler_name = args.upsampler | |
| face_enhancer_name = args.face_enhancer | |
| if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name): | |
| print( | |
| "Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you" | |
| ) | |
| face_enhancer_name = None | |
| os.makedirs(args.output, exist_ok=True) | |
| if not os.path.exists(args.input): | |
| raise ValueError("The input directory doesn't exist!") | |
| elif not os.path.isdir(args.input): | |
| image_paths = [args.input] | |
| else: | |
| image_paths = paths.list_images(args.input) | |
| with tqdm(image_paths) as pbar: | |
| for image_path in pbar: | |
| filename = os.path.basename(image_path) | |
| pbar.set_postfix_str(f"Processing {image_path}") | |
| upsampled_image = process( | |
| image_path=image_path, | |
| upsampler_name=upsampler_name, | |
| face_enhancer_name=face_enhancer_name, | |
| scale=scale, | |
| device=device, | |
| ) | |
| if upsampled_image is not None: | |
| save_path = os.path.join(args.output, filename) | |
| cv2.imwrite(save_path, upsampled_image) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Runs automatic detection and mask generation on an input image or directory of images" | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| "-i", | |
| type=str, | |
| required=True, | |
| help="Path to either a single input image or folder of images.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| type=str, | |
| required=True, | |
| help="Path to the output directory.", | |
| ) | |
| parser.add_argument( | |
| "--upsampler", | |
| type=str, | |
| default="realesr-general-x4v3", | |
| choices=[ | |
| "srcnn", | |
| "RealESRGAN_x2plus", | |
| "RealESRGAN_x4plus", | |
| "RealESRNet_x4plus", | |
| "realesr-general-x4v3", | |
| "RealESRGAN_x4plus_anime_6B", | |
| "realesr-animevideov3", | |
| ], | |
| help="The type of upsampler model to load", | |
| ) | |
| parser.add_argument( | |
| "--face-enhancer", | |
| type=str, | |
| choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"], | |
| help="The type of face enhancer model to load", | |
| ) | |
| parser.add_argument( | |
| "--scale", | |
| type=float, | |
| default=2, | |
| choices=[1.5, 2, 2.5, 3, 3.5, 4], | |
| help="scaling factor", | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda", help="The device to run upsampling on." | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |