Spaces:
Runtime error
Runtime error
| # https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion | |
| from collections import defaultdict | |
| from random import shuffle | |
| from typing import NamedTuple | |
| import torch | |
| from scipy.optimize import linear_sum_assignment | |
| from modules.shared import log | |
| SPECIAL_KEYS = [ | |
| "first_stage_model.decoder.norm_out.weight", | |
| "first_stage_model.decoder.norm_out.bias", | |
| "first_stage_model.encoder.norm_out.weight", | |
| "first_stage_model.encoder.norm_out.bias", | |
| "model.diffusion_model.out.0.weight", | |
| "model.diffusion_model.out.0.bias", | |
| ] | |
| class PermutationSpec(NamedTuple): | |
| perm_to_axes: dict | |
| axes_to_perm: dict | |
| def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec: | |
| perm_to_axes = defaultdict(list) | |
| for wk, axis_perms in axes_to_perm.items(): | |
| for axis, perm in enumerate(axis_perms): | |
| if perm is not None: | |
| perm_to_axes[perm].append((wk, axis)) | |
| return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm) | |
| def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None): | |
| """Get parameter `k` from `params`, with the permutations applied.""" | |
| w = params[k] | |
| for axis, p in enumerate(ps.axes_to_perm[k]): | |
| # Skip the axis we're trying to permute. | |
| if axis == except_axis: | |
| continue | |
| # None indicates that there is no permutation relevant to that axis. | |
| if p: | |
| w = torch.index_select(w, axis, perm[p].int()) | |
| return w | |
| def apply_permutation(ps: PermutationSpec, perm, params): | |
| """Apply a `perm` to `params`.""" | |
| return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()} | |
| def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): | |
| for k in model_a: | |
| try: | |
| perm_params = get_permuted_param( | |
| ps, perm, k, model_a | |
| ) | |
| model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params | |
| except RuntimeError: # dealing with pix2pix and inpainting models | |
| continue | |
| return model_a | |
| def inner_matching( | |
| n, | |
| ps, | |
| p, | |
| params_a, | |
| params_b, | |
| usefp16, | |
| progress, | |
| number, | |
| linear_sum, | |
| perm, | |
| device, | |
| ): | |
| A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n)) | |
| A = A.to(device) | |
| for wk, axis in ps.perm_to_axes[p]: | |
| w_a = params_a[wk] | |
| w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) | |
| w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device) | |
| w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device) | |
| if usefp16: | |
| w_a = w_a.half().to(device) | |
| w_b = w_b.half().to(device) | |
| try: | |
| A += torch.matmul(w_a, w_b) | |
| except RuntimeError: | |
| A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b)) | |
| A = A.cpu() | |
| ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True) | |
| A = A.to(device) | |
| assert (torch.tensor(ri) == torch.arange(len(ri))).all() | |
| eye_tensor = torch.eye(n).to(device) | |
| oldL = torch.vdot( | |
| torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()]) | |
| ) | |
| newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :])) | |
| if usefp16: | |
| oldL = oldL.half() | |
| newL = newL.half() | |
| if newL - oldL != 0: | |
| linear_sum += abs((newL - oldL).item()) | |
| number += 1 | |
| log.debug(f"Merge Rebasin permutation: {p}={newL-oldL}") | |
| progress = progress or newL > oldL + 1e-12 | |
| perm[p] = torch.Tensor(ci).to(device) | |
| return linear_sum, number, perm, progress | |
| def weight_matching( | |
| ps: PermutationSpec, | |
| params_a, | |
| params_b, | |
| max_iter=1, | |
| init_perm=None, | |
| usefp16=False, | |
| device="cpu", | |
| ): | |
| perm_sizes = { | |
| p: params_a[axes[0][0]].shape[axes[0][1]] | |
| for p, axes in ps.perm_to_axes.items() | |
| if axes[0][0] in params_a.keys() | |
| } | |
| perm = {} | |
| perm = ( | |
| {p: torch.arange(n).to(device) for p, n in perm_sizes.items()} | |
| if init_perm is None | |
| else init_perm | |
| ) | |
| linear_sum = 0 | |
| number = 0 | |
| special_layers = ["P_bg324"] | |
| for _i in range(max_iter): | |
| progress = False | |
| shuffle(special_layers) | |
| for p in special_layers: | |
| n = perm_sizes[p] | |
| linear_sum, number, perm, progress = inner_matching( | |
| n, | |
| ps, | |
| p, | |
| params_a, | |
| params_b, | |
| usefp16, | |
| progress, | |
| number, | |
| linear_sum, | |
| perm, | |
| device, | |
| ) | |
| progress = True | |
| if not progress: | |
| break | |
| average = linear_sum / number if number > 0 else 0 | |
| return perm, average | |