Fix SDPA & Flash-Attention
#7
by
Agnellino
- opened
- modeling_attn_mask_utils.py +89 -0
- modeling_llama2.py +89 -87
modeling_attn_mask_utils.py
CHANGED
|
@@ -160,6 +160,95 @@ class AttentionMaskConverter:
|
|
| 160 |
|
| 161 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
def _prepare_4d_causal_attention_mask(
|
| 165 |
attention_mask: Optional[torch.Tensor],
|
|
|
|
| 160 |
|
| 161 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 162 |
|
| 163 |
+
@staticmethod
|
| 164 |
+
def _unmask_unattended(
|
| 165 |
+
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
|
| 166 |
+
):
|
| 167 |
+
# fmt: off
|
| 168 |
+
"""
|
| 169 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
| 170 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 171 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
| 172 |
+
|
| 173 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
| 174 |
+
`attention_mask` is [bsz, src_seq_len].
|
| 175 |
+
|
| 176 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
| 177 |
+
|
| 178 |
+
For example, if `attention_mask` is
|
| 179 |
+
```
|
| 180 |
+
[[0, 0, 1],
|
| 181 |
+
[1, 1, 1],
|
| 182 |
+
[0, 1, 1]]
|
| 183 |
+
```
|
| 184 |
+
and `expanded_mask` is (e.g. here left-padding case)
|
| 185 |
+
```
|
| 186 |
+
[[[[0, 0, 0],
|
| 187 |
+
[0, 0, 0],
|
| 188 |
+
[0, 0, 1]]],
|
| 189 |
+
[[[1, 0, 0],
|
| 190 |
+
[1, 1, 0],
|
| 191 |
+
[1, 1, 1]]],
|
| 192 |
+
[[[0, 0, 0],
|
| 193 |
+
[0, 1, 0],
|
| 194 |
+
[0, 1, 1]]]]
|
| 195 |
+
```
|
| 196 |
+
then the modified `expanded_mask` will be
|
| 197 |
+
```
|
| 198 |
+
[[[[1, 1, 1], <-- modified
|
| 199 |
+
[1, 1, 1], <-- modified
|
| 200 |
+
[0, 0, 1]]],
|
| 201 |
+
[[[1, 0, 0],
|
| 202 |
+
[1, 1, 0],
|
| 203 |
+
[1, 1, 1]]],
|
| 204 |
+
[[[1, 1, 1], <-- modified
|
| 205 |
+
[0, 1, 0],
|
| 206 |
+
[0, 1, 1]]]]
|
| 207 |
+
```
|
| 208 |
+
"""
|
| 209 |
+
# fmt: on
|
| 210 |
+
|
| 211 |
+
# Get the index of the first non-zero value for every sample in the batch.
|
| 212 |
+
# In the above example, indices = [[2], [0], [1]]]
|
| 213 |
+
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
| 214 |
+
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
| 215 |
+
|
| 216 |
+
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
| 217 |
+
# expanded mask will be completely unattended.
|
| 218 |
+
left_masked_rows = torch.where(indices > 0)[0]
|
| 219 |
+
|
| 220 |
+
if left_masked_rows.shape[0] == 0:
|
| 221 |
+
return expanded_mask
|
| 222 |
+
indices = indices[left_masked_rows]
|
| 223 |
+
|
| 224 |
+
max_len = torch.max(indices)
|
| 225 |
+
range_tensor = torch.arange(max_len).unsqueeze(0)
|
| 226 |
+
range_tensor = range_tensor.repeat(indices.size(0), 1)
|
| 227 |
+
|
| 228 |
+
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
|
| 229 |
+
range_tensor[range_tensor >= indices] = 0
|
| 230 |
+
|
| 231 |
+
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
|
| 232 |
+
if expanded_mask.dim() == 4:
|
| 233 |
+
num_masks = expanded_mask.shape[1]
|
| 234 |
+
if num_masks == 1:
|
| 235 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
| 236 |
+
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
|
| 237 |
+
else:
|
| 238 |
+
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
|
| 239 |
+
mask_slice = (
|
| 240 |
+
left_masked_rows[:, None, None],
|
| 241 |
+
torch.arange(num_masks)[None, :, None],
|
| 242 |
+
range_tensor[:, None, :],
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
| 246 |
+
mask_slice = (left_masked_rows[:, None], range_tensor)
|
| 247 |
+
|
| 248 |
+
expanded_mask[mask_slice] = unmasked_value
|
| 249 |
+
|
| 250 |
+
return expanded_mask
|
| 251 |
+
|
| 252 |
|
| 253 |
def _prepare_4d_causal_attention_mask(
|
| 254 |
attention_mask: Optional[torch.Tensor],
|
modeling_llama2.py
CHANGED
|
@@ -8,8 +8,6 @@ import torch.nn.functional as F
|
|
| 8 |
import torch.utils.checkpoint
|
| 9 |
from torch import nn
|
| 10 |
|
| 11 |
-
|
| 12 |
-
import copy
|
| 13 |
import os
|
| 14 |
import sys
|
| 15 |
|
|
@@ -18,27 +16,28 @@ sys.path.insert(0, dir_path)
|
|
| 18 |
|
| 19 |
import transformers
|
| 20 |
from transformers.models.llama.modeling_llama import *
|
| 21 |
-
from transformers.models.llama.modeling_llama import
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.utils import logging
|
| 24 |
|
| 25 |
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
| 26 |
from .configuration_mplug_owl2 import LlamaConfig
|
| 27 |
|
|
|
|
| 28 |
class MultiwayNetwork(nn.Module):
|
| 29 |
|
| 30 |
def __init__(self, module_provider, num_multiway=2):
|
| 31 |
super(MultiwayNetwork, self).__init__()
|
| 32 |
|
| 33 |
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
| 34 |
-
|
| 35 |
def forward(self, hidden_states, multiway_indices):
|
| 36 |
|
| 37 |
if len(self.multiway) == 1:
|
| 38 |
return self.multiway[0](hidden_states)
|
| 39 |
|
| 40 |
output_hidden_states = torch.empty_like(hidden_states)
|
| 41 |
-
|
| 42 |
for idx, subway in enumerate(self.multiway):
|
| 43 |
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
| 44 |
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
|
@@ -48,9 +47,9 @@ class MultiwayNetwork(nn.Module):
|
|
| 48 |
output = output[0]
|
| 49 |
output = output.squeeze(1)
|
| 50 |
output_hidden_states[local_indices] = output
|
| 51 |
-
|
| 52 |
return output_hidden_states.contiguous()
|
| 53 |
-
|
| 54 |
|
| 55 |
class LlamaAttention(nn.Module):
|
| 56 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
@@ -65,7 +64,7 @@ class LlamaAttention(nn.Module):
|
|
| 65 |
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 66 |
"when creating this class."
|
| 67 |
)
|
| 68 |
-
|
| 69 |
self.attention_dropout = config.attention_dropout
|
| 70 |
self.hidden_size = config.hidden_size
|
| 71 |
self.num_heads = config.num_attention_heads
|
|
@@ -83,10 +82,12 @@ class LlamaAttention(nn.Module):
|
|
| 83 |
)
|
| 84 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 85 |
self.k_proj = MultiwayNetwork(module_provider=partial(
|
| 86 |
-
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
|
|
|
|
| 87 |
)
|
| 88 |
self.v_proj = MultiwayNetwork(module_provider=partial(
|
| 89 |
-
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
|
|
|
|
| 90 |
)
|
| 91 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
| 92 |
self._init_rope()
|
|
@@ -122,15 +123,15 @@ class LlamaAttention(nn.Module):
|
|
| 122 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 123 |
|
| 124 |
def forward(
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 135 |
bsz, q_len, _ = hidden_states.size()
|
| 136 |
|
|
@@ -193,7 +194,7 @@ class LlamaAttention(nn.Module):
|
|
| 193 |
attn_weights = None
|
| 194 |
|
| 195 |
return attn_output, attn_weights, past_key_value
|
| 196 |
-
|
| 197 |
|
| 198 |
class LlamaFlashAttention2(LlamaAttention):
|
| 199 |
"""
|
|
@@ -211,15 +212,15 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 211 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 212 |
|
| 213 |
def forward(
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 224 |
# LlamaFlashAttention2 attention does not support output_attentions
|
| 225 |
if "padding_mask" in kwargs:
|
|
@@ -302,7 +303,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 302 |
return attn_output, attn_weights, past_key_value
|
| 303 |
|
| 304 |
def _flash_attention_forward(
|
| 305 |
-
|
| 306 |
):
|
| 307 |
"""
|
| 308 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
@@ -408,14 +409,14 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
| 408 |
|
| 409 |
# Adapted from LlamaAttention.forward
|
| 410 |
def forward(
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 420 |
if output_attentions:
|
| 421 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
@@ -488,13 +489,13 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
| 488 |
return attn_output, None, past_key_value
|
| 489 |
|
| 490 |
|
| 491 |
-
|
| 492 |
LLAMA_ATTENTION_CLASSES = {
|
| 493 |
"eager": LlamaAttention,
|
| 494 |
"flash_attention_2": LlamaFlashAttention2,
|
| 495 |
"sdpa": LlamaSdpaAttention,
|
| 496 |
}
|
| 497 |
|
|
|
|
| 498 |
class LlamaDecoderLayer(nn.Module):
|
| 499 |
def __init__(self, config: LlamaConfig, layer_idx):
|
| 500 |
super().__init__()
|
|
@@ -510,14 +511,14 @@ class LlamaDecoderLayer(nn.Module):
|
|
| 510 |
))
|
| 511 |
|
| 512 |
def forward(
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 522 |
"""
|
| 523 |
Args:
|
|
@@ -567,17 +568,17 @@ class LlamaDecoderLayer(nn.Module):
|
|
| 567 |
|
| 568 |
|
| 569 |
def model_forward(
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 582 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 583 |
output_hidden_states = (
|
|
@@ -596,7 +597,7 @@ def model_forward(
|
|
| 596 |
batch_size, seq_length, _ = inputs_embeds.shape
|
| 597 |
else:
|
| 598 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 599 |
-
|
| 600 |
seq_length_with_past = seq_length
|
| 601 |
past_key_values_length = 0
|
| 602 |
|
|
@@ -620,24 +621,24 @@ def model_forward(
|
|
| 620 |
attention_mask = torch.ones(
|
| 621 |
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 622 |
)
|
| 623 |
-
|
| 624 |
if self._use_flash_attention_2:
|
| 625 |
-
|
| 626 |
-
|
| 627 |
elif self._use_sdpa and not output_attentions:
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
else:
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
|
| 642 |
hidden_states = inputs_embeds
|
| 643 |
|
|
@@ -712,18 +713,18 @@ def model_forward(
|
|
| 712 |
|
| 713 |
|
| 714 |
def causal_model_forward(
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 728 |
r"""
|
| 729 |
Args:
|
|
@@ -805,6 +806,7 @@ def causal_model_forward(
|
|
| 805 |
attentions=outputs.attentions,
|
| 806 |
)
|
| 807 |
|
|
|
|
| 808 |
def replace_llama_modality_adaptive():
|
| 809 |
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
|
| 810 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
|
@@ -814,7 +816,7 @@ def replace_llama_modality_adaptive():
|
|
| 814 |
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
| 815 |
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
| 816 |
|
| 817 |
-
|
| 818 |
if __name__ == "__main__":
|
| 819 |
replace_llama_modality_adaptive()
|
| 820 |
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|
|
|
|
| 8 |
import torch.utils.checkpoint
|
| 9 |
from torch import nn
|
| 10 |
|
|
|
|
|
|
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
|
|
|
|
| 16 |
|
| 17 |
import transformers
|
| 18 |
from transformers.models.llama.modeling_llama import *
|
| 19 |
+
from transformers.models.llama.modeling_llama import _get_unpad_data
|
| 20 |
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
from transformers.utils import logging
|
| 22 |
|
| 23 |
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
| 24 |
from .configuration_mplug_owl2 import LlamaConfig
|
| 25 |
|
| 26 |
+
|
| 27 |
class MultiwayNetwork(nn.Module):
|
| 28 |
|
| 29 |
def __init__(self, module_provider, num_multiway=2):
|
| 30 |
super(MultiwayNetwork, self).__init__()
|
| 31 |
|
| 32 |
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
| 33 |
+
|
| 34 |
def forward(self, hidden_states, multiway_indices):
|
| 35 |
|
| 36 |
if len(self.multiway) == 1:
|
| 37 |
return self.multiway[0](hidden_states)
|
| 38 |
|
| 39 |
output_hidden_states = torch.empty_like(hidden_states)
|
| 40 |
+
|
| 41 |
for idx, subway in enumerate(self.multiway):
|
| 42 |
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
| 43 |
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
|
|
|
| 47 |
output = output[0]
|
| 48 |
output = output.squeeze(1)
|
| 49 |
output_hidden_states[local_indices] = output
|
| 50 |
+
|
| 51 |
return output_hidden_states.contiguous()
|
| 52 |
+
|
| 53 |
|
| 54 |
class LlamaAttention(nn.Module):
|
| 55 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
| 64 |
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 65 |
"when creating this class."
|
| 66 |
)
|
| 67 |
+
|
| 68 |
self.attention_dropout = config.attention_dropout
|
| 69 |
self.hidden_size = config.hidden_size
|
| 70 |
self.num_heads = config.num_attention_heads
|
|
|
|
| 82 |
)
|
| 83 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 84 |
self.k_proj = MultiwayNetwork(module_provider=partial(
|
| 85 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
|
| 86 |
+
bias=config.attention_bias)
|
| 87 |
)
|
| 88 |
self.v_proj = MultiwayNetwork(module_provider=partial(
|
| 89 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
|
| 90 |
+
bias=config.attention_bias)
|
| 91 |
)
|
| 92 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
| 93 |
self._init_rope()
|
|
|
|
| 123 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 124 |
|
| 125 |
def forward(
|
| 126 |
+
self,
|
| 127 |
+
hidden_states: torch.Tensor,
|
| 128 |
+
modality_indicators: torch.Tensor,
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 131 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 132 |
+
output_attentions: bool = False,
|
| 133 |
+
use_cache: bool = False,
|
| 134 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
| 135 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 136 |
bsz, q_len, _ = hidden_states.size()
|
| 137 |
|
|
|
|
| 194 |
attn_weights = None
|
| 195 |
|
| 196 |
return attn_output, attn_weights, past_key_value
|
| 197 |
+
|
| 198 |
|
| 199 |
class LlamaFlashAttention2(LlamaAttention):
|
| 200 |
"""
|
|
|
|
| 212 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 213 |
|
| 214 |
def forward(
|
| 215 |
+
self,
|
| 216 |
+
hidden_states: torch.Tensor,
|
| 217 |
+
modality_indicators: torch.Tensor,
|
| 218 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 219 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 220 |
+
past_key_value: Optional[Cache] = None,
|
| 221 |
+
output_attentions: bool = False,
|
| 222 |
+
use_cache: bool = False,
|
| 223 |
+
**kwargs,
|
| 224 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 225 |
# LlamaFlashAttention2 attention does not support output_attentions
|
| 226 |
if "padding_mask" in kwargs:
|
|
|
|
| 303 |
return attn_output, attn_weights, past_key_value
|
| 304 |
|
| 305 |
def _flash_attention_forward(
|
| 306 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 307 |
):
|
| 308 |
"""
|
| 309 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
|
| 409 |
|
| 410 |
# Adapted from LlamaAttention.forward
|
| 411 |
def forward(
|
| 412 |
+
self,
|
| 413 |
+
hidden_states: torch.Tensor,
|
| 414 |
+
modality_indicators: torch.Tensor,
|
| 415 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 416 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 417 |
+
past_key_value: Optional[Cache] = None,
|
| 418 |
+
output_attentions: bool = False,
|
| 419 |
+
use_cache: bool = False,
|
| 420 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 421 |
if output_attentions:
|
| 422 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
|
|
| 489 |
return attn_output, None, past_key_value
|
| 490 |
|
| 491 |
|
|
|
|
| 492 |
LLAMA_ATTENTION_CLASSES = {
|
| 493 |
"eager": LlamaAttention,
|
| 494 |
"flash_attention_2": LlamaFlashAttention2,
|
| 495 |
"sdpa": LlamaSdpaAttention,
|
| 496 |
}
|
| 497 |
|
| 498 |
+
|
| 499 |
class LlamaDecoderLayer(nn.Module):
|
| 500 |
def __init__(self, config: LlamaConfig, layer_idx):
|
| 501 |
super().__init__()
|
|
|
|
| 511 |
))
|
| 512 |
|
| 513 |
def forward(
|
| 514 |
+
self,
|
| 515 |
+
hidden_states: torch.Tensor,
|
| 516 |
+
modality_indicators: torch.Tensor = None,
|
| 517 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 518 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 519 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 520 |
+
output_attentions: Optional[bool] = False,
|
| 521 |
+
use_cache: Optional[bool] = False,
|
| 522 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 523 |
"""
|
| 524 |
Args:
|
|
|
|
| 568 |
|
| 569 |
|
| 570 |
def model_forward(
|
| 571 |
+
self,
|
| 572 |
+
input_ids: torch.LongTensor = None,
|
| 573 |
+
modality_indicators: torch.Tensor = None,
|
| 574 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 575 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 576 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 577 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 578 |
+
use_cache: Optional[bool] = None,
|
| 579 |
+
output_attentions: Optional[bool] = None,
|
| 580 |
+
output_hidden_states: Optional[bool] = None,
|
| 581 |
+
return_dict: Optional[bool] = None,
|
| 582 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 583 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 584 |
output_hidden_states = (
|
|
|
|
| 597 |
batch_size, seq_length, _ = inputs_embeds.shape
|
| 598 |
else:
|
| 599 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 600 |
+
|
| 601 |
seq_length_with_past = seq_length
|
| 602 |
past_key_values_length = 0
|
| 603 |
|
|
|
|
| 621 |
attention_mask = torch.ones(
|
| 622 |
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 623 |
)
|
| 624 |
+
|
| 625 |
if self._use_flash_attention_2:
|
| 626 |
+
# 2d mask is passed through the layers
|
| 627 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 628 |
elif self._use_sdpa and not output_attentions:
|
| 629 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 630 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 631 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 632 |
+
attention_mask,
|
| 633 |
+
(batch_size, seq_length),
|
| 634 |
+
inputs_embeds,
|
| 635 |
+
past_key_values_length,
|
| 636 |
+
)
|
| 637 |
else:
|
| 638 |
+
# 4d mask is passed through the layers
|
| 639 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 640 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 641 |
+
)
|
| 642 |
|
| 643 |
hidden_states = inputs_embeds
|
| 644 |
|
|
|
|
| 713 |
|
| 714 |
|
| 715 |
def causal_model_forward(
|
| 716 |
+
self,
|
| 717 |
+
input_ids: torch.LongTensor = None,
|
| 718 |
+
modality_indicators: torch.Tensor = None,
|
| 719 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 720 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 721 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 722 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 723 |
+
labels: Optional[torch.LongTensor] = None,
|
| 724 |
+
use_cache: Optional[bool] = None,
|
| 725 |
+
output_attentions: Optional[bool] = None,
|
| 726 |
+
output_hidden_states: Optional[bool] = None,
|
| 727 |
+
return_dict: Optional[bool] = None,
|
| 728 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 729 |
r"""
|
| 730 |
Args:
|
|
|
|
| 806 |
attentions=outputs.attentions,
|
| 807 |
)
|
| 808 |
|
| 809 |
+
|
| 810 |
def replace_llama_modality_adaptive():
|
| 811 |
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
|
| 812 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
|
|
|
| 816 |
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
| 817 |
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
| 818 |
|
| 819 |
+
|
| 820 |
if __name__ == "__main__":
|
| 821 |
replace_llama_modality_adaptive()
|
| 822 |
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|