acip_llama2_13b / acip_model.py
martingenzel's picture
Add model
985b259 verified
raw
history blame
8.71 kB
from typing import Any
import torch
from transformers import PreTrainedModel
from .parametrized_model import ParametrizedModel, ParametrizedModelConfig
class ACIPModelConfig(ParametrizedModelConfig):
"""
Configuration for `ACIPModel`. Same functionality as `ParametrizedModelConfig`.
See Also:
- `ParametrizedModelConfig`
- `ACIPModel`
"""
model_type = "acip_model"
class ACIPModel(ParametrizedModel):
"""
This class extends `ParametrizedModel` by additional functionality required for ACIP.
It manages a `score_map` that stores the scores of the parametrized modules' target parameters,
which are updated during tuning by the ACIP method.
Moreover, it provides `prune_model_by_score` that prunes the target parameters of the model according to
their scores to achieve any given compression ratio.
Notes: The `score_map` is managed in float32 internally because a lower precision may lead to unexpected numerical
inaccuracies in the resulting parameter ranking. Fortunately, the memory consumption is negligible compared to
the model weights itself.
See Also:
- `ParametrizedModel`
- `ACIPModelConfig`
"""
config_class = ACIPModelConfig
def __init__(self, config: ACIPModelConfig, base_model: PreTrainedModel | None = None, **_: Any):
super().__init__(config, base_model)
self.config = config # redundant but enables type hinting for ACIPModelConfig
self._score_map: dict[str, torch.Tensor] | None = None
# Register and initialize score map buffers
# Important: don't run _update_score_map here because load_state_dict might still override the buffers
self._init_score_map_buffers()
def _init_score_map_buffers(self):
"""
Register and initialize score map buffers in parametrized modules (with random numbers).
Each target parameter "p_name" is associated with a buffer "p_name_score" that stores its score vector.
"""
for m_name, module in self.parametrized_modules.items():
for p_name, param in module.parametrization.get_target_params().items():
module.parametrization.register_buffer(p_name + "_score", torch.ones_like(param.data).float())
def _update_score_map(self):
"""Render `score_map` from the parametrized modules' score buffers."""
self._score_map = {}
for m_name, module in self.parametrized_modules.items():
for p_name in module.parametrization.get_target_params().keys():
self._score_map[f"{m_name}.parametrization.{p_name}"] = module.parametrization.get_buffer(
p_name + "_score"
)
@property
def score_map(self) -> dict[str, torch.Tensor]:
"""Returns the score map as Tensor dictionary whose keys match those of `self.get_target_params`."""
if self._score_map is None:
self._update_score_map()
return self._score_map
@score_map.setter
def score_map(self, score_map: dict[str, torch.Tensor]) -> None:
"""
Updates `score_map` and the corresponding parametrized modules' score buffers.
Args:
score_map: Dictionary whose keys should match (a subset of) `self.get_target_params`.
"""
if self._score_map is None:
self._update_score_map()
# score_map.keys() can be a subset of self.get_target_params().keys()
for p_name, score in score_map.items():
buffer = self.model.get_buffer(p_name + "_score")
if buffer.shape != score.shape:
raise ValueError(
f"Score map for '{p_name}' has incorrect shape: expected {buffer.shape}, got {score.shape}"
)
# cast to float32 to avoid numerical instabilities
buffer.copy_(score.detach().float())
self._score_map[p_name] = buffer
def _predict_compression_ratio_by_score(self, k: int, full: bool = False) -> tuple[float, dict[str, torch.Tensor]]:
"""
Helper function that checks what would happen if the k smallest target parameters are pruned
according to the global score map ranking. It returns the resulting compression ratio
and the corresponding parameter masks.
Args:
k: Number of target parameters to prune.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
Returns: Tuple of compression ratio and parameter masks. The masks indicate which parameters to keep.
"""
# Find the threshold value for the k smallest entries according to the global score map ranking.
score_map_cat = torch.cat([param.flatten() for param in self.score_map.values()])
threshold = torch.kthvalue(score_map_cat, k).values.item()
# Create a set of parameter masks marking which values to keep.
param_masks = {}
for p_name, score in self.score_map.items():
param_masks[p_name] = (score > threshold).to(dtype=score.dtype)
# Compute hypothetical compression ratio if param_masks would be used as masks for the target parameters.
compression_ratio = self.get_compression_ratio(full=full, target_params=param_masks)
return compression_ratio, param_masks
def _get_param_masks(self, compression_ratio: float, full: bool = False) -> dict[str, torch.Tensor]:
"""
Helper function that determines which parameters to keep to reach a target compression ratio.
Instead of looping over `k -> _predict_compression_ratio_by_score(k)`, a binary search can be used because
the compression ratio is monotonically increasing in k.
Args:
compression_ratio: Target compression ratio.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
Returns: Parameter masks indicating which parameters to keep to reach the target compression ratio.
"""
if compression_ratio == 1.0:
return {p_name: torch.ones_like(score) for p_name, score in self.score_map.items()}
# Perform a binary search to find the smallest k such that the compression ratio is at least compression_ratio.
# Here, k_lo and k_hi are the lower and upper bound of the search interval.
k_lo, k_hi = 1, sum(score.numel() for score in self.score_map.values())
while k_lo < k_hi:
k_mid = (k_lo + k_hi + 1) // 2 # round up to ensure low <= mid
ratio, _ = self._predict_compression_ratio_by_score(k=k_mid, full=full)
if ratio > compression_ratio:
k_lo = k_mid
else:
k_hi = k_mid - 1
k = k_lo
# TODO: handle tie-breaks
return self._predict_compression_ratio_by_score(k=k, full=full)[1]
def prune_model_by_score(self, compression_ratio: float, full: bool = False) -> None:
"""
This method prunes the target parameters of the model according to their scores to achieve
a given compression ratio.
This can be efficiently implemented by a simple binary search strategy:
We find the smallest number of parameters to be pruned according to the score map ranking
such that the resulting compression ratio is at least the target `compression_ratio`.
Args:
compression_ratio: The target compression ratio.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
"""
param_masks = self._get_param_masks(compression_ratio=compression_ratio, full=full)
# Reset the target parameters according to the parameter masks
for p_name, param in self.get_target_params().items():
param.data[param_masks[p_name] > 0.0] = 1.0 # dummy value, will be rescaled by reset_target_params
param.data[param_masks[p_name] == 0.0] = 0.0
for m_name, module in self.parametrized_modules.items():
if any(p_name.startswith(m_name) for p_name in param_masks.keys()):
module.parametrization.reset_target_params(mode="nonzero")
# Register ACIPModelConfig and ACIPModel for AutoModel
# Required to push custom model to Huggingface Hub (see https://huggingface.co/docs/transformers/en/custom_models)
ACIPModelConfig.register_for_auto_class()
ACIPModel.register_for_auto_class("AutoModel")