|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .parametrized_model import ParametrizedModel, ParametrizedModelConfig |
|
|
|
|
|
|
|
|
class ACIPModelConfig(ParametrizedModelConfig): |
|
|
""" |
|
|
Configuration for `ACIPModel`. Same functionality as `ParametrizedModelConfig`. |
|
|
|
|
|
See Also: |
|
|
- `ParametrizedModelConfig` |
|
|
- `ACIPModel` |
|
|
""" |
|
|
|
|
|
model_type = "acip_model" |
|
|
|
|
|
|
|
|
class ACIPModel(ParametrizedModel): |
|
|
""" |
|
|
This class extends `ParametrizedModel` by additional functionality required for ACIP. |
|
|
It manages a `score_map` that stores the scores of the parametrized modules' target parameters, |
|
|
which are updated during tuning by the ACIP method. |
|
|
Moreover, it provides `prune_model_by_score` that prunes the target parameters of the model according to |
|
|
their scores to achieve any given compression ratio. |
|
|
|
|
|
Notes: The `score_map` is managed in float32 internally because a lower precision may lead to unexpected numerical |
|
|
inaccuracies in the resulting parameter ranking. Fortunately, the memory consumption is negligible compared to |
|
|
the model weights itself. |
|
|
|
|
|
See Also: |
|
|
- `ParametrizedModel` |
|
|
- `ACIPModelConfig` |
|
|
""" |
|
|
|
|
|
config_class = ACIPModelConfig |
|
|
|
|
|
def __init__(self, config: ACIPModelConfig, base_model: PreTrainedModel | None = None, **_: Any): |
|
|
super().__init__(config, base_model) |
|
|
self.config = config |
|
|
|
|
|
self._score_map: dict[str, torch.Tensor] | None = None |
|
|
|
|
|
|
|
|
self._init_score_map_buffers() |
|
|
|
|
|
def _init_score_map_buffers(self): |
|
|
""" |
|
|
Register and initialize score map buffers in parametrized modules (with random numbers). |
|
|
Each target parameter "p_name" is associated with a buffer "p_name_score" that stores its score vector. |
|
|
""" |
|
|
for m_name, module in self.parametrized_modules.items(): |
|
|
for p_name, param in module.parametrization.get_target_params().items(): |
|
|
module.parametrization.register_buffer(p_name + "_score", torch.ones_like(param.data).float()) |
|
|
|
|
|
def _update_score_map(self): |
|
|
"""Render `score_map` from the parametrized modules' score buffers.""" |
|
|
self._score_map = {} |
|
|
for m_name, module in self.parametrized_modules.items(): |
|
|
for p_name in module.parametrization.get_target_params().keys(): |
|
|
self._score_map[f"{m_name}.parametrization.{p_name}"] = module.parametrization.get_buffer( |
|
|
p_name + "_score" |
|
|
) |
|
|
|
|
|
@property |
|
|
def score_map(self) -> dict[str, torch.Tensor]: |
|
|
"""Returns the score map as Tensor dictionary whose keys match those of `self.get_target_params`.""" |
|
|
if self._score_map is None: |
|
|
self._update_score_map() |
|
|
return self._score_map |
|
|
|
|
|
@score_map.setter |
|
|
def score_map(self, score_map: dict[str, torch.Tensor]) -> None: |
|
|
""" |
|
|
Updates `score_map` and the corresponding parametrized modules' score buffers. |
|
|
|
|
|
Args: |
|
|
score_map: Dictionary whose keys should match (a subset of) `self.get_target_params`. |
|
|
""" |
|
|
if self._score_map is None: |
|
|
self._update_score_map() |
|
|
|
|
|
for p_name, score in score_map.items(): |
|
|
buffer = self.model.get_buffer(p_name + "_score") |
|
|
if buffer.shape != score.shape: |
|
|
raise ValueError( |
|
|
f"Score map for '{p_name}' has incorrect shape: expected {buffer.shape}, got {score.shape}" |
|
|
) |
|
|
|
|
|
buffer.copy_(score.detach().float()) |
|
|
self._score_map[p_name] = buffer |
|
|
|
|
|
def _predict_compression_ratio_by_score(self, k: int, full: bool = False) -> tuple[float, dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Helper function that checks what would happen if the k smallest target parameters are pruned |
|
|
according to the global score map ranking. It returns the resulting compression ratio |
|
|
and the corresponding parameter masks. |
|
|
|
|
|
Args: |
|
|
k: Number of target parameters to prune. |
|
|
full: Whether to count the number of parameters of the entire model or only the parametrized modules. |
|
|
See also `ParametrizedModel.get_num_params`. |
|
|
|
|
|
Returns: Tuple of compression ratio and parameter masks. The masks indicate which parameters to keep. |
|
|
""" |
|
|
|
|
|
score_map_cat = torch.cat([param.flatten() for param in self.score_map.values()]) |
|
|
threshold = torch.kthvalue(score_map_cat, k).values.item() |
|
|
|
|
|
|
|
|
param_masks = {} |
|
|
for p_name, score in self.score_map.items(): |
|
|
param_masks[p_name] = (score > threshold).to(dtype=score.dtype) |
|
|
|
|
|
|
|
|
compression_ratio = self.get_compression_ratio(full=full, target_params=param_masks) |
|
|
return compression_ratio, param_masks |
|
|
|
|
|
def _get_param_masks(self, compression_ratio: float, full: bool = False) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Helper function that determines which parameters to keep to reach a target compression ratio. |
|
|
Instead of looping over `k -> _predict_compression_ratio_by_score(k)`, a binary search can be used because |
|
|
the compression ratio is monotonically increasing in k. |
|
|
|
|
|
Args: |
|
|
compression_ratio: Target compression ratio. |
|
|
full: Whether to count the number of parameters of the entire model or only the parametrized modules. |
|
|
See also `ParametrizedModel.get_num_params`. |
|
|
|
|
|
Returns: Parameter masks indicating which parameters to keep to reach the target compression ratio. |
|
|
""" |
|
|
if compression_ratio == 1.0: |
|
|
return {p_name: torch.ones_like(score) for p_name, score in self.score_map.items()} |
|
|
|
|
|
|
|
|
|
|
|
k_lo, k_hi = 1, sum(score.numel() for score in self.score_map.values()) |
|
|
while k_lo < k_hi: |
|
|
k_mid = (k_lo + k_hi + 1) // 2 |
|
|
ratio, _ = self._predict_compression_ratio_by_score(k=k_mid, full=full) |
|
|
if ratio > compression_ratio: |
|
|
k_lo = k_mid |
|
|
else: |
|
|
k_hi = k_mid - 1 |
|
|
k = k_lo |
|
|
|
|
|
return self._predict_compression_ratio_by_score(k=k, full=full)[1] |
|
|
|
|
|
def prune_model_by_score(self, compression_ratio: float, full: bool = False) -> None: |
|
|
""" |
|
|
This method prunes the target parameters of the model according to their scores to achieve |
|
|
a given compression ratio. |
|
|
|
|
|
This can be efficiently implemented by a simple binary search strategy: |
|
|
We find the smallest number of parameters to be pruned according to the score map ranking |
|
|
such that the resulting compression ratio is at least the target `compression_ratio`. |
|
|
|
|
|
Args: |
|
|
compression_ratio: The target compression ratio. |
|
|
full: Whether to count the number of parameters of the entire model or only the parametrized modules. |
|
|
See also `ParametrizedModel.get_num_params`. |
|
|
""" |
|
|
param_masks = self._get_param_masks(compression_ratio=compression_ratio, full=full) |
|
|
|
|
|
|
|
|
for p_name, param in self.get_target_params().items(): |
|
|
param.data[param_masks[p_name] > 0.0] = 1.0 |
|
|
param.data[param_masks[p_name] == 0.0] = 0.0 |
|
|
for m_name, module in self.parametrized_modules.items(): |
|
|
if any(p_name.startswith(m_name) for p_name in param_masks.keys()): |
|
|
module.parametrization.reset_target_params(mode="nonzero") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACIPModelConfig.register_for_auto_class() |
|
|
ACIPModel.register_for_auto_class("AutoModel") |
|
|
|