Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any, Dict, Callable, Optional | |
| import numpy as np | |
| import torch | |
| import diffusers | |
| import onnxruntime as ort | |
| import optimum.onnxruntime | |
| initialized = False | |
| run_olive_workflow = None | |
| class DynamicSessionOptions(ort.SessionOptions): | |
| config: Optional[Dict] = None | |
| def __init__(self): | |
| super().__init__() | |
| self.enable_mem_pattern = False | |
| def from_sess_options(cls, sess_options: ort.SessionOptions): | |
| if isinstance(sess_options, DynamicSessionOptions): | |
| return sess_options.copy() | |
| return DynamicSessionOptions() | |
| def enable_static_dims(self, config: Dict): | |
| self.config = config | |
| self.add_free_dimension_override_by_name("unet_sample_batch", config["hidden_batch_size"]) | |
| self.add_free_dimension_override_by_name("unet_sample_channels", 4) | |
| self.add_free_dimension_override_by_name("unet_sample_height", config["height"] // 8) | |
| self.add_free_dimension_override_by_name("unet_sample_width", config["width"] // 8) | |
| self.add_free_dimension_override_by_name("unet_time_batch", 1) | |
| self.add_free_dimension_override_by_name("unet_hidden_batch", config["hidden_batch_size"]) | |
| self.add_free_dimension_override_by_name("unet_hidden_sequence", 77) | |
| if config["is_sdxl"] and not config["is_refiner"]: | |
| self.add_free_dimension_override_by_name("unet_text_embeds_batch", config["hidden_batch_size"]) | |
| self.add_free_dimension_override_by_name("unet_text_embeds_size", 1280) | |
| self.add_free_dimension_override_by_name("unet_time_ids_batch", config["hidden_batch_size"]) | |
| self.add_free_dimension_override_by_name("unet_time_ids_size", 6) | |
| def copy(self): | |
| sess_options = DynamicSessionOptions() | |
| if self.config is not None: | |
| sess_options.enable_static_dims(self.config) | |
| return sess_options | |
| class TorchCompatibleModule: | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| def to(self, *_, **__): | |
| raise NotImplementedError | |
| def type(self, *_, **__): | |
| return self | |
| class TemporalModule(TorchCompatibleModule): | |
| """ | |
| Replace the models which are not able to be moved to CPU. | |
| """ | |
| provider: Any | |
| path: str | |
| sess_options: ort.SessionOptions | |
| def __init__(self, provider: Any, path: str, sess_options: ort.SessionOptions): | |
| self.provider = provider | |
| self.path = path | |
| self.sess_options = sess_options | |
| def to(self, *args, **kwargs): | |
| from .utils import extract_device | |
| device = extract_device(args, kwargs) | |
| if device is not None and device.type != "cpu": | |
| from .execution_providers import TORCH_DEVICE_TO_EP | |
| provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else self.provider | |
| return OnnxRuntimeModel.load_model(self.path, provider, DynamicSessionOptions.from_sess_options(self.sess_options)) | |
| return self | |
| class OnnxRuntimeModel(TorchCompatibleModule, diffusers.OnnxRuntimeModel): | |
| config = {} # dummy | |
| def named_modules(self): # dummy | |
| return () | |
| def to(self, *args, **kwargs): | |
| from modules.onnx_impl.utils import extract_device, move_inference_session | |
| device = extract_device(args, kwargs) | |
| if device is not None: | |
| self.device = device | |
| self.model = move_inference_session(self.model, device) | |
| return self | |
| class VAEConfig: | |
| DEFAULTS = { "scaling_factor": 0.18215 } | |
| config: Dict | |
| def __init__(self, config: Dict): | |
| self.config = config | |
| def __getattr__(self, key): | |
| return self.config.get(key, VAEConfig.DEFAULTS[key]) | |
| class VAE(TorchCompatibleModule): | |
| pipeline: Any | |
| def __init__(self, pipeline: Any): | |
| self.pipeline = pipeline | |
| def config(self): | |
| return VAEConfig(self.pipeline.vae_decoder.config) | |
| def device(self): | |
| return self.pipeline.vae_decoder.device | |
| def encode(self, sample: torch.Tensor, *_, **__): | |
| sample_np = sample.cpu().numpy() | |
| return [ | |
| torch.from_numpy(np.concatenate( | |
| [self.pipeline.vae_encoder(sample=sample_np[i : i + 1])[0] for i in range(sample_np.shape[0])] | |
| )).to(sample.device) | |
| ] | |
| def decode(self, latent_sample: torch.Tensor, *_, **__): | |
| latents_np = latent_sample.cpu().numpy() | |
| return [ | |
| torch.from_numpy(np.concatenate( | |
| [self.pipeline.vae_decoder(latent_sample=latents_np[i : i + 1])[0] for i in range(latents_np.shape[0])] | |
| )).to(latent_sample.device) | |
| ] | |
| def to(self, *args, **kwargs): | |
| self.pipeline.vae_encoder = self.pipeline.vae_encoder.to(*args, **kwargs) | |
| self.pipeline.vae_decoder = self.pipeline.vae_decoder.to(*args, **kwargs) | |
| return self | |
| def check_parameters_changed(p, refiner_enabled: bool): | |
| from modules import shared, sd_models | |
| if shared.sd_model.__class__.__name__ == "OnnxRawPipeline" or not shared.sd_model.__class__.__name__.startswith("Onnx"): | |
| return shared.sd_model | |
| compile_height = p.height | |
| compile_width = p.width | |
| if (shared.compiled_model_state is None or | |
| shared.compiled_model_state.height != compile_height | |
| or shared.compiled_model_state.width != compile_width | |
| or shared.compiled_model_state.batch_size != p.batch_size): | |
| shared.log.info("Olive: Parameter change detected") | |
| shared.log.info("Olive: Recompiling base model") | |
| sd_models.unload_model_weights(op='model') | |
| sd_models.reload_model_weights(op='model') | |
| if refiner_enabled: | |
| shared.log.info("Olive: Recompiling refiner") | |
| sd_models.unload_model_weights(op='refiner') | |
| sd_models.reload_model_weights(op='refiner') | |
| shared.compiled_model_state.height = compile_height | |
| shared.compiled_model_state.width = compile_width | |
| shared.compiled_model_state.batch_size = p.batch_size | |
| return shared.sd_model | |
| def preprocess_pipeline(p): | |
| from modules import shared, sd_models | |
| if "ONNX" not in shared.opts.diffusers_pipeline: | |
| shared.log.warning(f"Unsupported pipeline for 'olive-ai' compile backend: {shared.opts.diffusers_pipeline}. You should select one of the ONNX pipelines.") | |
| return shared.sd_model | |
| if hasattr(shared.sd_model, "preprocess"): | |
| shared.sd_model = shared.sd_model.preprocess(p) | |
| if hasattr(shared.sd_refiner, "preprocess"): | |
| if shared.opts.onnx_unload_base: | |
| sd_models.unload_model_weights(op='model') | |
| shared.sd_refiner = shared.sd_refiner.preprocess(p) | |
| if shared.opts.onnx_unload_base: | |
| sd_models.reload_model_weights(op='model') | |
| shared.sd_model = shared.sd_model.preprocess(p) | |
| return shared.sd_model | |
| def ORTDiffusionModelPart_to(self, *args, **kwargs): | |
| self.parent_model = self.parent_model.to(*args, **kwargs) | |
| return self | |
| def initialize_onnx(): | |
| global initialized # pylint: disable=global-statement | |
| if initialized: | |
| return | |
| from installer import log, installed | |
| from modules import devices | |
| from modules.shared import opts | |
| if not installed('onnx', quiet=True): | |
| return | |
| try: # may fail on onnx import | |
| import onnx # pylint: disable=unused-import | |
| from .execution_providers import ExecutionProvider, TORCH_DEVICE_TO_EP, available_execution_providers | |
| if devices.backend == "rocm": | |
| TORCH_DEVICE_TO_EP["cuda"] = ExecutionProvider.ROCm | |
| from .pipelines.onnx_stable_diffusion_pipeline import OnnxStableDiffusionPipeline | |
| from .pipelines.onnx_stable_diffusion_img2img_pipeline import OnnxStableDiffusionImg2ImgPipeline | |
| from .pipelines.onnx_stable_diffusion_inpaint_pipeline import OnnxStableDiffusionInpaintPipeline | |
| from .pipelines.onnx_stable_diffusion_upscale_pipeline import OnnxStableDiffusionUpscalePipeline | |
| from .pipelines.onnx_stable_diffusion_xl_pipeline import OnnxStableDiffusionXLPipeline | |
| from .pipelines.onnx_stable_diffusion_xl_img2img_pipeline import OnnxStableDiffusionXLImg2ImgPipeline | |
| OnnxRuntimeModel.__module__ = 'diffusers' # OnnxRuntimeModel Hijack. | |
| diffusers.OnnxRuntimeModel = OnnxRuntimeModel | |
| diffusers.OnnxStableDiffusionPipeline = OnnxStableDiffusionPipeline | |
| diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionPipeline | |
| diffusers.OnnxStableDiffusionImg2ImgPipeline = OnnxStableDiffusionImg2ImgPipeline | |
| diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionImg2ImgPipeline | |
| diffusers.OnnxStableDiffusionInpaintPipeline = OnnxStableDiffusionInpaintPipeline | |
| diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["onnx-stable-diffusion"] = diffusers.OnnxStableDiffusionInpaintPipeline | |
| diffusers.OnnxStableDiffusionUpscalePipeline = OnnxStableDiffusionUpscalePipeline | |
| diffusers.OnnxStableDiffusionXLPipeline = OnnxStableDiffusionXLPipeline | |
| diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLPipeline | |
| diffusers.OnnxStableDiffusionXLImg2ImgPipeline = OnnxStableDiffusionXLImg2ImgPipeline | |
| diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["onnx-stable-diffusion-xl"] = diffusers.OnnxStableDiffusionXLImg2ImgPipeline | |
| diffusers.ORTStableDiffusionXLPipeline = diffusers.OnnxStableDiffusionXLPipeline # Huggingface model compatibility | |
| diffusers.ORTStableDiffusionXLImg2ImgPipeline = diffusers.OnnxStableDiffusionXLImg2ImgPipeline | |
| optimum.onnxruntime.modeling_diffusion._ORTDiffusionModelPart.to = ORTDiffusionModelPart_to # pylint: disable=protected-access | |
| log.debug(f'ONNX: version={ort.__version__} provider={opts.onnx_execution_provider}, available={available_execution_providers}') | |
| except Exception as e: | |
| log.error(f'ONNX failed to initialize: {e}') | |
| initialized = True | |
| def initialize_olive(): | |
| global run_olive_workflow # pylint: disable=global-statement | |
| from installer import installed, log | |
| if not installed('olive-ai', quiet=True) or not installed('onnx', quiet=True): | |
| return | |
| import sys | |
| import importlib | |
| orig_sys_path = sys.path | |
| venv_dir = os.environ.get("VENV_DIR", os.path.join(os.getcwd(), 'venv')) | |
| try: | |
| spec = importlib.util.find_spec('onnxruntime.transformers') | |
| sys.path = [d for d in spec.submodule_search_locations + sys.path if sys.path[1] not in d or venv_dir in d] | |
| from onnxruntime.transformers import convert_generation # pylint: disable=unused-import | |
| spec = importlib.util.find_spec('olive') | |
| sys.path = spec.submodule_search_locations + sys.path | |
| run_olive_workflow = importlib.import_module('olive.workflows').run | |
| except Exception as e: | |
| run_olive_workflow = None | |
| log.error(f'Olive: Failed to load olive-ai: {e}') | |
| sys.path = orig_sys_path | |
| def install_olive(): | |
| from installer import installed, install, log | |
| if installed("olive-ai"): | |
| return | |
| try: | |
| log.info('Installing Olive') | |
| install('onnx', 'onnx', ignore=True) | |
| install('olive-ai', 'olive-ai', ignore=True) | |
| import olive.workflows # pylint: disable=unused-import | |
| except Exception as e: | |
| log.error(f'Olive: Failed to load olive-ai: {e}') | |
| else: | |
| log.info('Olive: Please restart webui session.') | |