feat: add dropout support
Browse files
configuration_stablelm_epoch.py
CHANGED
|
@@ -65,6 +65,8 @@ class StableLMEpochConfig(PretrainedConfig):
|
|
| 65 |
Whether or not the model should use bias for qkv layers.
|
| 66 |
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 67 |
Whether to tie weight embeddings
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
model_type = "stablelm_epoch"
|
| 70 |
keys_to_ignore_at_inference = ["past_key_values"]
|
|
@@ -88,6 +90,7 @@ class StableLMEpochConfig(PretrainedConfig):
|
|
| 88 |
bos_token_id=0,
|
| 89 |
eos_token_id=2,
|
| 90 |
tie_word_embeddings=False,
|
|
|
|
| 91 |
**kwargs,
|
| 92 |
):
|
| 93 |
self.vocab_size = vocab_size
|
|
@@ -105,6 +108,7 @@ class StableLMEpochConfig(PretrainedConfig):
|
|
| 105 |
self.use_cache = use_cache
|
| 106 |
self.use_qkv_bias = use_qkv_bias
|
| 107 |
self.tie_word_embeddings = tie_word_embeddings
|
|
|
|
| 108 |
super().__init__(
|
| 109 |
bos_token_id=bos_token_id,
|
| 110 |
eos_token_id=eos_token_id,
|
|
|
|
| 65 |
Whether or not the model should use bias for qkv layers.
|
| 66 |
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 67 |
Whether to tie weight embeddings
|
| 68 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 69 |
+
The dropout ratio for the attention probabilities.
|
| 70 |
"""
|
| 71 |
model_type = "stablelm_epoch"
|
| 72 |
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
| 90 |
bos_token_id=0,
|
| 91 |
eos_token_id=2,
|
| 92 |
tie_word_embeddings=False,
|
| 93 |
+
attention_dropout: float = 0.0,
|
| 94 |
**kwargs,
|
| 95 |
):
|
| 96 |
self.vocab_size = vocab_size
|
|
|
|
| 108 |
self.use_cache = use_cache
|
| 109 |
self.use_qkv_bias = use_qkv_bias
|
| 110 |
self.tie_word_embeddings = tie_word_embeddings
|
| 111 |
+
self.attention_dropout = attention_dropout
|
| 112 |
super().__init__(
|
| 113 |
bos_token_id=bos_token_id,
|
| 114 |
eos_token_id=eos_token_id,
|
modeling_stablelm_epoch.py
CHANGED
|
@@ -191,6 +191,7 @@ class Attention(nn.Module):
|
|
| 191 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 192 |
self.max_position_embeddings = config.max_position_embeddings
|
| 193 |
self.is_causal = True
|
|
|
|
| 194 |
|
| 195 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 196 |
raise ValueError(
|
|
@@ -275,6 +276,7 @@ class Attention(nn.Module):
|
|
| 275 |
|
| 276 |
# Upcast attention to fp32
|
| 277 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
| 278 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 279 |
|
| 280 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
|
| 191 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 192 |
self.max_position_embeddings = config.max_position_embeddings
|
| 193 |
self.is_causal = True
|
| 194 |
+
self.attention_dropout = config.attention_dropout
|
| 195 |
|
| 196 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 197 |
raise ValueError(
|
|
|
|
| 276 |
|
| 277 |
# Upcast attention to fp32
|
| 278 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 279 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 280 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 281 |
|
| 282 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|