from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from typing import Optional class NeuroBLASTConfig(PretrainedConfig): model_type = "neuroblast" def __init__( self, vocab_size=28886, hidden_size=2048, kv_dim=2048, intermediate_size=None, num_attention_heads=32, num_sensory_cortex_layers=6, num_motor_cortex_layers=6, num_association_cortex_layers=6, dropout=0.1, layer_norm_epsilon=1e-6, pad_token_id=None, use_cache=False, rope_theta=10000.0, rope_scaling=None, max_position_embeddings=2048, initializer_range=0.02, use_flash_attn=True, num_experts=None, num_experts_per_tok=None, norm_topk_prob=False, hidden_act="silu", use_zero_memory=False, zero_memory_alpha=1.0, zero_memory_layers=None, gradient_scaling_enabled=True, association_gradient_scale=0.9, sensory_gradient_scale=0.95, cross_attention_gradient_scale=0.95, clamp_value=1e5, _attn_implementation='sdpa', **kwargs ): # Calculate intermediate_size if not provided if intermediate_size is None: intermediate_size = int(hidden_size * 4 * 2 / 3) super().__init__( pad_token_id=pad_token_id, **kwargs ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.kv_dim = kv_dim self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.num_sensory_cortex_layers = num_sensory_cortex_layers self.num_motor_cortex_layers = num_motor_cortex_layers self.num_association_cortex_layers = num_association_cortex_layers self.dropout = dropout self.layer_norm_epsilon = layer_norm_epsilon self.rms_norm_eps = layer_norm_epsilon self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.use_flash_attn = use_flash_attn self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.norm_topk_prob = norm_topk_prob self.hidden_act = hidden_act self.use_zero_memory = use_zero_memory self.zero_memory_alpha = zero_memory_alpha self.zero_memory_layers = zero_memory_layers self.gradient_scaling_enabled = gradient_scaling_enabled self.association_gradient_scale = association_gradient_scale self.sensory_gradient_scale = sensory_gradient_scale self.cross_attention_gradient_scale = cross_attention_gradient_scale self._attn_implementation = _attn_implementation self.clamp_value = clamp_value if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self)