File size: 3,190 Bytes
2fe4bd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from typing import Optional
class NeuroBLASTConfig(PretrainedConfig):
model_type = "neuroblast"
def __init__(
self,
vocab_size=28886,
hidden_size=2048,
kv_dim=2048,
intermediate_size=None,
num_attention_heads=32,
num_sensory_cortex_layers=6,
num_motor_cortex_layers=6,
num_association_cortex_layers=6,
dropout=0.1,
layer_norm_epsilon=1e-6,
pad_token_id=None,
use_cache=False,
rope_theta=10000.0,
rope_scaling=None,
max_position_embeddings=2048,
initializer_range=0.02,
use_flash_attn=True,
num_experts=None,
num_experts_per_tok=None,
norm_topk_prob=False,
hidden_act="silu",
use_zero_memory=False,
zero_memory_alpha=1.0,
zero_memory_layers=None,
gradient_scaling_enabled=True,
association_gradient_scale=0.9,
sensory_gradient_scale=0.95,
cross_attention_gradient_scale=0.95,
clamp_value=1e5,
_attn_implementation='sdpa',
**kwargs
):
# Calculate intermediate_size if not provided
if intermediate_size is None:
intermediate_size = int(hidden_size * 4 * 2 / 3)
super().__init__(
pad_token_id=pad_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.kv_dim = kv_dim
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_sensory_cortex_layers = num_sensory_cortex_layers
self.num_motor_cortex_layers = num_motor_cortex_layers
self.num_association_cortex_layers = num_association_cortex_layers
self.dropout = dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.rms_norm_eps = layer_norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.use_flash_attn = use_flash_attn
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.hidden_act = hidden_act
self.use_zero_memory = use_zero_memory
self.zero_memory_alpha = zero_memory_alpha
self.zero_memory_layers = zero_memory_layers
self.gradient_scaling_enabled = gradient_scaling_enabled
self.association_gradient_scale = association_gradient_scale
self.sensory_gradient_scale = sensory_gradient_scale
self.cross_attention_gradient_scale = cross_attention_gradient_scale
self._attn_implementation = _attn_implementation
self.clamp_value = clamp_value
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
|