Spaces:
Runtime error
Runtime error
| import os | |
| import html | |
| import json | |
| import time | |
| import shutil | |
| import torch | |
| import tqdm | |
| import gradio as gr | |
| import safetensors.torch | |
| from modules.merging.merge import merge_models | |
| from modules.merging.merge_utils import TRIPLE_METHODS | |
| from modules import shared, images, sd_models, sd_vae, sd_models_config, devices | |
| def run_pnginfo(image): | |
| if image is None: | |
| return '', '', '' | |
| geninfo, items = images.read_info_from_image(image) | |
| items = {**{'parameters': geninfo}, **items} | |
| info = '' | |
| for key, text in items.items(): | |
| if key != 'UserComment': | |
| info += f"<div><b>{html.escape(str(key))}</b>: {html.escape(str(text))}</div>" | |
| return '', geninfo, info | |
| def create_config(ckpt_result, config_source, a, b, c): | |
| def config(x): | |
| res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None | |
| return res if res != shared.sd_default_config else None | |
| if config_source == 0: | |
| cfg = config(a) or config(b) or config(c) | |
| elif config_source == 1: | |
| cfg = config(b) | |
| elif config_source == 2: | |
| cfg = config(c) | |
| else: | |
| cfg = None | |
| if cfg is None: | |
| return | |
| filename, _ = os.path.splitext(ckpt_result) | |
| checkpoint_filename = filename + ".yaml" | |
| shared.log.info("Copying config: {cfg} -> {checkpoint_filename}") | |
| shutil.copyfile(cfg, checkpoint_filename) | |
| def to_half(tensor, enable): | |
| if enable and tensor.dtype == torch.float: | |
| return tensor.half() | |
| return tensor | |
| def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument | |
| shared.state.begin('merge') | |
| t0 = time.time() | |
| def fail(message): | |
| shared.state.textinfo = message | |
| shared.state.end() | |
| return [*[gr.update() for _ in range(4)], message] | |
| kwargs["models"] = { | |
| "model_a": sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)).filename, | |
| "model_b": sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)).filename, | |
| } | |
| if kwargs.get("primary_model_name", None) in [None, 'None']: | |
| return fail("Failed: Merging requires a primary model.") | |
| primary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)) | |
| if kwargs.get("secondary_model_name", None) in [None, 'None']: | |
| return fail("Failed: Merging requires a secondary model.") | |
| secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)) | |
| if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS: | |
| return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.") | |
| tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None | |
| del kwargs["primary_model_name"] | |
| del kwargs["secondary_model_name"] | |
| if kwargs.get("tertiary_model_name", None) is not None: | |
| kwargs["models"] |= {"model_c": sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename} | |
| del kwargs["tertiary_model_name"] | |
| if kwargs.get("alpha_base", None) and kwargs.get("alpha_in_blocks", None) and kwargs.get("alpha_mid_block", None) and kwargs.get("alpha_out_blocks", None): | |
| try: | |
| alpha = [float(x) for x in | |
| [kwargs["alpha_base"]] + kwargs["alpha_in_blocks"].split(",") + [kwargs["alpha_mid_block"]] + kwargs["alpha_out_blocks"].split(",")] | |
| assert len(alpha) == 26 or len(alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL)" | |
| kwargs["alpha"] = alpha | |
| except KeyError as ke: | |
| shared.log.warning(f"Merge: Malformed manual block weight: {ke}") | |
| elif kwargs.get("alpha_preset", None) or kwargs.get("alpha", None): | |
| kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"]) | |
| kwargs.pop("alpha_base", None) | |
| kwargs.pop("alpha_in_blocks", None) | |
| kwargs.pop("alpha_mid_block", None) | |
| kwargs.pop("alpha_out_blocks", None) | |
| kwargs.pop("alpha_preset", None) | |
| if kwargs.get("beta_base", None) and kwargs.get("beta_in_blocks", None) and kwargs.get("beta_mid_block", None) and kwargs.get("beta_out_blocks", None): | |
| try: | |
| beta = [float(x) for x in | |
| [kwargs["beta_base"]] + kwargs["beta_in_blocks"].split(",") + [kwargs["beta_mid_block"]] + kwargs["beta_out_blocks"].split(",")] | |
| assert len(beta) == 26 or len(beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL)" | |
| kwargs["beta"] = beta | |
| except KeyError as ke: | |
| shared.log.warning(f"Merge: Malformed manual block weight: {ke}") | |
| elif kwargs.get("beta_preset", None) or kwargs.get("beta", None): | |
| kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"]) | |
| kwargs.pop("beta_base", None) | |
| kwargs.pop("beta_in_blocks", None) | |
| kwargs.pop("beta_mid_block", None) | |
| kwargs.pop("beta_out_blocks", None) | |
| kwargs.pop("beta_preset", None) | |
| if kwargs["device"] == "gpu": | |
| kwargs["device"] = devices.device | |
| elif kwargs["device"] == "shuffle": | |
| kwargs["device"] = torch.device("cpu") | |
| kwargs["work_device"] = devices.device | |
| else: | |
| kwargs["device"] = torch.device("cpu") | |
| if kwargs.pop("unload", False): | |
| sd_models.unload_model_weights() | |
| try: | |
| theta_0 = merge_models(**kwargs) | |
| except Exception as e: | |
| return fail(f"{e}") | |
| try: | |
| theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary | |
| except Exception: | |
| pass | |
| bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None) | |
| if bake_in_vae_filename is not None: | |
| shared.log.info(f"Merge VAE='{bake_in_vae_filename}'") | |
| shared.state.textinfo = 'Merge VAE' | |
| vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename) | |
| for key in vae_dict.keys(): | |
| theta_0_key = 'first_stage_model.' + key | |
| if theta_0_key in theta_0: | |
| theta_0[theta_0_key] = to_half(vae_dict[key], kwargs.get("precision", "fp16") == "fp16") | |
| del vae_dict | |
| ckpt_dir = shared.opts.ckpt_dir or sd_models.model_path | |
| filename = kwargs.get("custom_name", "Unnamed_Merge") | |
| filename += "." + kwargs.get("checkpoint_format", None) | |
| output_modelname = os.path.join(ckpt_dir, filename) | |
| shared.state.textinfo = "merge saving" | |
| metadata = None | |
| if kwargs.get("save_metadata", False): | |
| metadata = {"format": "pt", "sd_merge_models": {}} | |
| merge_recipe = { | |
| "type": "SDNext", # indicate this model was merged with webui's built-in merger | |
| "primary_model_hash": primary_model_info.sha256, | |
| "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None, | |
| "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None, | |
| "merge_mode": kwargs.get('merge_mode', None), | |
| "alpha": kwargs.get('alpha', None), | |
| "beta": kwargs.get('beta', None), | |
| "precision": kwargs.get('precision', None), | |
| "custom_name": kwargs.get("custom_name", "Unamed_Merge"), | |
| } | |
| metadata["sd_merge_recipe"] = json.dumps(merge_recipe) | |
| def add_model_metadata(checkpoint_info): | |
| checkpoint_info.calculate_shorthash() | |
| metadata["sd_merge_models"][checkpoint_info.sha256] = { | |
| "name": checkpoint_info.name, | |
| "legacy_hash": checkpoint_info.hash, | |
| "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None) | |
| } | |
| metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {})) | |
| add_model_metadata(primary_model_info) | |
| if secondary_model_info: | |
| add_model_metadata(secondary_model_info) | |
| if tertiary_model_info: | |
| add_model_metadata(tertiary_model_info) | |
| metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) | |
| _, extension = os.path.splitext(output_modelname) | |
| if os.path.exists(output_modelname) and not kwargs.get("overwrite", False): | |
| return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"] | |
| if extension.lower() == ".safetensors": | |
| safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) | |
| else: | |
| torch.save(theta_0, output_modelname) | |
| t1 = time.time() | |
| shared.log.info(f"Merge complete: saved='{output_modelname}' time={t1-t0:.2f}") | |
| sd_models.list_models() | |
| created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None) | |
| if created_model: | |
| created_model.calculate_shorthash() | |
| devices.torch_gc(force=True) | |
| shared.state.end() | |
| return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model saved to {output_modelname}"] | |
| def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv, | |
| vae_conv, others_conv, fix_clip): | |
| # position_ids in clip is int64. model_ema.num_updates is int32 | |
| dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} | |
| dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} | |
| def conv_fp16(t: torch.Tensor): | |
| return t.half() if t.dtype in dtypes_to_fp16 else t | |
| def conv_bf16(t: torch.Tensor): | |
| return t.bfloat16() if t.dtype in dtypes_to_bf16 else t | |
| def conv_full(t): | |
| return t | |
| _g_precision_func = { | |
| "full": conv_full, | |
| "fp32": conv_full, | |
| "fp16": conv_fp16, | |
| "bf16": conv_bf16, | |
| } | |
| def check_weight_type(k: str) -> str: | |
| if k.startswith("model.diffusion_model"): | |
| return "unet" | |
| elif k.startswith("first_stage_model"): | |
| return "vae" | |
| elif k.startswith("cond_stage_model"): | |
| return "clip" | |
| return "other" | |
| def load_model(path): | |
| if path.endswith(".safetensors"): | |
| m = safetensors.torch.load_file(path, device="cpu") | |
| else: | |
| m = torch.load(path, map_location="cpu") | |
| state_dict = m["state_dict"] if "state_dict" in m else m | |
| return state_dict | |
| def fix_model(model, fix_clip=False): | |
| # code from model-toolkit | |
| nai_keys = { | |
| 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', | |
| 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', | |
| 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.' | |
| } | |
| for k in list(model.keys()): | |
| for r in nai_keys: | |
| if type(k) == str and k.startswith(r): | |
| new_key = k.replace(r, nai_keys[r]) | |
| model[new_key] = model[k] | |
| del model[k] | |
| shared.log.warning(f"Model convert: fixed NovelAI error key: {k}") | |
| break | |
| if fix_clip: | |
| i = "cond_stage_model.transformer.text_model.embeddings.position_ids" | |
| if i in model: | |
| correct = torch.Tensor([list(range(77))]).to(torch.int64) | |
| now = model[i].to(torch.int64) | |
| broken = correct.ne(now) | |
| broken = [i for i in range(77) if broken[0][i]] | |
| model[i] = correct | |
| if len(broken) != 0: | |
| shared.log.warning(f"Model convert: fixed broken CLiP: {broken}") | |
| return model | |
| if model == "": | |
| return "Error: you must choose a model" | |
| if len(checkpoint_formats) == 0: | |
| return "Error: at least choose one model save format" | |
| extra_opt = { | |
| "unet": unet_conv, | |
| "clip": text_encoder_conv, | |
| "vae": vae_conv, | |
| "other": others_conv | |
| } | |
| shared.state.begin('convert') | |
| model_info = sd_models.checkpoints_list[model] | |
| shared.state.textinfo = f"Loading {model_info.filename}..." | |
| shared.log.info(f"Model convert loading: {model_info.filename}") | |
| state_dict = load_model(model_info.filename) | |
| ok = {} # {"state_dict": {}} | |
| conv_func = _g_precision_func[precision] | |
| def _hf(wk: str, t: torch.Tensor): | |
| if not isinstance(t, torch.Tensor): | |
| return | |
| w_t = check_weight_type(wk) | |
| conv_t = extra_opt[w_t] | |
| if conv_t == "convert": | |
| ok[wk] = conv_func(t) | |
| elif conv_t == "copy": | |
| ok[wk] = t | |
| elif conv_t == "delete": | |
| return | |
| shared.log.info("Model convert: running") | |
| if conv_type == "ema-only": | |
| for k in tqdm.tqdm(state_dict): | |
| ema_k = "___" | |
| try: | |
| ema_k = "model_ema." + k[6:].replace(".", "") | |
| except Exception: | |
| pass | |
| if ema_k in state_dict: | |
| _hf(k, state_dict[ema_k]) | |
| elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: | |
| _hf(k, state_dict[k]) | |
| elif conv_type == "no-ema": | |
| for k, v in tqdm.tqdm(state_dict.items()): | |
| if "model_ema." not in k: | |
| _hf(k, v) | |
| else: | |
| for k, v in tqdm.tqdm(state_dict.items()): | |
| _hf(k, v) | |
| ok = fix_model(ok, fix_clip=fix_clip) | |
| output = "" | |
| ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path | |
| save_name = f"{model_info.model_name}-{precision}" | |
| if conv_type != "disabled": | |
| save_name += f"-{conv_type}" | |
| if custom_name != "": | |
| save_name = custom_name | |
| for fmt in checkpoint_formats: | |
| ext = ".safetensors" if fmt == "safetensors" else ".ckpt" | |
| _save_name = save_name + ext | |
| save_path = os.path.join(ckpt_dir, _save_name) | |
| shared.log.info(f"Model convert saving: {save_path}") | |
| if fmt == "safetensors": | |
| safetensors.torch.save_file(ok, save_path) | |
| else: | |
| torch.save({"state_dict": ok}, save_path) | |
| output += f"Checkpoint saved to {save_path}<br>" | |
| shared.state.end() | |
| return output | |