Spaces:
Runtime error
Runtime error
| from typing import Union | |
| import time | |
| import diffusers.utils | |
| from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline | |
| from modules.shared import log, opts | |
| from modules.control.units import detect | |
| from modules import sd_models | |
| what = 'Reference' | |
| def list_models(): | |
| return ['Reference'] | |
| class ReferencePipeline(): | |
| def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): | |
| t0 = time.time() | |
| self.orig_pipeline = pipeline | |
| self.pipeline = None | |
| if pipeline is None: | |
| log.error(f'Control {what} model pipeline: model not loaded') | |
| return | |
| if opts.diffusers_fuse_projections and hasattr(pipeline, 'unfuse_qkv_projections'): | |
| pipeline.unfuse_qkv_projections() | |
| if detect.is_sdxl(pipeline): | |
| cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_xl_reference', module_file='pipeline.py') | |
| self.pipeline = cls( | |
| vae=pipeline.vae, | |
| text_encoder=pipeline.text_encoder, | |
| text_encoder_2=pipeline.text_encoder_2, | |
| tokenizer=pipeline.tokenizer, | |
| tokenizer_2=pipeline.tokenizer_2, | |
| unet=pipeline.unet, | |
| scheduler=pipeline.scheduler, | |
| feature_extractor=getattr(pipeline, 'feature_extractor', None), | |
| ) | |
| sd_models.move_model(self.pipeline, pipeline.device) | |
| elif detect.is_sd15(pipeline): | |
| cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_reference', module_file='pipeline.py') | |
| self.pipeline = cls( | |
| vae=pipeline.vae, | |
| text_encoder=pipeline.text_encoder, | |
| tokenizer=pipeline.tokenizer, | |
| unet=pipeline.unet, | |
| scheduler=pipeline.scheduler, | |
| feature_extractor=getattr(pipeline, 'feature_extractor', None), | |
| requires_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| sd_models.move_model(self.pipeline, pipeline.device) | |
| else: | |
| log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') | |
| return | |
| if dtype is not None and self.pipeline is not None: | |
| self.pipeline = self.pipeline.to(dtype) | |
| t1 = time.time() | |
| if self.pipeline is not None: | |
| log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') | |
| else: | |
| log.error(f'Control {what} pipeline: not initialized') | |
| def restore(self): | |
| self.pipeline = None | |
| return self.orig_pipeline | |