Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import importlib | |
| from typing import Type, Tuple, Union, List, Dict, Any | |
| import torch | |
| import diffusers | |
| import onnxruntime as ort | |
| def extract_device(args: List, kwargs: Dict): | |
| device = kwargs.get("device", None) | |
| if device is None: | |
| for arg in args: | |
| if isinstance(arg, torch.device): | |
| device = arg | |
| return device | |
| def move_inference_session(session: ort.InferenceSession, device: torch.device): | |
| from modules.devices import device as default_device | |
| if default_device.type == "cpu": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake. | |
| return session | |
| from . import DynamicSessionOptions, TemporalModule | |
| from .execution_providers import TORCH_DEVICE_TO_EP | |
| previous_provider = session._providers # pylint: disable=protected-access | |
| provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider | |
| path = session._model_path # pylint: disable=protected-access | |
| try: | |
| return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access | |
| except Exception: | |
| return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access | |
| def check_diffusers_cache(path: os.PathLike): | |
| from modules.shared import opts | |
| return opts.diffusers_dir in os.path.abspath(path) | |
| def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool: | |
| return 'XL' in cls.__name__ | |
| def check_cache_onnx(path: os.PathLike) -> bool: | |
| if not os.path.isdir(path): | |
| return False | |
| init_dict_path = os.path.join(path, "model_index.json") | |
| if not os.path.isfile(init_dict_path): | |
| return False | |
| init_dict = None | |
| with open(init_dict_path, "r", encoding="utf-8") as file: | |
| init_dict = file.read() | |
| if "OnnxRuntimeModel" not in init_dict: | |
| return False | |
| return True | |
| def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike): | |
| merged: Dict[str, Any] = {} | |
| extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path)) | |
| for item in extracted: | |
| merged.update(item) | |
| merged = merged.items() | |
| R: Dict[str, Tuple[str]] = {} | |
| for k, v in merged: | |
| if isinstance(v, list): | |
| if k not in cls.__init__.__annotations__: | |
| continue | |
| R[k] = v | |
| return R | |
| def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort): | |
| lib, atr = item | |
| if lib is None or atr is None: | |
| return None | |
| library = importlib.import_module(lib) | |
| attribute = getattr(library, atr) | |
| path = os.path.join(path, submodel_name) | |
| if issubclass(attribute, diffusers.OnnxRuntimeModel): | |
| return diffusers.OnnxRuntimeModel.load_model( | |
| os.path.join(path, "model.onnx"), | |
| **kwargs_ort, | |
| ) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained( | |
| path, | |
| **kwargs_ort, | |
| ) | |
| return attribute.from_pretrained(path) | |
| def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort): | |
| loaded = {} | |
| for k, v in init_dict.items(): | |
| if not isinstance(v, list): | |
| loaded[k] = v | |
| continue | |
| try: | |
| loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort) | |
| except Exception: | |
| pass | |
| return loaded | |
| def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline: | |
| if os.path.isdir(path): | |
| return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort))) | |
| else: | |
| return cls.from_single_file(path) | |
| def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict: | |
| if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: | |
| kwargs["safety_checker"] = None | |
| kwargs["requires_safety_checker"] = False | |
| if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline: | |
| kwargs["config"] = {} | |
| return kwargs | |
| def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool): | |
| if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: | |
| return diffusers.OnnxStableDiffusionPipeline | |
| if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner: | |
| return diffusers.OnnxStableDiffusionXLPipeline | |
| return cls | |
| def get_io_config(submodel: str, is_sdxl: bool): | |
| from modules.paths import sd_configs_path | |
| with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file: | |
| io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"] | |
| for axe in io_config["dynamic_axes"]: | |
| io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() } | |
| return io_config | |