Fix SDPA & Flash-Attention

#7
Files changed (2) hide show
  1. modeling_attn_mask_utils.py +89 -0
  2. modeling_llama2.py +89 -87
modeling_attn_mask_utils.py CHANGED
@@ -160,6 +160,95 @@ class AttentionMaskConverter:
160
 
161
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def _prepare_4d_causal_attention_mask(
165
  attention_mask: Optional[torch.Tensor],
 
160
 
161
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162
 
163
+ @staticmethod
164
+ def _unmask_unattended(
165
+ expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
166
+ ):
167
+ # fmt: off
168
+ """
169
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
170
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
171
+ Details: https://github.com/pytorch/pytorch/issues/110213
172
+
173
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
174
+ `attention_mask` is [bsz, src_seq_len].
175
+
176
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
177
+
178
+ For example, if `attention_mask` is
179
+ ```
180
+ [[0, 0, 1],
181
+ [1, 1, 1],
182
+ [0, 1, 1]]
183
+ ```
184
+ and `expanded_mask` is (e.g. here left-padding case)
185
+ ```
186
+ [[[[0, 0, 0],
187
+ [0, 0, 0],
188
+ [0, 0, 1]]],
189
+ [[[1, 0, 0],
190
+ [1, 1, 0],
191
+ [1, 1, 1]]],
192
+ [[[0, 0, 0],
193
+ [0, 1, 0],
194
+ [0, 1, 1]]]]
195
+ ```
196
+ then the modified `expanded_mask` will be
197
+ ```
198
+ [[[[1, 1, 1], <-- modified
199
+ [1, 1, 1], <-- modified
200
+ [0, 0, 1]]],
201
+ [[[1, 0, 0],
202
+ [1, 1, 0],
203
+ [1, 1, 1]]],
204
+ [[[1, 1, 1], <-- modified
205
+ [0, 1, 0],
206
+ [0, 1, 1]]]]
207
+ ```
208
+ """
209
+ # fmt: on
210
+
211
+ # Get the index of the first non-zero value for every sample in the batch.
212
+ # In the above example, indices = [[2], [0], [1]]]
213
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
214
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
215
+
216
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
217
+ # expanded mask will be completely unattended.
218
+ left_masked_rows = torch.where(indices > 0)[0]
219
+
220
+ if left_masked_rows.shape[0] == 0:
221
+ return expanded_mask
222
+ indices = indices[left_masked_rows]
223
+
224
+ max_len = torch.max(indices)
225
+ range_tensor = torch.arange(max_len).unsqueeze(0)
226
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
227
+
228
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
229
+ range_tensor[range_tensor >= indices] = 0
230
+
231
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
232
+ if expanded_mask.dim() == 4:
233
+ num_masks = expanded_mask.shape[1]
234
+ if num_masks == 1:
235
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
236
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
237
+ else:
238
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
239
+ mask_slice = (
240
+ left_masked_rows[:, None, None],
241
+ torch.arange(num_masks)[None, :, None],
242
+ range_tensor[:, None, :],
243
+ )
244
+ else:
245
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
246
+ mask_slice = (left_masked_rows[:, None], range_tensor)
247
+
248
+ expanded_mask[mask_slice] = unmasked_value
249
+
250
+ return expanded_mask
251
+
252
 
253
  def _prepare_4d_causal_attention_mask(
254
  attention_mask: Optional[torch.Tensor],
modeling_llama2.py CHANGED
@@ -8,8 +8,6 @@ import torch.nn.functional as F
8
  import torch.utils.checkpoint
9
  from torch import nn
10
 
11
-
12
- import copy
13
  import os
14
  import sys
15
 
@@ -18,27 +16,28 @@ sys.path.insert(0, dir_path)
18
 
19
  import transformers
20
  from transformers.models.llama.modeling_llama import *
21
- from transformers.models.llama.modeling_llama import *
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.utils import logging
24
 
25
  from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
26
  from .configuration_mplug_owl2 import LlamaConfig
27
 
 
28
  class MultiwayNetwork(nn.Module):
29
 
30
  def __init__(self, module_provider, num_multiway=2):
31
  super(MultiwayNetwork, self).__init__()
32
 
33
  self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
34
-
35
  def forward(self, hidden_states, multiway_indices):
36
 
37
  if len(self.multiway) == 1:
38
  return self.multiway[0](hidden_states)
39
 
40
  output_hidden_states = torch.empty_like(hidden_states)
41
-
42
  for idx, subway in enumerate(self.multiway):
43
  local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
44
  hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
@@ -48,9 +47,9 @@ class MultiwayNetwork(nn.Module):
48
  output = output[0]
49
  output = output.squeeze(1)
50
  output_hidden_states[local_indices] = output
51
-
52
  return output_hidden_states.contiguous()
53
-
54
 
55
  class LlamaAttention(nn.Module):
56
  """Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -65,7 +64,7 @@ class LlamaAttention(nn.Module):
65
  "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
66
  "when creating this class."
67
  )
68
-
69
  self.attention_dropout = config.attention_dropout
70
  self.hidden_size = config.hidden_size
71
  self.num_heads = config.num_attention_heads
@@ -83,10 +82,12 @@ class LlamaAttention(nn.Module):
83
  )
84
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
85
  self.k_proj = MultiwayNetwork(module_provider=partial(
86
- nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
 
87
  )
88
  self.v_proj = MultiwayNetwork(module_provider=partial(
89
- nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
 
90
  )
91
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
92
  self._init_rope()
@@ -122,15 +123,15 @@ class LlamaAttention(nn.Module):
122
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
123
 
124
  def forward(
125
- self,
126
- hidden_states: torch.Tensor,
127
- modality_indicators: torch.Tensor,
128
- attention_mask: Optional[torch.Tensor] = None,
129
- position_ids: Optional[torch.LongTensor] = None,
130
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
131
- output_attentions: bool = False,
132
- use_cache: bool = False,
133
- padding_mask: Optional[torch.LongTensor] = None,
134
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
135
  bsz, q_len, _ = hidden_states.size()
136
 
@@ -193,7 +194,7 @@ class LlamaAttention(nn.Module):
193
  attn_weights = None
194
 
195
  return attn_output, attn_weights, past_key_value
196
-
197
 
198
  class LlamaFlashAttention2(LlamaAttention):
199
  """
@@ -211,15 +212,15 @@ class LlamaFlashAttention2(LlamaAttention):
211
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
212
 
213
  def forward(
214
- self,
215
- hidden_states: torch.Tensor,
216
- modality_indicators: torch.Tensor,
217
- attention_mask: Optional[torch.LongTensor] = None,
218
- position_ids: Optional[torch.LongTensor] = None,
219
- past_key_value: Optional[Cache] = None,
220
- output_attentions: bool = False,
221
- use_cache: bool = False,
222
- **kwargs,
223
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
224
  # LlamaFlashAttention2 attention does not support output_attentions
225
  if "padding_mask" in kwargs:
@@ -302,7 +303,7 @@ class LlamaFlashAttention2(LlamaAttention):
302
  return attn_output, attn_weights, past_key_value
303
 
304
  def _flash_attention_forward(
305
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
306
  ):
307
  """
308
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -408,14 +409,14 @@ class LlamaSdpaAttention(LlamaAttention):
408
 
409
  # Adapted from LlamaAttention.forward
410
  def forward(
411
- self,
412
- hidden_states: torch.Tensor,
413
- modality_indicators: torch.Tensor,
414
- attention_mask: Optional[torch.Tensor] = None,
415
- position_ids: Optional[torch.LongTensor] = None,
416
- past_key_value: Optional[Cache] = None,
417
- output_attentions: bool = False,
418
- use_cache: bool = False,
419
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
420
  if output_attentions:
421
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -488,13 +489,13 @@ class LlamaSdpaAttention(LlamaAttention):
488
  return attn_output, None, past_key_value
489
 
490
 
491
-
492
  LLAMA_ATTENTION_CLASSES = {
493
  "eager": LlamaAttention,
494
  "flash_attention_2": LlamaFlashAttention2,
495
  "sdpa": LlamaSdpaAttention,
496
  }
497
 
 
498
  class LlamaDecoderLayer(nn.Module):
499
  def __init__(self, config: LlamaConfig, layer_idx):
500
  super().__init__()
@@ -510,14 +511,14 @@ class LlamaDecoderLayer(nn.Module):
510
  ))
511
 
512
  def forward(
513
- self,
514
- hidden_states: torch.Tensor,
515
- modality_indicators: torch.Tensor = None,
516
- attention_mask: Optional[torch.Tensor] = None,
517
- position_ids: Optional[torch.LongTensor] = None,
518
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
519
- output_attentions: Optional[bool] = False,
520
- use_cache: Optional[bool] = False,
521
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
522
  """
523
  Args:
@@ -567,17 +568,17 @@ class LlamaDecoderLayer(nn.Module):
567
 
568
 
569
  def model_forward(
570
- self,
571
- input_ids: torch.LongTensor = None,
572
- modality_indicators: torch.Tensor = None,
573
- attention_mask: Optional[torch.Tensor] = None,
574
- position_ids: Optional[torch.LongTensor] = None,
575
- past_key_values: Optional[List[torch.FloatTensor]] = None,
576
- inputs_embeds: Optional[torch.FloatTensor] = None,
577
- use_cache: Optional[bool] = None,
578
- output_attentions: Optional[bool] = None,
579
- output_hidden_states: Optional[bool] = None,
580
- return_dict: Optional[bool] = None,
581
  ) -> Union[Tuple, BaseModelOutputWithPast]:
582
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
583
  output_hidden_states = (
@@ -596,7 +597,7 @@ def model_forward(
596
  batch_size, seq_length, _ = inputs_embeds.shape
597
  else:
598
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
599
-
600
  seq_length_with_past = seq_length
601
  past_key_values_length = 0
602
 
@@ -620,24 +621,24 @@ def model_forward(
620
  attention_mask = torch.ones(
621
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
622
  )
623
-
624
  if self._use_flash_attention_2:
625
- # 2d mask is passed through the layers
626
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
627
  elif self._use_sdpa and not output_attentions:
628
- # output_attentions=True can not be supported when using SDPA, and we fall back on
629
- # the manual implementation that requires a 4D causal mask in all cases.
630
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
631
- attention_mask,
632
- (batch_size, seq_length),
633
- inputs_embeds,
634
- past_key_values_length,
635
- )
636
  else:
637
- # 4d mask is passed through the layers
638
- attention_mask = _prepare_4d_causal_attention_mask(
639
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
640
- )
641
 
642
  hidden_states = inputs_embeds
643
 
@@ -712,18 +713,18 @@ def model_forward(
712
 
713
 
714
  def causal_model_forward(
715
- self,
716
- input_ids: torch.LongTensor = None,
717
- modality_indicators: torch.Tensor = None,
718
- attention_mask: Optional[torch.Tensor] = None,
719
- position_ids: Optional[torch.LongTensor] = None,
720
- past_key_values: Optional[List[torch.FloatTensor]] = None,
721
- inputs_embeds: Optional[torch.FloatTensor] = None,
722
- labels: Optional[torch.LongTensor] = None,
723
- use_cache: Optional[bool] = None,
724
- output_attentions: Optional[bool] = None,
725
- output_hidden_states: Optional[bool] = None,
726
- return_dict: Optional[bool] = None,
727
  ) -> Union[Tuple, CausalLMOutputWithPast]:
728
  r"""
729
  Args:
@@ -805,6 +806,7 @@ def causal_model_forward(
805
  attentions=outputs.attentions,
806
  )
807
 
 
808
  def replace_llama_modality_adaptive():
809
  transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
810
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
@@ -814,7 +816,7 @@ def replace_llama_modality_adaptive():
814
  transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
815
  transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
816
 
817
-
818
  if __name__ == "__main__":
819
  replace_llama_modality_adaptive()
820
  config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
 
8
  import torch.utils.checkpoint
9
  from torch import nn
10
 
 
 
11
  import os
12
  import sys
13
 
 
16
 
17
  import transformers
18
  from transformers.models.llama.modeling_llama import *
19
+ from transformers.models.llama.modeling_llama import _get_unpad_data
20
  from transformers.configuration_utils import PretrainedConfig
21
  from transformers.utils import logging
22
 
23
  from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
24
  from .configuration_mplug_owl2 import LlamaConfig
25
 
26
+
27
  class MultiwayNetwork(nn.Module):
28
 
29
  def __init__(self, module_provider, num_multiway=2):
30
  super(MultiwayNetwork, self).__init__()
31
 
32
  self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
33
+
34
  def forward(self, hidden_states, multiway_indices):
35
 
36
  if len(self.multiway) == 1:
37
  return self.multiway[0](hidden_states)
38
 
39
  output_hidden_states = torch.empty_like(hidden_states)
40
+
41
  for idx, subway in enumerate(self.multiway):
42
  local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
43
  hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
 
47
  output = output[0]
48
  output = output.squeeze(1)
49
  output_hidden_states[local_indices] = output
50
+
51
  return output_hidden_states.contiguous()
52
+
53
 
54
  class LlamaAttention(nn.Module):
55
  """Multi-headed attention from 'Attention Is All You Need' paper"""
 
64
  "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
65
  "when creating this class."
66
  )
67
+
68
  self.attention_dropout = config.attention_dropout
69
  self.hidden_size = config.hidden_size
70
  self.num_heads = config.num_attention_heads
 
82
  )
83
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
84
  self.k_proj = MultiwayNetwork(module_provider=partial(
85
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
86
+ bias=config.attention_bias)
87
  )
88
  self.v_proj = MultiwayNetwork(module_provider=partial(
89
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim,
90
+ bias=config.attention_bias)
91
  )
92
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
93
  self._init_rope()
 
123
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
124
 
125
  def forward(
126
+ self,
127
+ hidden_states: torch.Tensor,
128
+ modality_indicators: torch.Tensor,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
132
+ output_attentions: bool = False,
133
+ use_cache: bool = False,
134
+ padding_mask: Optional[torch.LongTensor] = None,
135
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
136
  bsz, q_len, _ = hidden_states.size()
137
 
 
194
  attn_weights = None
195
 
196
  return attn_output, attn_weights, past_key_value
197
+
198
 
199
  class LlamaFlashAttention2(LlamaAttention):
200
  """
 
212
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
213
 
214
  def forward(
215
+ self,
216
+ hidden_states: torch.Tensor,
217
+ modality_indicators: torch.Tensor,
218
+ attention_mask: Optional[torch.LongTensor] = None,
219
+ position_ids: Optional[torch.LongTensor] = None,
220
+ past_key_value: Optional[Cache] = None,
221
+ output_attentions: bool = False,
222
+ use_cache: bool = False,
223
+ **kwargs,
224
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
  # LlamaFlashAttention2 attention does not support output_attentions
226
  if "padding_mask" in kwargs:
 
303
  return attn_output, attn_weights, past_key_value
304
 
305
  def _flash_attention_forward(
306
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
307
  ):
308
  """
309
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
409
 
410
  # Adapted from LlamaAttention.forward
411
  def forward(
412
+ self,
413
+ hidden_states: torch.Tensor,
414
+ modality_indicators: torch.Tensor,
415
+ attention_mask: Optional[torch.Tensor] = None,
416
+ position_ids: Optional[torch.LongTensor] = None,
417
+ past_key_value: Optional[Cache] = None,
418
+ output_attentions: bool = False,
419
+ use_cache: bool = False,
420
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
421
  if output_attentions:
422
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
489
  return attn_output, None, past_key_value
490
 
491
 
 
492
  LLAMA_ATTENTION_CLASSES = {
493
  "eager": LlamaAttention,
494
  "flash_attention_2": LlamaFlashAttention2,
495
  "sdpa": LlamaSdpaAttention,
496
  }
497
 
498
+
499
  class LlamaDecoderLayer(nn.Module):
500
  def __init__(self, config: LlamaConfig, layer_idx):
501
  super().__init__()
 
511
  ))
512
 
513
  def forward(
514
+ self,
515
+ hidden_states: torch.Tensor,
516
+ modality_indicators: torch.Tensor = None,
517
+ attention_mask: Optional[torch.Tensor] = None,
518
+ position_ids: Optional[torch.LongTensor] = None,
519
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
520
+ output_attentions: Optional[bool] = False,
521
+ use_cache: Optional[bool] = False,
522
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
523
  """
524
  Args:
 
568
 
569
 
570
  def model_forward(
571
+ self,
572
+ input_ids: torch.LongTensor = None,
573
+ modality_indicators: torch.Tensor = None,
574
+ attention_mask: Optional[torch.Tensor] = None,
575
+ position_ids: Optional[torch.LongTensor] = None,
576
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
577
+ inputs_embeds: Optional[torch.FloatTensor] = None,
578
+ use_cache: Optional[bool] = None,
579
+ output_attentions: Optional[bool] = None,
580
+ output_hidden_states: Optional[bool] = None,
581
+ return_dict: Optional[bool] = None,
582
  ) -> Union[Tuple, BaseModelOutputWithPast]:
583
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
584
  output_hidden_states = (
 
597
  batch_size, seq_length, _ = inputs_embeds.shape
598
  else:
599
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
600
+
601
  seq_length_with_past = seq_length
602
  past_key_values_length = 0
603
 
 
621
  attention_mask = torch.ones(
622
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
623
  )
624
+
625
  if self._use_flash_attention_2:
626
+ # 2d mask is passed through the layers
627
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
628
  elif self._use_sdpa and not output_attentions:
629
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
630
+ # the manual implementation that requires a 4D causal mask in all cases.
631
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
632
+ attention_mask,
633
+ (batch_size, seq_length),
634
+ inputs_embeds,
635
+ past_key_values_length,
636
+ )
637
  else:
638
+ # 4d mask is passed through the layers
639
+ attention_mask = _prepare_4d_causal_attention_mask(
640
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
641
+ )
642
 
643
  hidden_states = inputs_embeds
644
 
 
713
 
714
 
715
  def causal_model_forward(
716
+ self,
717
+ input_ids: torch.LongTensor = None,
718
+ modality_indicators: torch.Tensor = None,
719
+ attention_mask: Optional[torch.Tensor] = None,
720
+ position_ids: Optional[torch.LongTensor] = None,
721
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
722
+ inputs_embeds: Optional[torch.FloatTensor] = None,
723
+ labels: Optional[torch.LongTensor] = None,
724
+ use_cache: Optional[bool] = None,
725
+ output_attentions: Optional[bool] = None,
726
+ output_hidden_states: Optional[bool] = None,
727
+ return_dict: Optional[bool] = None,
728
  ) -> Union[Tuple, CausalLMOutputWithPast]:
729
  r"""
730
  Args:
 
806
  attentions=outputs.attentions,
807
  )
808
 
809
+
810
  def replace_llama_modality_adaptive():
811
  transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
812
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
 
816
  transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
817
  transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
818
 
819
+
820
  if __name__ == "__main__":
821
  replace_llama_modality_adaptive()
822
  config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')