RenderingModel / preprocessor.py
Ahmed Essam
Upload 5 files
402cce1 verified
raw
history blame contribute delete
684 Bytes
import gc
import PIL.Image
import torch
from controlnet_aux import LineartDetector
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
if name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
return self.model(image, **kwargs)