Fix SDPA & Flash-Attention

#7

This PR aims at solving two issues with SDPA and flash attention.

  • SDPA uses _unmask_unattendedmethod of the AttentionMaskConverterbut this function appears nowhere. It is added in this PR.
  • Flash attention uses _get_unpad_data from transformers.models.llama.modeling_llama, but the star import does not include it for more recent versions of transformers (>=4.48.0).

The implementation of _unmask_unattended is a raw copy-paste of the implementation given in there, so nothing fancy to worry about: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/modeling_attn_mask_utils.py#L189

I don't know why but it seems that a lot of lines of code are changed... it's not the case, simply an import of _get_unpad_data and the implementation of _unmask_unattended.

Agnellino changed pull request status to open
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment