RenderingModel / model.py
Ahmed Essam
Upload 5 files
402cce1 verified
from __future__ import annotations
import gc
import numpy as np
import PIL.Image
import torch
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from preprocessor import Preprocessor
from settings import *
class Model:
def __init__(self, base_model_id: str = "runwayml/stable-diffusion-v1-5", task_name: str = "lineart"):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.base_model_id = ""
self.task_name = ""
self.pipe = self.load_pipe(base_model_id, task_name)
self.preprocessor = Preprocessor()
def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
if (
base_model_id == self.base_model_id
and task_name == self.task_name
and hasattr(self, "pipe")
and self.pipe is not None
):
return self.pipe
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
if self.device.type == "cuda":
pipe.enable_xformers_memory_efficient_attention()
pipe.to(self.device)
torch.cuda.empty_cache()
gc.collect()
self.base_model_id = base_model_id
self.task_name = task_name
return pipe
def set_base_model(self, base_model_id: str) -> str:
if not base_model_id or base_model_id == self.base_model_id:
return self.base_model_id
del self.pipe
torch.cuda.empty_cache()
gc.collect()
try:
self.pipe = self.load_pipe(base_model_id, self.task_name)
except Exception:
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
return self.base_model_id
def load_controlnet_weight(self, task_name: str) -> None:
if task_name == self.task_name:
return
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
del self.pipe.controlnet
torch.cuda.empty_cache()
gc.collect()
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
controlnet.to(self.device)
torch.cuda.empty_cache()
gc.collect()
self.pipe.controlnet = controlnet
self.task_name = task_name
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
if not prompt:
prompt = additional_prompt
else:
prompt = f"{prompt}, {additional_prompt}"
return prompt
@torch.autocast("cuda")
def run_pipe(
self,
control_image: PIL.Image.Image,
) -> list[PIL.Image.Image]:
generator = torch.Generator().manual_seed(randomize_seed)
return self.pipe(
prompt=prompt + ' ' + a_prompt,
negative_prompt=n_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=DEFAULT_NUM_IMAGES,
num_inference_steps=num_steps,
generator=generator,
image=control_image,
).images
def process_lineart(
self,
image: np.ndarray,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
else:
self.preprocessor.load("Lineart")
control_image = self.preprocessor(
image=image,
image_resolution=DEFAULT_IMAGE_RESOLUTION,
detect_resolution=preprocess_resolution,
)
self.load_controlnet_weight("lineart")
results = self.run_pipe(
control_image=control_image
)
return [control_image] + results