Commit
·
06c1397
1
Parent(s):
9c0f1cf
Remove flash attn support
Browse files- attention.py +4 -71
attention.py
CHANGED
|
@@ -58,59 +58,6 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
| 58 |
if not tensor.is_cuda:
|
| 59 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
| 60 |
|
| 61 |
-
def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 62 |
-
try:
|
| 63 |
-
from flash_attn import bert_padding, flash_attn_interface
|
| 64 |
-
except:
|
| 65 |
-
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
| 66 |
-
check_valid_inputs(query, key, value)
|
| 67 |
-
if attn_bias is not None:
|
| 68 |
-
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
| 69 |
-
(batch_size, seqlen) = query.shape[:2]
|
| 70 |
-
if key_padding_mask is None:
|
| 71 |
-
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
| 72 |
-
query_padding_mask = key_padding_mask[:, -query.size(1):]
|
| 73 |
-
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
| 74 |
-
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
| 75 |
-
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
| 76 |
-
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
| 77 |
-
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
| 78 |
-
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
| 79 |
-
if multiquery:
|
| 80 |
-
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
| 81 |
-
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
| 82 |
-
dropout_p = dropout_p if training else 0.0
|
| 83 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 84 |
-
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
| 85 |
-
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
| 86 |
-
return (output, None)
|
| 87 |
-
|
| 88 |
-
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 89 |
-
try:
|
| 90 |
-
from flash_attn import flash_attn_triton
|
| 91 |
-
except:
|
| 92 |
-
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
| 93 |
-
check_valid_inputs(query, key, value)
|
| 94 |
-
if dropout_p:
|
| 95 |
-
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
| 96 |
-
if needs_weights:
|
| 97 |
-
raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
|
| 98 |
-
if key_padding_mask is not None:
|
| 99 |
-
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
| 100 |
-
(b_size, s_k) = key_padding_mask.shape[:2]
|
| 101 |
-
if attn_bias is None:
|
| 102 |
-
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
| 103 |
-
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
| 104 |
-
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
| 105 |
-
key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
| 106 |
-
value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
| 107 |
-
if multiquery:
|
| 108 |
-
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 109 |
-
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 110 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 111 |
-
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 112 |
-
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 113 |
-
return (output, None)
|
| 114 |
|
| 115 |
class MultiheadAttention(nn.Module):
|
| 116 |
"""Multi-head self attention.
|
|
@@ -137,12 +84,7 @@ class MultiheadAttention(nn.Module):
|
|
| 137 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 138 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
| 139 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
| 140 |
-
if self.attn_impl == '
|
| 141 |
-
self.attn_fn = flash_attn_fn
|
| 142 |
-
elif self.attn_impl == 'triton':
|
| 143 |
-
self.attn_fn = triton_flash_attn_fn
|
| 144 |
-
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
| 145 |
-
elif self.attn_impl == 'torch':
|
| 146 |
self.attn_fn = scaled_multihead_dot_product_attention
|
| 147 |
if torch.cuda.is_available():
|
| 148 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
@@ -197,12 +139,7 @@ class MultiQueryAttention(nn.Module):
|
|
| 197 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 198 |
self.q_ln = layernorm_class(d_model, device=device)
|
| 199 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
| 200 |
-
if self.attn_impl == '
|
| 201 |
-
self.attn_fn = flash_attn_fn
|
| 202 |
-
elif self.attn_impl == 'triton':
|
| 203 |
-
self.attn_fn = triton_flash_attn_fn
|
| 204 |
-
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
| 205 |
-
elif self.attn_impl == 'torch':
|
| 206 |
self.attn_fn = scaled_multihead_dot_product_attention
|
| 207 |
if torch.cuda.is_available():
|
| 208 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
@@ -232,9 +169,7 @@ class MultiQueryAttention(nn.Module):
|
|
| 232 |
return (self.out_proj(context), attn_weights, past_key_value)
|
| 233 |
|
| 234 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
| 235 |
-
if
|
| 236 |
-
return None
|
| 237 |
-
elif attn_impl in ['torch', 'triton']:
|
| 238 |
if alibi:
|
| 239 |
if (prefix_lm or not causal) or use_sequence_id:
|
| 240 |
return (1, n_heads, seq_len, seq_len)
|
|
@@ -246,9 +181,7 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
|
|
| 246 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
| 247 |
|
| 248 |
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
| 249 |
-
if attn_impl
|
| 250 |
-
return None
|
| 251 |
-
elif attn_impl in ['torch', 'triton']:
|
| 252 |
if alibi:
|
| 253 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 254 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
|
|
|
| 58 |
if not tensor.is_cuda:
|
| 59 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
class MultiheadAttention(nn.Module):
|
| 63 |
"""Multi-head self attention.
|
|
|
|
| 84 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 85 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
| 86 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
| 87 |
+
if self.attn_impl == 'torch':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
self.attn_fn = scaled_multihead_dot_product_attention
|
| 89 |
if torch.cuda.is_available():
|
| 90 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
|
|
| 139 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 140 |
self.q_ln = layernorm_class(d_model, device=device)
|
| 141 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
| 142 |
+
if self.attn_impl == 'torch':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
self.attn_fn = scaled_multihead_dot_product_attention
|
| 144 |
if torch.cuda.is_available():
|
| 145 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
|
|
| 169 |
return (self.out_proj(context), attn_weights, past_key_value)
|
| 170 |
|
| 171 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
| 172 |
+
if attn_impl in ['torch', 'triton']:
|
|
|
|
|
|
|
| 173 |
if alibi:
|
| 174 |
if (prefix_lm or not causal) or use_sequence_id:
|
| 175 |
return (1, n_heads, seq_len, seq_len)
|
|
|
|
| 181 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
| 182 |
|
| 183 |
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
| 184 |
+
if attn_impl in ['torch', 'triton']:
|
|
|
|
|
|
|
| 185 |
if alibi:
|
| 186 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 187 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|