Fix SDPA & Flash-Attention
#7
by
Agnellino
- opened
This PR aims at solving two issues with SDPA and flash attention.
- SDPA uses
_unmask_unattended
method of theAttentionMaskConverter
but this function appears nowhere. It is added in this PR. - Flash attention uses
_get_unpad_data
fromtransformers.models.llama.modeling_llama
, but the star import does not include it for more recent versions of transformers (>=4.48.0).
The implementation of _unmask_unattended
is a raw copy-paste of the implementation given in there, so nothing fancy to worry about: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/modeling_attn_mask_utils.py#L189
I don't know why but it seems that a lot of lines of code are changed... it's not the case, simply an import of
_get_unpad_data
and the implementation of_unmask_unattended
.
Agnellino
changed pull request status to
open