from __future__ import annotations import gc import numpy as np import PIL.Image import torch from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from preprocessor import Preprocessor from settings import * class Model: def __init__(self, base_model_id: str = "runwayml/stable-diffusion-v1-5", task_name: str = "lineart"): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.base_model_id = "" self.task_name = "" self.pipe = self.load_pipe(base_model_id, task_name) self.preprocessor = Preprocessor() def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline: if ( base_model_id == self.base_model_id and task_name == self.task_name and hasattr(self, "pipe") and self.pipe is not None ): return self.pipe controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) if self.device.type == "cuda": pipe.enable_xformers_memory_efficient_attention() pipe.to(self.device) torch.cuda.empty_cache() gc.collect() self.base_model_id = base_model_id self.task_name = task_name return pipe def set_base_model(self, base_model_id: str) -> str: if not base_model_id or base_model_id == self.base_model_id: return self.base_model_id del self.pipe torch.cuda.empty_cache() gc.collect() try: self.pipe = self.load_pipe(base_model_id, self.task_name) except Exception: self.pipe = self.load_pipe(self.base_model_id, self.task_name) return self.base_model_id def load_controlnet_weight(self, task_name: str) -> None: if task_name == self.task_name: return if self.pipe is not None and hasattr(self.pipe, "controlnet"): del self.pipe.controlnet torch.cuda.empty_cache() gc.collect() controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) controlnet.to(self.device) torch.cuda.empty_cache() gc.collect() self.pipe.controlnet = controlnet self.task_name = task_name def get_prompt(self, prompt: str, additional_prompt: str) -> str: if not prompt: prompt = additional_prompt else: prompt = f"{prompt}, {additional_prompt}" return prompt @torch.autocast("cuda") def run_pipe( self, control_image: PIL.Image.Image, ) -> list[PIL.Image.Image]: generator = torch.Generator().manual_seed(randomize_seed) return self.pipe( prompt=prompt + ' ' + a_prompt, negative_prompt=n_prompt, guidance_scale=guidance_scale, num_images_per_prompt=DEFAULT_NUM_IMAGES, num_inference_steps=num_steps, generator=generator, image=control_image, ).images def process_lineart( self, image: np.ndarray, ) -> list[PIL.Image.Image]: if image is None: raise ValueError else: self.preprocessor.load("Lineart") control_image = self.preprocessor( image=image, image_resolution=DEFAULT_IMAGE_RESOLUTION, detect_resolution=preprocess_resolution, ) self.load_controlnet_weight("lineart") results = self.run_pipe( control_image=control_image ) return [control_image] + results