make FlashAttention logic more robust
Browse files- modeling_gptbert.py +5 -5
modeling_gptbert.py
CHANGED
|
@@ -367,7 +367,7 @@ class SelfAttention(nn.Module):
|
|
| 367 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
| 368 |
|
| 369 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 370 |
-
if
|
| 371 |
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
| 372 |
else:
|
| 373 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
|
@@ -418,7 +418,7 @@ class SelfAttention(nn.Module):
|
|
| 418 |
|
| 419 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 420 |
# Get original shape info
|
| 421 |
-
if
|
| 422 |
# Unpadded case
|
| 423 |
indices, cu_seqlens, max_seqlen = padding_info
|
| 424 |
total_seqlen = hidden_layer.size(0)
|
|
@@ -433,7 +433,7 @@ class SelfAttention(nn.Module):
|
|
| 433 |
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
| 434 |
value = self.v_proj(hidden_layer)
|
| 435 |
|
| 436 |
-
if
|
| 437 |
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
| 438 |
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
| 439 |
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
|
@@ -645,7 +645,7 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
| 645 |
else:
|
| 646 |
attention_mask = attention_mask.bool()
|
| 647 |
|
| 648 |
-
if
|
| 649 |
if len(attention_mask.size()) != 2:
|
| 650 |
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
| 651 |
with torch.no_grad():
|
|
@@ -676,7 +676,7 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
| 676 |
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
| 677 |
|
| 678 |
# Pad output if using FlashAttention
|
| 679 |
-
if
|
| 680 |
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
| 681 |
if output_hidden_states:
|
| 682 |
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|
|
|
|
| 367 |
theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
|
| 368 |
|
| 369 |
# Initialize rotary embeddings based on whether FlashAttention is available
|
| 370 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 371 |
self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
|
| 372 |
else:
|
| 373 |
self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
|
|
|
|
| 418 |
|
| 419 |
def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
|
| 420 |
# Get original shape info
|
| 421 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 422 |
# Unpadded case
|
| 423 |
indices, cu_seqlens, max_seqlen = padding_info
|
| 424 |
total_seqlen = hidden_layer.size(0)
|
|
|
|
| 433 |
query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
|
| 434 |
value = self.v_proj(hidden_layer)
|
| 435 |
|
| 436 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 437 |
# Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
|
| 438 |
query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
|
| 439 |
key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
|
|
|
|
| 645 |
else:
|
| 646 |
attention_mask = attention_mask.bool()
|
| 647 |
|
| 648 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 649 |
if len(attention_mask.size()) != 2:
|
| 650 |
raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
|
| 651 |
with torch.no_grad():
|
|
|
|
| 676 |
contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
|
| 677 |
|
| 678 |
# Pad output if using FlashAttention
|
| 679 |
+
if flash_attn_varlen_qkvpacked_func is not None:
|
| 680 |
last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
|
| 681 |
if output_hidden_states:
|
| 682 |
contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
|