| import gc | |
| import PIL.Image | |
| import torch | |
| from controlnet_aux import LineartDetector | |
| class Preprocessor: | |
| MODEL_ID = "lllyasviel/Annotators" | |
| def __init__(self): | |
| self.model = None | |
| self.name = "" | |
| def load(self, name: str) -> None: | |
| if name == self.name: | |
| return | |
| if name == "Lineart": | |
| self.model = LineartDetector.from_pretrained(self.MODEL_ID) | |
| else: | |
| raise ValueError | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| self.name = name | |
| def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: | |
| return self.model(image, **kwargs) | |