Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Tuple | |
| import torch | |
| from torch import Tensor | |
| __all__ = [ | |
| "weighted_sum", | |
| "weighted_subtraction", | |
| "tensor_sum", | |
| "add_difference", | |
| "sum_twice", | |
| "triple_sum", | |
| "euclidean_add_difference", | |
| "multiply_difference", | |
| "top_k_tensor_sum", | |
| "similarity_add_difference", | |
| "distribution_crossover", | |
| "ties_add_difference", | |
| ] | |
| EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero | |
| def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Basic Merge: | |
| alpha 0 returns Primary Model | |
| alpha 1 returns Secondary Model | |
| """ | |
| return (1 - alpha) * a + alpha * b | |
| def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| The inverse of a Weighted Sum Merge | |
| Returns Primary Model when alpha*beta = 0 | |
| High values of alpha*beta are likely to break the merged model | |
| """ | |
| # Adjust beta if both alpha and beta are 1.0 to avoid division by zero | |
| if alpha == 1.0 and beta == 1.0: | |
| beta -= EPSILON | |
| return (a - alpha * beta * b) / (1 - alpha * beta) | |
| def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Takes a slice of Secondary Model and pastes it into Primary Model | |
| Alpha sets the width of the slice | |
| Beta sets the start point of the slice | |
| ie Alpha = 0.5 Beta = 0.25 is (ABBA) Alpha = 0.25 Beta = 0 is (BAAA) | |
| """ | |
| if alpha + beta <= 1: | |
| tt = a.clone() | |
| talphas = int(a.shape[0] * beta) | |
| talphae = int(a.shape[0] * (alpha + beta)) | |
| tt[talphas:talphae] = b[talphas:talphae].clone() | |
| else: | |
| talphas = int(a.shape[0] * (alpha + beta - 1)) | |
| talphae = int(a.shape[0] * beta) | |
| tt = b.clone() | |
| tt[talphas:talphae] = a[talphas:talphae].clone() | |
| return tt | |
| def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Classic Add Difference Merge | |
| """ | |
| return a + alpha * (b - c) | |
| def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Stacked Basic Merge: | |
| Equivalent to Merging Primary and Secondary @ alpha | |
| Then merging the result with Tertiary @ beta | |
| """ | |
| return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c | |
| def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Weights Secondary and Tertiary at alpha and beta respectively | |
| Fills in the rest with Primary | |
| Expect odd results if alpha + beta > 1 as Primary will be merged with a negative ratio | |
| """ | |
| return (1 - alpha - beta) * a + alpha * b + beta * c | |
| def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Subtract Primary and Secondary from Tertiary | |
| Compare the remainders via Euclidean distance | |
| Add to Tertiary | |
| Note: Slow | |
| """ | |
| a_diff = a.float() - c.float() | |
| b_diff = b.float() - c.float() | |
| a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff)) | |
| b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff)) | |
| distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2 | |
| distance = torch.sqrt(distance) | |
| sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float() | |
| distance = torch.copysign(distance, sum_diff) | |
| target_norm = torch.linalg.norm(sum_diff) | |
| return c + distance / torch.linalg.norm(distance) * target_norm | |
| def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Similar to Add Difference but with geometric mean instead of arithmatic mean | |
| """ | |
| diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha)) | |
| diff_b = torch.pow(torch.abs(b.float() - c), alpha) | |
| difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c) | |
| return c + difference.to(c.dtype) | |
| def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Redistributes the largest weights of Secondary Model into Primary Model | |
| """ | |
| a_flat = torch.flatten(a) | |
| a_dist = torch.msort(a_flat) | |
| b_indices = torch.argsort(torch.flatten(b), stable=True) | |
| redist_indices = torch.argsort(b_indices) | |
| start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a)) | |
| start_top_k = kth_abs_value(a_dist, start_i) | |
| end_top_k = kth_abs_value(a_dist, end_i) | |
| indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) | |
| if region_is_inverted: | |
| indices_mask = ~indices_mask | |
| indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) | |
| a_redist = torch.gather(a_dist, 0, redist_indices) | |
| a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist | |
| return a_redist.reshape_as(a) | |
| def kth_abs_value(a: Tensor, k: int) -> Tensor: | |
| if k <= 0: | |
| return torch.tensor(-1, device=a.device) | |
| else: | |
| return torch.kthvalue(torch.abs(a.float()), k)[0] | |
| def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: | |
| if width < 0: | |
| offset += width | |
| width = -width | |
| width = min(width, 1) | |
| if offset < 0: | |
| offset = 1 + offset - int(offset) | |
| offset = math.fmod(offset, 1.0) | |
| if width + offset <= 1: | |
| inverted = False | |
| start = offset * n | |
| end = (width + offset) * n | |
| else: | |
| inverted = True | |
| start = (width + offset - 1) * n | |
| end = offset * n | |
| return round(start), round(end), inverted | |
| def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar | |
| """ | |
| threshold = torch.maximum(torch.abs(a), torch.abs(b)) | |
| similarity = ((a * b / threshold**2) + 1) / 2 | |
| similarity = torch.nan_to_num(similarity * beta, nan=beta) | |
| ab_diff = a + alpha * (b - c) | |
| ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b | |
| return (1 - similarity) * ab_diff + similarity * ab_sum | |
| def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument | |
| """ | |
| From the creator: | |
| It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of | |
| Primary and Secondary when ordered with respect to Tertiary. Split the frequency domain | |
| using a linear function. Alpha is the split frequency and Beta is the inclination of the line. | |
| add everything under the line as the contribution of Primary and everything over the line as the contribution of Secondary | |
| """ | |
| if a.shape == (): | |
| return alpha * a + (1 - alpha) * b | |
| c_indices = torch.argsort(torch.flatten(c)) | |
| a_dist = torch.gather(torch.flatten(a), 0, c_indices) | |
| b_dist = torch.gather(torch.flatten(b), 0, c_indices) | |
| a_dft = torch.fft.rfft(a_dist.float()) | |
| b_dft = torch.fft.rfft(b_dist.float()) | |
| dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float() | |
| dft_filter /= torch.numel(a_dft) | |
| if beta > EPSILON: | |
| dft_filter = (dft_filter - alpha) / beta + 1 / 2 | |
| dft_filter = torch.clamp(dft_filter, 0.0, 1.0) | |
| else: | |
| dft_filter = (dft_filter >= alpha).float() | |
| x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft | |
| x_dist = torch.fft.irfft(x_dft, a_dist.shape[0]) | |
| x_values = torch.gather(x_dist, 0, torch.argsort(c_indices)) | |
| return x_values.reshape_as(a) | |
| def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument | |
| """ | |
| An implementation of arXiv:2306.01708 | |
| """ | |
| deltas = [] | |
| signs = [] | |
| for m in [a, b]: | |
| deltas.append(filter_top_k(m - c, beta)) | |
| signs.append(torch.sign(deltas[-1])) | |
| signs = torch.stack(signs, dim=0) | |
| final_sign = torch.sign(torch.sum(signs, dim=0)) | |
| delta_filters = (signs == final_sign).float() | |
| res = torch.zeros_like(c, device=c.device) | |
| for delta_filter, delta in zip(delta_filters, deltas): | |
| res += delta_filter * delta | |
| param_count = torch.sum(delta_filters, dim=0) | |
| return c + alpha * torch.nan_to_num(res / param_count) | |
| def filter_top_k(a: Tensor, k: float): | |
| k = max(int((1 - k) * torch.numel(a)), 1) | |
| k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) | |
| top_k_filter = (torch.abs(a) >= k_value).float() | |
| return a * top_k_filter | |