|
import torch |
|
import numpy as np |
|
from functools import reduce, partial |
|
from operator import mul |
|
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations |
|
|
|
|
|
def chain_functions(*functions): |
|
return lambda initial: reduce(lambda x, f: f(x), functions, initial) |
|
|
|
|
|
def remove_fx_parametrisation(fx): |
|
def remover(m): |
|
if not is_parametrized(m): |
|
return |
|
for k in list(m.parametrizations.keys()): |
|
remove_parametrizations(m, k) |
|
|
|
fx.apply(remover) |
|
return fx |
|
|
|
|
|
def get_chunks(keys, original_shapes): |
|
(position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys)) |
|
original_chunks = list(map(partial(reduce, mul), original_shapes)) |
|
U_matrix_shape = original_shapes[position] |
|
|
|
dimensions_not_need = np.ravel_multi_index( |
|
np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape |
|
) + sum(original_chunks[:position]) |
|
|
|
selected_chunks = ( |
|
original_chunks[:position] |
|
+ [original_chunks[position] - dimensions_not_need.size] |
|
+ original_chunks[position + 1 :] |
|
) |
|
return selected_chunks, position, U_matrix_shape, dimensions_not_need |
|
|
|
|
|
def vec2statedict( |
|
x: torch.Tensor, |
|
keys, |
|
original_shapes, |
|
selected_chunks, |
|
position, |
|
U_matrix_shape, |
|
): |
|
chunks = list(torch.split(x, selected_chunks)) |
|
U = x.new_zeros(reduce(mul, U_matrix_shape)) |
|
U[ |
|
np.ravel_multi_index( |
|
np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]), |
|
U_matrix_shape, |
|
) |
|
] = chunks[position] |
|
chunks[position] = U |
|
|
|
state_dict = dict( |
|
zip( |
|
keys, |
|
map(lambda x, shape: x.reshape(*shape), chunks, original_shapes), |
|
) |
|
) |
|
return state_dict |
|
|