mkurman commited on
Commit
500706e
·
verified ·
1 Parent(s): 3fb4a31

Upload 4 files

Browse files
neuroblast_model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from neuroblast_model.configuration_neuroblast import NeuroBLASTConfig
2
+ from neuroblast_model.modeling_neuroblast import NeuroBLASTForCausalLM
neuroblast_model/configuration_neuroblast.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.modeling_rope_utils import rope_config_validation
3
+ from typing import Optional
4
+
5
+
6
+ class NeuroBLASTConfig(PretrainedConfig):
7
+ model_type = "neuroblast"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=28886,
12
+ hidden_size=2048,
13
+ kv_dim=2048,
14
+ intermediate_size=None,
15
+ num_attention_heads=32,
16
+ num_sensory_cortex_layers=6,
17
+ num_motor_cortex_layers=6,
18
+ num_association_cortex_layers=6,
19
+ dropout=0.1,
20
+ layer_norm_epsilon=1e-6,
21
+ pad_token_id=None,
22
+ use_cache=False,
23
+ rope_theta=10000.0,
24
+ rope_scaling=None,
25
+ max_position_embeddings=2048,
26
+ initializer_range=0.02,
27
+ use_flash_attn=True,
28
+ num_experts=None,
29
+ num_experts_per_tok=None,
30
+ norm_topk_prob=False,
31
+ hidden_act="silu",
32
+ use_zero_memory=False,
33
+ zero_memory_alpha=1.0,
34
+ zero_memory_layers=None,
35
+ gradient_scaling_enabled=True,
36
+ association_gradient_scale=0.9,
37
+ sensory_gradient_scale=0.95,
38
+ cross_attention_gradient_scale=0.95,
39
+ clamp_value=1e5,
40
+ _attn_implementation='sdpa',
41
+ **kwargs
42
+ ):
43
+ # Calculate intermediate_size if not provided
44
+ if intermediate_size is None:
45
+ intermediate_size = int(hidden_size * 4 * 2 / 3)
46
+
47
+ super().__init__(
48
+ pad_token_id=pad_token_id,
49
+ **kwargs
50
+ )
51
+
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.kv_dim = kv_dim
55
+ self.intermediate_size = intermediate_size
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_sensory_cortex_layers = num_sensory_cortex_layers
58
+ self.num_motor_cortex_layers = num_motor_cortex_layers
59
+ self.num_association_cortex_layers = num_association_cortex_layers
60
+ self.dropout = dropout
61
+ self.layer_norm_epsilon = layer_norm_epsilon
62
+ self.rms_norm_eps = layer_norm_epsilon
63
+ self.use_cache = use_cache
64
+ self.rope_theta = rope_theta
65
+ self.rope_scaling = rope_scaling
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.initializer_range = initializer_range
68
+ self.use_flash_attn = use_flash_attn
69
+ self.num_experts = num_experts
70
+ self.num_experts_per_tok = num_experts_per_tok
71
+ self.norm_topk_prob = norm_topk_prob
72
+ self.hidden_act = hidden_act
73
+ self.use_zero_memory = use_zero_memory
74
+ self.zero_memory_alpha = zero_memory_alpha
75
+ self.zero_memory_layers = zero_memory_layers
76
+ self.gradient_scaling_enabled = gradient_scaling_enabled
77
+ self.association_gradient_scale = association_gradient_scale
78
+ self.sensory_gradient_scale = sensory_gradient_scale
79
+ self.cross_attention_gradient_scale = cross_attention_gradient_scale
80
+ self._attn_implementation = _attn_implementation
81
+ self.clamp_value = clamp_value
82
+
83
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
84
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
85
+ rope_config_validation(self)
neuroblast_model/modeling_neuroblast.py ADDED
@@ -0,0 +1,1961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, GenerationMixin
6
+ from transformers.cache_utils import DynamicCache, Cache
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
9
+ from transformers.utils import logging
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPast,
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.activations import ACT2FN
15
+ from typing import Optional, Tuple, Union, List
16
+ from neuroblast_model.configuration_neuroblast import NeuroBLASTConfig
17
+
18
+ CLAMP_VALUE = 1e5
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ def apply_gradient_scaling(
23
+ tensor: torch.Tensor, scale: float, enabled: bool = True
24
+ ) -> torch.Tensor:
25
+ """
26
+ Apply gradient scaling to a tensor.
27
+ This scales the gradients during backward pass while keeping forward pass unchanged.
28
+ """
29
+ if not enabled or scale == 1.0 or not tensor.requires_grad:
30
+ return tensor
31
+
32
+ # Use a custom autograd function for gradient scaling
33
+ class GradientScale(torch.autograd.Function):
34
+ @staticmethod
35
+ def forward(ctx, input_tensor, scale_factor):
36
+ ctx.scale = scale_factor
37
+ return input_tensor.clone()
38
+
39
+ @staticmethod
40
+ def backward(ctx, grad_output):
41
+ if grad_output is None:
42
+ return None, None
43
+ return grad_output * ctx.scale, None
44
+
45
+ return GradientScale.apply(tensor, scale)
46
+
47
+
48
+ def _prepare_4d_causal_attention_mask_with_cache_position(
49
+ attention_mask: torch.Tensor,
50
+ sequence_length: int,
51
+ target_length: int,
52
+ dtype: torch.dtype,
53
+ device: torch.device,
54
+ min_dtype: float,
55
+ cache_position: torch.Tensor,
56
+ batch_size: int,
57
+ ):
58
+ if attention_mask is not None and attention_mask.dim() == 4:
59
+ causal_mask = attention_mask
60
+ else:
61
+ causal_mask = torch.full(
62
+ (sequence_length, target_length),
63
+ fill_value=min_dtype,
64
+ dtype=dtype,
65
+ device=device,
66
+ )
67
+ if sequence_length != 1:
68
+ causal_mask = torch.triu(causal_mask, diagonal=1)
69
+ causal_mask *= torch.arange(
70
+ target_length, device=device
71
+ ) > cache_position.reshape(-1, 1)
72
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
73
+ if attention_mask is not None:
74
+ causal_mask = (
75
+ causal_mask.clone()
76
+ )
77
+ mask_length = attention_mask.shape[-1]
78
+ padding_mask = (
79
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
80
+ )
81
+ padding_mask = padding_mask == 0
82
+ causal_mask[:, :, :, :mask_length] = causal_mask[
83
+ :, :, :, :mask_length
84
+ ].masked_fill(padding_mask, min_dtype)
85
+
86
+ return causal_mask
87
+
88
+
89
+ # --- RoPE Implementation (using HF LlamaRotaryEmbedding) ---
90
+
91
+
92
+ class LlamaRMSNorm(nn.Module):
93
+ def __init__(self, hidden_size, eps=1e-6):
94
+ """
95
+ LlamaRMSNorm is equivalent to T5LayerNorm
96
+ """
97
+ super().__init__()
98
+ self.weight = nn.Parameter(torch.ones(hidden_size))
99
+ self.variance_epsilon = eps
100
+
101
+ def forward(self, hidden_states):
102
+ input_dtype = hidden_states.dtype
103
+ hidden_states = hidden_states.to(torch.float32)
104
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
105
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
106
+ return self.weight * hidden_states.to(input_dtype)
107
+
108
+ def extra_repr(self):
109
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
110
+
111
+
112
+ class NeuroBLASTRotaryEmbedding(nn.Module):
113
+ """
114
+ Rotary Positional Embedding for NeuroBLAST model.
115
+ Source: LlamaRotaryEmbedding
116
+ """
117
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
118
+ super().__init__()
119
+ self.dim = dim
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.base = base
122
+ inv_freq = 1.0 / (
123
+ self.base
124
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim)
125
+ )
126
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
127
+ self._set_cos_sin_cache(
128
+ seq_len=max_position_embeddings, device="cpu", dtype=torch.float32
129
+ )
130
+
131
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
132
+ self.max_seq_len_cached = seq_len
133
+ t = torch.arange(
134
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
135
+ )
136
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
+ emb = torch.cat((freqs, freqs), dim=-1)
138
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
+
141
+ def forward(self, x, seq_len=None):
142
+ if seq_len > self.max_seq_len_cached:
143
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
+ elif self.cos_cached.device != x.device or self.cos_cached.dtype != x.dtype:
145
+ self.cos_cached = self.cos_cached.to(device=x.device, dtype=x.dtype)
146
+ self.sin_cached = self.sin_cached.to(device=x.device, dtype=x.dtype)
147
+
148
+ return (
149
+ self.cos_cached[:seq_len],
150
+ self.sin_cached[:seq_len],
151
+ )
152
+
153
+
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
162
+ """ Applies rotary positional embeddings to query and key tensors."""
163
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
164
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
+ q_embed = (q * cos) + (rotate_half(q) * sin)
166
+ k_embed = (k * cos) + (rotate_half(k) * sin)
167
+ return q_embed, k_embed
168
+
169
+
170
+ # Overload for Cross Attention where only query is rotated
171
+ def apply_rotary_pos_emb_single(q, cos, sin, position_ids):
172
+ """ Applies rotary positional embeddings to query tensor. """
173
+ cos = cos[position_ids].unsqueeze(1) # [1, 1, seq_len, dim]
174
+ sin = sin[position_ids].unsqueeze(1) # [1, 1, seq_len, dim]
175
+ q_embed = (q * cos) + (rotate_half(q) * sin)
176
+ return q_embed
177
+
178
+
179
+ class SwiGLUMLP(nn.Module):
180
+ """SwiGLU MLP block"""
181
+
182
+ def __init__(self, hidden_size, config: NeuroBLASTConfig, dropout):
183
+ super().__init__()
184
+ intermediate_size = getattr(config, "intermediate_size", int(hidden_size * 2.5))
185
+ self.init_std = getattr(config, "initializer_range", 0.02)
186
+ self.clamp_value = config.clamp_value
187
+
188
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
189
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
190
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
191
+ self.act_fn = nn.SiLU()
192
+ self.dropout = nn.Dropout(dropout)
193
+
194
+ if config.num_experts is not None:
195
+ self.experts = NeuroBLASTSparseMoeBlock(config)
196
+
197
+ # Re-enable scaled initialization
198
+ with torch.no_grad():
199
+ # Scale down initial weights in the up/gate projections
200
+ self.gate_proj.weight.data.normal_(
201
+ mean=0.0,
202
+ std=self.init_std / math.sqrt(hidden_size), # Scale by input dim
203
+ )
204
+ if self.gate_proj.bias is not None:
205
+ self.gate_proj.bias.data.zero_()
206
+
207
+ self.up_proj.weight.data.normal_(
208
+ mean=0.0, std=self.init_std / math.sqrt(hidden_size)
209
+ ) # Scale by input dim
210
+ # Scale down initial weights in the down projection even further
211
+ self.down_proj.weight.data.normal_(
212
+ mean=0.0,
213
+ std=self.init_std
214
+ / math.sqrt(intermediate_size), # Scale by intermediate dim
215
+ )
216
+
217
+ def forward(self, x):
218
+ gated_x = self.gate_proj(x)
219
+ activated_x = self.act_fn(gated_x)
220
+ up_projected_x = self.up_proj(x)
221
+
222
+ intermediate_activation = activated_x * up_projected_x
223
+
224
+ # Clamp the intermediate activation before down_proj
225
+ clamp_value = self.clamp_value
226
+
227
+ intermediate_activation = torch.clamp(
228
+ intermediate_activation, min=-clamp_value, max=clamp_value
229
+ )
230
+ intermediate_activation = torch.nan_to_num(
231
+ intermediate_activation
232
+ ) # Safeguard against NaNs
233
+
234
+ y = self.down_proj(intermediate_activation)
235
+ y = self.dropout(y)
236
+
237
+ if hasattr(self, "experts"):
238
+ z = self.experts(y)
239
+
240
+ y = y + z
241
+
242
+ return y
243
+
244
+
245
+ class NeuroBLASTMoeMLP(nn.Module):
246
+ """ Source: Qwen3MoeMLP """
247
+ def __init__(self, config, intermediate_size=None):
248
+ super().__init__()
249
+ self.config = config
250
+ self.clamp_value = config.clamp_value
251
+ self.hidden_size = config.hidden_size
252
+ self.intermediate_size = (
253
+ intermediate_size
254
+ if intermediate_size is not None
255
+ else config.intermediate_size
256
+ )
257
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
258
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
259
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
260
+ self.act_fn = ACT2FN[config.hidden_act]
261
+
262
+ def forward(self, x):
263
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
264
+
265
+ down_proj = torch.clamp(down_proj, min=-self.clamp_value, max=self.clamp_value)
266
+ down_proj = torch.nan_to_num(down_proj) # Safeguard against NaNs
267
+
268
+ return down_proj
269
+
270
+
271
+ class NeuroBLASTSparseMoeBlock(nn.Module):
272
+ """ Source: Qwen3SparseMoeBlock """
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.num_experts = config.num_experts
276
+ self.top_k = config.num_experts_per_tok
277
+ self.norm_topk_prob = config.norm_topk_prob
278
+
279
+ # gating
280
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
281
+ self.experts = nn.ModuleList(
282
+ [NeuroBLASTMoeMLP(config) for _ in range(self.num_experts)]
283
+ )
284
+
285
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
286
+ """ """
287
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
288
+ hidden_states = hidden_states.view(-1, hidden_dim)
289
+ # router_logits: (batch * sequence_length, n_experts)
290
+ router_logits = self.gate(hidden_states)
291
+
292
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
293
+ routing_weights, selected_experts = torch.topk(
294
+ routing_weights, self.top_k, dim=-1
295
+ )
296
+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
297
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
298
+ # we cast back to the input dtype
299
+ routing_weights = routing_weights.to(hidden_states.dtype)
300
+
301
+ final_hidden_states = torch.zeros(
302
+ (batch_size * sequence_length, hidden_dim),
303
+ dtype=hidden_states.dtype,
304
+ device=hidden_states.device,
305
+ )
306
+
307
+ # One hot encode the selected experts to create an expert mask
308
+ # this will be used to easily index which expert is going to be sollicitated
309
+ expert_mask = torch.nn.functional.one_hot(
310
+ selected_experts, num_classes=self.num_experts
311
+ ).permute(2, 1, 0)
312
+
313
+ # Loop over all available experts in the model and perform the computation on each expert
314
+ for expert_idx in range(self.num_experts):
315
+ expert_layer = self.experts[expert_idx]
316
+ idx, top_x = torch.where(expert_mask[expert_idx])
317
+
318
+ # Index the correct hidden states and compute the expert hidden state for
319
+ # the current expert. We need to make sure to multiply the output hidden
320
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
321
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
322
+ current_hidden_states = (
323
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
324
+ )
325
+
326
+ # However `index_add_` only support torch tensors for indexing so we'll use
327
+ # the `top_x` tensor here.
328
+ final_hidden_states.index_add_(
329
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
330
+ )
331
+ final_hidden_states = final_hidden_states.reshape(
332
+ batch_size, sequence_length, hidden_dim
333
+ )
334
+ return final_hidden_states
335
+
336
+
337
+ class NeuroBLASTRouterBlock(nn.Module):
338
+ """ Memory router; overcomplicated due to backward compatibility """
339
+ def __init__(
340
+ self,
341
+ config,
342
+ hidden_size,
343
+ ):
344
+ super().__init__()
345
+ self.num_experts = 2
346
+ self.top_k = 1
347
+ self.norm_topk_prob = config.norm_topk_prob
348
+
349
+ # gating
350
+ self.gate = nn.Linear(hidden_size, self.num_experts, bias=False)
351
+
352
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
353
+ """ """
354
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
355
+ hidden_states = hidden_states.view(-1, hidden_dim)
356
+ # router_logits: (batch * sequence_length, n_experts)
357
+ router_logits = self.gate(hidden_states)
358
+
359
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
360
+ routing_weights, selected_experts = torch.topk(
361
+ routing_weights, self.top_k, dim=-1
362
+ )
363
+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
364
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
365
+ # we cast back to the input dtype
366
+ routing_weights = routing_weights.to(hidden_states.dtype)
367
+
368
+ return routing_weights, selected_experts
369
+
370
+
371
+ class SelfAttention(torch.nn.Module):
372
+ def __init__(
373
+ self,
374
+ config: NeuroBLASTConfig,
375
+ hidden_size: int,
376
+ is_causal: bool = False,
377
+ layer_idx: Optional[int] = None,
378
+ ):
379
+ super().__init__()
380
+ self.is_causal = is_causal
381
+ self.dropout_p = config.dropout # Will apply based on self.training
382
+ self.layer_idx = layer_idx
383
+ self.hidden_size = hidden_size
384
+ self.intermediate_size = config.kv_dim
385
+ self.use_flash_attn = config.use_flash_attn
386
+ # Allow overriding num_heads, default to config.num_attention_heads
387
+ self.num_heads = getattr(
388
+ config, f"num_heads_{layer_idx}", config.num_attention_heads
389
+ )
390
+ self.head_dim = self.intermediate_size // self.num_heads
391
+
392
+ if (self.head_dim * self.num_heads) != self.intermediate_size:
393
+ raise ValueError(
394
+ f"Layer {self.layer_idx}: hidden_size ({self.intermediate_size}) must be divisible by num_heads ({self.num_heads})"
395
+ )
396
+
397
+ self.qkv_proj = nn.Linear(
398
+ self.hidden_size, self.intermediate_size * 3, bias=True
399
+ )
400
+ self.o_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
401
+ self.dropout = nn.Dropout(config.dropout)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor], # (cos, sin)
407
+ attention_mask: Optional[
408
+ torch.Tensor
409
+ ], # Not directly used by flash_attn causal
410
+ past_key_value: Optional[Cache] = None,
411
+ cache_position: Optional[torch.LongTensor] = None,
412
+ output_attentions: Optional[bool] = False,
413
+ use_cache: Optional[bool] = False,
414
+ position_ids: Optional[torch.LongTensor] = None,
415
+ ):
416
+ batch_size, seq_len, _ = hidden_states.shape
417
+ dropout_p = self.dropout_p if self.training else 0.0
418
+
419
+ qkv = self.qkv_proj(hidden_states)
420
+ query_states, key_states, value_states = qkv.chunk(3, dim=-1)
421
+
422
+ query_states = query_states.view(
423
+ batch_size, seq_len, self.num_heads, self.head_dim
424
+ ).transpose(1, 2)
425
+ key_states = key_states.view(
426
+ batch_size, seq_len, self.num_heads, self.head_dim
427
+ ).transpose(1, 2)
428
+ value_states = value_states.view(
429
+ batch_size, seq_len, self.num_heads, self.head_dim
430
+ ).transpose(1, 2)
431
+
432
+ cos, sin = position_embeddings
433
+ query_states, key_states = apply_rotary_pos_emb(
434
+ query_states, key_states, cos, sin, position_ids
435
+ )
436
+
437
+ if past_key_value is not None:
438
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
439
+ # When using Cache object, updating happens in-place.
440
+ key_states, value_states = past_key_value.update(
441
+ key_states, value_states, self.layer_idx, cache_kwargs
442
+ )
443
+
444
+ if self.use_flash_attn:
445
+ causal_mask = attention_mask
446
+ if attention_mask is not None:
447
+ if attention_mask.dim() == 4:
448
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
449
+ elif attention_mask.dim() == 2:
450
+ causal_mask = attention_mask
451
+
452
+ if causal_mask.dtype not in [
453
+ torch.bool,
454
+ torch.float16,
455
+ torch.float32,
456
+ torch.bfloat16,
457
+ ]:
458
+ causal_mask = causal_mask.to(query_states.dtype)
459
+
460
+ is_causal = (
461
+ True if causal_mask is None and query_states.shape[-2] > 1 else False
462
+ )
463
+
464
+ attn_output = F.scaled_dot_product_attention(
465
+ query_states,
466
+ key_states,
467
+ value_states,
468
+ attn_mask=causal_mask,
469
+ dropout_p=dropout_p,
470
+ enable_gqa=False,
471
+ scale=self.head_dim**-0.5,
472
+ is_causal=is_causal,
473
+ )
474
+ attn_weights = None
475
+ else:
476
+ attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / (
477
+ self.head_dim**0.5
478
+ )
479
+
480
+ if attention_mask is not None:
481
+ attn_weights = attn_weights + attention_mask
482
+ elif self.is_causal and seq_len > 1:
483
+ causal_mask = torch.triu(
484
+ torch.ones(
485
+ (seq_len, key_states.shape[2]),
486
+ dtype=torch.bool,
487
+ device=query_states.device,
488
+ ),
489
+ diagonal=1,
490
+ )
491
+ attn_weights = attn_weights.masked_fill(
492
+ causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")
493
+ )
494
+
495
+ attn_weights = F.softmax(attn_weights, dim=-1)
496
+ attn_weights_dropped = F.dropout(
497
+ attn_weights, p=dropout_p, training=self.training
498
+ )
499
+ attn_output = torch.matmul(attn_weights_dropped, value_states)
500
+
501
+ attn_output = attn_output.transpose(1, 2).contiguous()
502
+ attn_output = attn_output.view(batch_size, seq_len, -1)
503
+ attn_output = self.o_proj(attn_output)
504
+ attn_output = self.dropout(attn_output)
505
+
506
+ if output_attentions:
507
+ logger.warning(
508
+ f"Layer {self.layer_idx}: Flash Attention does not return attention weights."
509
+ )
510
+
511
+ outputs = (attn_output,)
512
+ if output_attentions:
513
+ outputs += (attn_weights,)
514
+ if use_cache:
515
+ outputs += (past_key_value,) # Return the cache object
516
+
517
+ return outputs
518
+
519
+
520
+ class CrossAttention(torch.nn.Module):
521
+ def __init__(
522
+ self,
523
+ config: NeuroBLASTConfig,
524
+ query_dim: int,
525
+ kv_dim: int,
526
+ layer_idx: int,
527
+ is_causal: bool = True,
528
+ ):
529
+ super().__init__()
530
+ self.dropout_p = config.dropout
531
+ self.layer_idx = layer_idx
532
+ self.query_dim = query_dim
533
+ self.kv_dim = kv_dim
534
+ self.is_causal = is_causal
535
+
536
+ self.num_heads = config.num_attention_heads
537
+ self.head_dim = self.kv_dim // self.num_heads
538
+ self.kv_head_dim = (
539
+ self.kv_dim // self.num_heads
540
+ )
541
+
542
+ if (self.head_dim * self.num_heads) != self.kv_dim:
543
+ raise ValueError(
544
+ f"CrossAttn {layer_idx}: query_dim ({self.kv_dim}) must be divisible by num_heads ({self.num_heads})"
545
+ )
546
+ if (self.kv_head_dim * self.num_heads) != self.kv_dim:
547
+ raise ValueError(
548
+ f"CrossAttn {layer_idx}: kv_dim ({kv_dim}) must be divisible by num_heads ({self.num_heads})"
549
+ )
550
+
551
+ self.q_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
552
+ self.k_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
553
+ self.v_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
554
+ self.o_proj = nn.Linear(self.kv_dim, self.query_dim, bias=False)
555
+ self.dropout = nn.Dropout(config.dropout)
556
+
557
+ self.use_flash_attn = hasattr(
558
+ F, "scaled_dot_product_attention"
559
+ )
560
+
561
+ def forward(
562
+ self,
563
+ query_states: torch.Tensor,
564
+ kv_states: torch.Tensor,
565
+ position_embeddings: Tuple[
566
+ torch.Tensor, torch.Tensor
567
+ ],
568
+ past_key_value: Optional[Cache] = None,
569
+ cache_position: Optional[torch.LongTensor] = None,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ output_attentions: Optional[bool] = False,
572
+ position_ids: Optional[torch.LongTensor] = None,
573
+ use_cache: Optional[bool] = False,
574
+ ):
575
+ batch_size, q_seq_len, _ = query_states.shape
576
+ kv_seq_len = kv_states.shape[1]
577
+ dropout_p = self.dropout_p if self.training else 0.0
578
+
579
+ query = self.q_proj(query_states)
580
+ key = self.k_proj(kv_states)
581
+ value = self.v_proj(kv_states)
582
+
583
+ query = query.view(
584
+ batch_size, q_seq_len, self.num_heads, self.head_dim
585
+ ).transpose(1, 2)
586
+
587
+ cos, sin = position_embeddings
588
+ query = apply_rotary_pos_emb_single(query, cos, sin, position_ids)
589
+
590
+ if past_key_value is not None:
591
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
592
+ key, value = past_key_value.update(
593
+ key, value, self.layer_idx or 0, cache_kwargs
594
+ )
595
+ kv_seq_len = key.shape[1]
596
+
597
+ key = key.view(
598
+ batch_size, kv_seq_len, self.num_heads, self.kv_head_dim
599
+ ).transpose(1, 2)
600
+ value = value.view(
601
+ batch_size, kv_seq_len, self.num_heads, self.kv_head_dim
602
+ ).transpose(1, 2)
603
+
604
+ sdpa_attn_mask = attention_mask
605
+
606
+ if self.use_flash_attn:
607
+ is_causal = True if sdpa_attn_mask is None and q_seq_len > 1 else False
608
+ attn_output = F.scaled_dot_product_attention(
609
+ query,
610
+ key,
611
+ value,
612
+ attn_mask=sdpa_attn_mask if not is_causal else None,
613
+ dropout_p=dropout_p,
614
+ is_causal=is_causal,
615
+ enable_gqa=False,
616
+ scale=self.head_dim**-0.5,
617
+ )
618
+ attn_weights = None
619
+ else:
620
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) / (
621
+ self.head_dim**0.5
622
+ )
623
+
624
+ if sdpa_attn_mask is not None:
625
+ attn_weights = attn_weights + sdpa_attn_mask
626
+ elif self.is_causal and q_seq_len > 1:
627
+ causal_mask = torch.triu(
628
+ torch.ones(
629
+ (q_seq_len, kv_seq_len), dtype=torch.bool, device=query.device
630
+ ),
631
+ diagonal=1,
632
+ )
633
+ attn_weights = attn_weights.masked_fill(
634
+ causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")
635
+ )
636
+
637
+ attn_weights = F.softmax(attn_weights, dim=-1)
638
+ attn_weights_dropped = F.dropout(
639
+ attn_weights, p=dropout_p, training=self.training
640
+ )
641
+ attn_output = torch.matmul(attn_weights_dropped, value)
642
+
643
+ attn_output = attn_output.transpose(
644
+ 1, 2
645
+ ).contiguous()
646
+ attn_output = attn_output.reshape(batch_size, q_seq_len, self.kv_dim)
647
+
648
+ attn_output = self.o_proj(attn_output)
649
+ attn_output = self.dropout(attn_output)
650
+
651
+ outputs = (attn_output,)
652
+ if output_attentions:
653
+ outputs += (attn_weights,)
654
+
655
+ if use_cache:
656
+ outputs += (past_key_value,)
657
+
658
+ return outputs
659
+
660
+
661
+ class AttentionBlock(torch.nn.Module):
662
+ """Modified Attention Block with Pre-Norm and choice of Self/Cross Attention & MLP"""
663
+
664
+ def __init__(
665
+ self,
666
+ config: NeuroBLASTConfig,
667
+ hidden_size: int,
668
+ attention_module: nn.Module,
669
+ mlp_module: nn.Module,
670
+ is_cross_attention: bool = False,
671
+ layer_idx: int = 0,
672
+ precomputed_total_layers: Optional[int] = None,
673
+ ):
674
+ super().__init__()
675
+ self.hidden_size = hidden_size
676
+ self.config = config
677
+ self.layer_idx = layer_idx
678
+ self.is_cross_attention = is_cross_attention
679
+
680
+ self.input_layernorm = nn.LayerNorm(
681
+ hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
682
+ )
683
+ self.attention = attention_module
684
+
685
+ self.post_attention_layernorm = nn.LayerNorm(
686
+ hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
687
+ )
688
+ self.mlp = mlp_module
689
+
690
+ if self.config.use_zero_memory and (
691
+ self.config.zero_memory_layers is None
692
+ or self.layer_idx in self.config.zero_memory_layers
693
+ ):
694
+ self.router = NeuroBLASTRouterBlock(config, hidden_size)
695
+ self.memory = NeuroBLASTMemory(
696
+ config,
697
+ hidden_size=hidden_size,
698
+ layer_idx=layer_idx,
699
+ precomputed_total_layers=precomputed_total_layers,
700
+ )
701
+
702
+ def forward(
703
+ self,
704
+ hidden_states: torch.Tensor,
705
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
706
+ attention_mask: Optional[torch.Tensor],
707
+ position_ids: Optional[torch.LongTensor],
708
+ kv_states: Optional[torch.Tensor] = None,
709
+ past_key_value: Optional[Cache] = None,
710
+ cache_position: Optional[torch.LongTensor] = None,
711
+ output_attentions: Optional[bool] = False,
712
+ use_cache: Optional[bool] = False,
713
+ previous_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
714
+ ):
715
+ residual = hidden_states
716
+ hidden_states = torch.nan_to_num(hidden_states)
717
+ normed_hidden_states = self.input_layernorm(hidden_states)
718
+
719
+ if self.is_cross_attention:
720
+ if kv_states is None:
721
+ raise ValueError("kv_states must be provided for CrossAttention")
722
+ attn_outputs = self.attention(
723
+ query_states=normed_hidden_states,
724
+ kv_states=kv_states,
725
+ past_key_value=past_key_value,
726
+ cache_position=cache_position,
727
+ position_embeddings=position_embeddings,
728
+ attention_mask=attention_mask,
729
+ output_attentions=output_attentions,
730
+ position_ids=position_ids,
731
+ use_cache=use_cache,
732
+ )
733
+ else:
734
+ attn_outputs = self.attention(
735
+ normed_hidden_states,
736
+ position_embeddings=position_embeddings,
737
+ attention_mask=attention_mask,
738
+ past_key_value=past_key_value,
739
+ cache_position=cache_position,
740
+ output_attentions=output_attentions,
741
+ use_cache=use_cache,
742
+ position_ids=position_ids,
743
+ )
744
+ attn_output = attn_outputs[0]
745
+ past_key_value = attn_outputs[-1] if use_cache else None
746
+
747
+ hidden_states = residual + attn_output
748
+ hidden_states = torch.nan_to_num(hidden_states)
749
+
750
+ residual = hidden_states
751
+
752
+ normed_hidden_states = self.post_attention_layernorm(hidden_states)
753
+
754
+ mlp_output = self.mlp(normed_hidden_states)
755
+
756
+ hidden_states = residual + mlp_output
757
+
758
+ hidden_states = torch.nan_to_num(hidden_states)
759
+
760
+ if self.config.use_zero_memory and (
761
+ self.config.zero_memory_layers is None
762
+ or self.layer_idx in self.config.zero_memory_layers
763
+ ):
764
+ routing_weights, selected_experts = self.router(hidden_states)
765
+
766
+ residual = hidden_states
767
+
768
+ hidden_states, (hx, cx), past_key_value = self.memory(
769
+ hidden_states,
770
+ previous_states,
771
+ past_key_value=past_key_value,
772
+ cache_position=cache_position,
773
+ position_embeddings=position_embeddings,
774
+ attention_mask=attention_mask,
775
+ output_attentions=output_attentions,
776
+ position_ids=position_ids,
777
+ use_cache=use_cache,
778
+ )
779
+
780
+ hidden_states = torch.nan_to_num(hidden_states)
781
+ hx = torch.nan_to_num(hx)
782
+ cx = torch.nan_to_num(cx)
783
+
784
+ hidden_states = hidden_states * routing_weights.reshape(
785
+ hidden_states.shape[:-1]
786
+ ).unsqueeze(-1)
787
+
788
+ hidden_states = residual + self.config.zero_memory_alpha * hidden_states
789
+ hidden_states = torch.nan_to_num(hidden_states)
790
+
791
+ outputs = (hidden_states,) + attn_outputs[1:]
792
+
793
+ if self.config.use_zero_memory and (
794
+ self.config.zero_memory_layers is None
795
+ or self.layer_idx in self.config.zero_memory_layers
796
+ ):
797
+ outputs += ((hx, cx),)
798
+ else:
799
+ outputs += (previous_states,)
800
+
801
+ if use_cache:
802
+ outputs += (past_key_value,)
803
+
804
+ return outputs
805
+
806
+
807
+ class NeuroBLASTMemory(nn.Module):
808
+ def __init__(
809
+ self,
810
+ config: NeuroBLASTConfig,
811
+ hidden_size: int = 256,
812
+ num_heads: int = 4,
813
+ scale_factor: int = 4,
814
+ layer_idx: int = 0,
815
+ with_hx: bool = True,
816
+ precomputed_total_layers: Optional[int] = None,
817
+ *args,
818
+ **kwargs,
819
+ ):
820
+ super().__init__(*args, **kwargs)
821
+
822
+ self.hidden_size = hidden_size
823
+ self.num_heads = num_heads
824
+ self.scale_factor = scale_factor
825
+ self.clamp_value = config.clamp_value
826
+ # Use precomputed_total_layers instead of hardcoded 100 for layer index shift
827
+ layer_shift = (
828
+ precomputed_total_layers if precomputed_total_layers is not None else 100
829
+ )
830
+ self.layer_idx = layer_idx + layer_shift
831
+ self.with_hx = with_hx
832
+ self.kv_dim = (
833
+ config.kv_dim
834
+ )
835
+
836
+ self.scaled_dim = hidden_size * scale_factor
837
+ self.head_dim = self.kv_dim // config.num_attention_heads
838
+ self.num_heads = self.hidden_size // self.head_dim
839
+
840
+ self.norm1 = nn.LayerNorm(hidden_size)
841
+
842
+ if self.with_hx:
843
+ self.lin1 = nn.Linear(self.hidden_size, self.scaled_dim)
844
+ self.lin2 = nn.Linear(self.hidden_size, self.scaled_dim)
845
+ self.lin3 = nn.Linear(self.scaled_dim, self.hidden_size)
846
+
847
+ self.lin4 = nn.Linear(self.hidden_size, self.scaled_dim)
848
+ self.lin5 = nn.Linear(self.scaled_dim, self.hidden_size)
849
+
850
+ if self.with_hx:
851
+ self.lin6 = nn.Linear(self.scaled_dim, self.hidden_size)
852
+
853
+ self.gate1 = nn.Linear(self.scaled_dim, self.scaled_dim)
854
+ self.act1 = nn.SiLU()
855
+
856
+ self.gate2 = nn.Linear(self.scaled_dim, self.scaled_dim)
857
+ self.act2 = nn.SiLU()
858
+
859
+ self.last_token_reg = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
860
+ self.prev_tokens_reg = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
861
+
862
+ self.norm2 = nn.LayerNorm(hidden_size)
863
+ self.dropout = nn.Dropout(config.dropout)
864
+
865
+ def forward(
866
+ self,
867
+ x: torch.Tensor,
868
+ previous_state: tuple[torch.Tensor, torch.Tensor],
869
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
870
+ attention_mask: Optional[torch.Tensor],
871
+ position_ids: Optional[torch.LongTensor],
872
+ past_key_value: Optional[Cache] = None,
873
+ cache_position: Optional[torch.LongTensor] = None,
874
+ output_attentions: Optional[bool] = False,
875
+ use_cache: Optional[bool] = False,
876
+ ):
877
+ hx, cx = previous_state
878
+
879
+ b, s, d = x.size()
880
+
881
+ x = torch.nan_to_num(x)
882
+ norm_x = self.norm1(x)
883
+
884
+ norm_x = norm_x.view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
885
+
886
+ cos, sin = position_embeddings
887
+ norm_x = apply_rotary_pos_emb_single(norm_x, cos, sin, position_ids)
888
+
889
+ kv_seq_len = s
890
+
891
+ if past_key_value is not None:
892
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
893
+ norm_x, _ = past_key_value.update(
894
+ norm_x,
895
+ torch.zeros((b, 1, kv_seq_len, d)),
896
+ self.layer_idx,
897
+ cache_kwargs,
898
+ )
899
+ kv_seq_len = norm_x.shape[2]
900
+
901
+ norm_x = norm_x.transpose(1, 2).contiguous()
902
+ norm_x = norm_x.view(b, kv_seq_len, d)
903
+
904
+ norm_x = torch.nan_to_num(norm_x)
905
+
906
+ expanded_x = None
907
+
908
+ shifted_x = torch.cat(
909
+ [
910
+ torch.zeros((b, 1, d), device=x.device, dtype=x.dtype),
911
+ (norm_x[:, :-1].contiguous()),
912
+ ],
913
+ dim=1,
914
+ ).contiguous()
915
+
916
+ shifted_x = torch.nan_to_num(shifted_x) # Replace NaNs with zeros
917
+
918
+ prev_tokens_x = norm_x.cumsum(dim=1)
919
+
920
+ prev_tokens_x = prev_tokens_x - shifted_x
921
+ prev_tokens_x = torch.nan_to_num(prev_tokens_x)[:, -s:].contiguous()
922
+
923
+ if self.with_hx:
924
+ expanded_x = self.lin1(norm_x[:, -s:].contiguous())
925
+ expanded_x = torch.nan_to_num(expanded_x)
926
+
927
+ expanded_shifted_x = self.lin2(shifted_x[:, -s:].contiguous())
928
+
929
+ expanded_shifted_x = torch.nan_to_num(expanded_shifted_x)
930
+
931
+ gated_shifted_x = self.gate1(expanded_shifted_x)
932
+ gated_shifted_x = self.act1(gated_shifted_x)
933
+ gated_shifted_x = torch.clamp(
934
+ gated_shifted_x, min=-self.clamp_value, max=self.clamp_value
935
+ )
936
+
937
+ gated_shifted_x = torch.nan_to_num(gated_shifted_x)
938
+
939
+ collapsed_shifted_x = self.lin3(gated_shifted_x)
940
+ collapsed_shifted_x = torch.nan_to_num(collapsed_shifted_x)
941
+
942
+ prev_tokens_x = torch.nan_to_num(prev_tokens_x)
943
+
944
+ expanded_prev_tokens_x = self.lin4(prev_tokens_x)
945
+ expanded_prev_tokens_x = torch.nan_to_num(expanded_prev_tokens_x)
946
+
947
+ gated_prev_tokens_x = self.gate2(expanded_prev_tokens_x)
948
+ gated_prev_tokens_x = self.act2(gated_prev_tokens_x)
949
+ gated_prev_tokens_x = torch.clamp(
950
+ gated_prev_tokens_x, min=-self.clamp_value, max=self.clamp_value
951
+ )
952
+
953
+ gated_prev_tokens_x = torch.nan_to_num(gated_prev_tokens_x)
954
+
955
+ collapsed_prev_tokens_x = self.lin5(gated_prev_tokens_x)
956
+ collapsed_prev_tokens_x = torch.nan_to_num(collapsed_prev_tokens_x)
957
+
958
+ if self.with_hx:
959
+ weights = torch.softmax(expanded_x * expanded_shifted_x, dim=-1)
960
+
961
+ expanded_x_attn = weights * expanded_x
962
+
963
+ expanded_x_attn = torch.nan_to_num(expanded_x_attn)
964
+ hx = hx + self.lin6(expanded_x_attn)
965
+ hx = torch.nan_to_num(hx)
966
+
967
+ if self.with_hx:
968
+ x = torch.nan_to_num(x)
969
+ hx = torch.nan_to_num(hx)
970
+ collapsed_shifted_x = torch.nan_to_num(collapsed_shifted_x)
971
+ collapsed_prev_tokens_x = torch.nan_to_num(collapsed_prev_tokens_x)
972
+ output = x + (
973
+ hx
974
+ * (
975
+ self.last_token_reg(collapsed_shifted_x)
976
+ + self.prev_tokens_reg(collapsed_prev_tokens_x)
977
+ )
978
+ )
979
+ output = torch.nan_to_num(output)
980
+
981
+ else:
982
+ output = (
983
+ x
984
+ + (
985
+ self.last_token_reg(collapsed_shifted_x)
986
+ + self.prev_tokens_reg(collapsed_prev_tokens_x)
987
+ )[:, -s:].contiguous()
988
+ )
989
+
990
+ output = self.norm2(output)
991
+ output = self.dropout(output)
992
+ output = torch.nan_to_num(output)
993
+
994
+ return (
995
+ output,
996
+ (hx, cx),
997
+ past_key_value,
998
+ )
999
+
1000
+
1001
+ class NeuroBLASTPreTrainedModel(PreTrainedModel):
1002
+ config_class = NeuroBLASTConfig
1003
+ base_model_prefix = "brain"
1004
+ supports_gradient_checkpointing = True
1005
+ _no_split_modules = []
1006
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1007
+ _supports_flash_attn_2 = False
1008
+ _supports_sdpa = True
1009
+
1010
+ def _init_weights(self, module):
1011
+ std = getattr(self.config, "initializer_range", 0.02)
1012
+ if isinstance(module, nn.Linear):
1013
+ module.weight.data.normal_(mean=0.0, std=std)
1014
+ if module.bias is not None:
1015
+ module.bias.data.zero_()
1016
+ elif isinstance(module, nn.Embedding):
1017
+ module.weight.data.normal_(mean=0.0, std=std)
1018
+ if module.padding_idx is not None:
1019
+ module.weight.data[module.padding_idx].zero_()
1020
+ elif isinstance(module, nn.LayerNorm):
1021
+ if module.bias is not None:
1022
+ module.bias.data.zero_()
1023
+ module.weight.data.fill_(1.0)
1024
+
1025
+
1026
+ class NeuroBLASTModel(NeuroBLASTPreTrainedModel):
1027
+ def __init__(self, config: NeuroBLASTConfig):
1028
+ super(NeuroBLASTModel, self).__init__(config)
1029
+ self.config = config
1030
+ self.padding_idx = config.pad_token_id
1031
+ self.vocab_size = config.vocab_size
1032
+
1033
+ self.embed_tokens = nn.Embedding(
1034
+ config.vocab_size,
1035
+ config.hidden_size, # Using main hidden_size for embeddings now
1036
+ padding_idx=self.padding_idx,
1037
+ )
1038
+ self.rotary_emb = NeuroBLASTRotaryEmbedding(
1039
+ config.kv_dim // config.num_attention_heads,
1040
+ max_position_embeddings=config.max_position_embeddings,
1041
+ base=config.rope_theta,
1042
+ )
1043
+ self.dropout = nn.Dropout(config.dropout)
1044
+
1045
+ self.assoc_to_sensory_pooler = nn.Sequential(
1046
+ nn.Linear(config.hidden_size, config.hidden_size),
1047
+ nn.Identity(), # Backward compatibility - previously LayerNorm, but we found that removing it improve generalization
1048
+ nn.GELU(),
1049
+ nn.LayerNorm(config.hidden_size),
1050
+ )
1051
+ self.assoc_to_motor_pooler = nn.Sequential(
1052
+ nn.Linear(config.hidden_size, config.hidden_size),
1053
+ nn.Identity(), # Backward compatibility
1054
+ nn.GELU(),
1055
+ nn.LayerNorm(config.hidden_size),
1056
+ )
1057
+ self.sensory_to_motor_pooler = nn.Sequential(
1058
+ nn.Linear(config.hidden_size, config.hidden_size),
1059
+ nn.Identity(), # Backward compatibility
1060
+ nn.GELU(),
1061
+ nn.LayerNorm(config.hidden_size),
1062
+ )
1063
+
1064
+ # --- Cortex Layers ---
1065
+ # Using generic AttentionBlock with specific Self/Cross Attention modules passed in
1066
+ total_layers = 0
1067
+
1068
+ # Precompute total layers for Memory layer indexing before creating any layers
1069
+ precomputed_total_layers = (
1070
+ config.num_association_cortex_layers
1071
+ + config.num_sensory_cortex_layers * 2
1072
+ + config.num_motor_cortex_layers
1073
+ * 3
1074
+ )
1075
+ self.precomputed_total_layers = precomputed_total_layers
1076
+ config.precomputed_total_layers = precomputed_total_layers
1077
+
1078
+ # 1. Association Cortex (Self-Attention)
1079
+ self.association_cortex = nn.ModuleList()
1080
+ for i in range(config.num_association_cortex_layers):
1081
+ layer_idx = total_layers + i
1082
+ print(f"Adding layer {layer_idx} to association cortex")
1083
+ self.association_cortex.append(
1084
+ AttentionBlock(
1085
+ config,
1086
+ config.hidden_size, # Use main hidden_size
1087
+ attention_module=SelfAttention(
1088
+ config, config.hidden_size, is_causal=True, layer_idx=layer_idx
1089
+ ),
1090
+ mlp_module=(
1091
+ NeuroBLASTSparseMoeBlock(
1092
+ config,
1093
+ )
1094
+ if config.num_experts
1095
+ else NeuroBLASTMoeMLP(config)
1096
+ ),
1097
+ is_cross_attention=False,
1098
+ layer_idx=layer_idx,
1099
+ precomputed_total_layers=precomputed_total_layers,
1100
+ )
1101
+ )
1102
+
1103
+ total_layers += config.num_association_cortex_layers
1104
+ # 2. Sensory Cortex (Self-Attention + Cross-Attention to Association)
1105
+ self.sensory_self_attn_layers = nn.ModuleList()
1106
+ self.sensory_cross_attn_layers = (
1107
+ nn.ModuleList()
1108
+ ) # One cross-attn per self-attn layer
1109
+ for i in range(config.num_sensory_cortex_layers):
1110
+ layer_idx = total_layers + i
1111
+ print(f"Adding layer {layer_idx} to sensory cortex")
1112
+ self.sensory_self_attn_layers.append(
1113
+ AttentionBlock(
1114
+ config,
1115
+ config.hidden_size,
1116
+ attention_module=SelfAttention(
1117
+ config,
1118
+ config.hidden_size,
1119
+ is_causal=True,
1120
+ layer_idx=layer_idx,
1121
+ ),
1122
+ mlp_module=(
1123
+ NeuroBLASTSparseMoeBlock(
1124
+ config,
1125
+ )
1126
+ if config.num_experts
1127
+ else NeuroBLASTMoeMLP(config)
1128
+ ),
1129
+ is_cross_attention=False,
1130
+ layer_idx=layer_idx,
1131
+ precomputed_total_layers=precomputed_total_layers,
1132
+ )
1133
+ )
1134
+
1135
+ total_layers += config.num_sensory_cortex_layers
1136
+ for i in range(config.num_sensory_cortex_layers):
1137
+ layer_idx = total_layers + i
1138
+ print(f"Adding layer {layer_idx} to sensory cross-attention")
1139
+ # Add Cross-Attention layer: Sensory queries, Association is K/V source
1140
+ self.sensory_cross_attn_layers.append(
1141
+ AttentionBlock(
1142
+ config,
1143
+ config.hidden_size, # Query Dim
1144
+ attention_module=CrossAttention(
1145
+ config,
1146
+ query_dim=config.hidden_size,
1147
+ kv_dim=config.kv_dim, # Assoc output dim
1148
+ layer_idx=layer_idx,
1149
+ ),
1150
+ mlp_module=(
1151
+ NeuroBLASTSparseMoeBlock(
1152
+ config,
1153
+ )
1154
+ if config.num_experts
1155
+ else NeuroBLASTMoeMLP(config)
1156
+ ),
1157
+ is_cross_attention=True,
1158
+ layer_idx=layer_idx,
1159
+ precomputed_total_layers=precomputed_total_layers,
1160
+ )
1161
+ )
1162
+
1163
+ total_layers += config.num_sensory_cortex_layers
1164
+
1165
+ # 3. Motor Cortex (Self-Attention + Cross-Attention to Sensory + Cross-Attention to Association)
1166
+ self.motor_self_attn_layers = nn.ModuleList()
1167
+ self.motor_cross_sensory_layers = nn.ModuleList()
1168
+ self.motor_cross_assoc_layers = nn.ModuleList()
1169
+ for i in range(config.num_motor_cortex_layers):
1170
+ layer_idx = total_layers + i
1171
+ print(f"Adding layer {layer_idx} to motor cortex")
1172
+ self.motor_self_attn_layers.append(
1173
+ AttentionBlock(
1174
+ config,
1175
+ config.hidden_size,
1176
+ attention_module=SelfAttention(
1177
+ config,
1178
+ config.hidden_size,
1179
+ is_causal=True,
1180
+ layer_idx=layer_idx,
1181
+ ),
1182
+ mlp_module=(
1183
+ NeuroBLASTSparseMoeBlock(
1184
+ config,
1185
+ )
1186
+ if config.num_experts
1187
+ else NeuroBLASTMoeMLP(config)
1188
+ ),
1189
+ is_cross_attention=False,
1190
+ layer_idx=layer_idx,
1191
+ precomputed_total_layers=precomputed_total_layers,
1192
+ )
1193
+ )
1194
+
1195
+ total_layers += config.num_motor_cortex_layers
1196
+ for i in range(config.num_motor_cortex_layers):
1197
+ layer_idx = total_layers + i
1198
+ print(f"Adding layer {layer_idx} to motor cross-sensory")
1199
+ # Cross-Attend to Sensory Output
1200
+ self.motor_cross_sensory_layers.append(
1201
+ AttentionBlock(
1202
+ config,
1203
+ config.hidden_size, # Query Dim
1204
+ attention_module=CrossAttention(
1205
+ config,
1206
+ query_dim=config.hidden_size,
1207
+ kv_dim=config.kv_dim, # Sensory output dim
1208
+ layer_idx=layer_idx,
1209
+ ),
1210
+ mlp_module=(
1211
+ NeuroBLASTSparseMoeBlock(
1212
+ config,
1213
+ )
1214
+ if config.num_experts
1215
+ else NeuroBLASTMoeMLP(config)
1216
+ ),
1217
+ is_cross_attention=True,
1218
+ layer_idx=layer_idx,
1219
+ precomputed_total_layers=precomputed_total_layers,
1220
+ )
1221
+ )
1222
+
1223
+ total_layers += config.num_motor_cortex_layers
1224
+ for i in range(config.num_motor_cortex_layers):
1225
+ layer_idx = total_layers + i
1226
+ print(f"Adding layer {layer_idx} to motor cross-association")
1227
+ # Cross-Attend to Association Output
1228
+ self.motor_cross_assoc_layers.append(
1229
+ AttentionBlock(
1230
+ config,
1231
+ config.hidden_size, # Query Dim
1232
+ attention_module=CrossAttention(
1233
+ config,
1234
+ query_dim=config.hidden_size,
1235
+ kv_dim=config.kv_dim, # Assoc output dim
1236
+ layer_idx=layer_idx,
1237
+ ),
1238
+ mlp_module=(
1239
+ NeuroBLASTSparseMoeBlock(
1240
+ config,
1241
+ )
1242
+ if config.num_experts
1243
+ else NeuroBLASTMoeMLP(config)
1244
+ ),
1245
+ is_cross_attention=True,
1246
+ layer_idx=layer_idx,
1247
+ precomputed_total_layers=precomputed_total_layers,
1248
+ )
1249
+ )
1250
+
1251
+ total_layers += config.num_motor_cortex_layers
1252
+
1253
+ # Initialize more conservatively to prevent strong gradient flow initially
1254
+ self.sensory_cross_assoc_gate = NeuroBLASTMoeMLP(
1255
+ config,
1256
+ )
1257
+ self.motor_cross_sensory_gate = NeuroBLASTMoeMLP(
1258
+ config,
1259
+ )
1260
+ self.motor_cross_assoc_gate = NeuroBLASTMoeMLP(
1261
+ config,
1262
+ )
1263
+
1264
+ # Final normalization before output head
1265
+ self.norm = nn.LayerNorm(
1266
+ config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
1267
+ )
1268
+
1269
+ self.gradient_checkpointing = False
1270
+ self.post_init()
1271
+
1272
+ def forward(
1273
+ self,
1274
+ input_ids: torch.LongTensor,
1275
+ attention_mask: Optional[torch.Tensor] = None,
1276
+ position_ids: Optional[torch.LongTensor] = None,
1277
+ past_key_values: Optional[
1278
+ Cache
1279
+ ] = None,
1280
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1281
+ use_cache: Optional[bool] = None,
1282
+ output_attentions: Optional[bool] = None,
1283
+ output_hidden_states: Optional[bool] = None,
1284
+ return_dict: Optional[bool] = None,
1285
+ cache_position: Optional[torch.LongTensor] = None,
1286
+ ):
1287
+ output_attentions = (
1288
+ output_attentions
1289
+ if output_attentions is not None
1290
+ else self.config.output_attentions
1291
+ )
1292
+ output_hidden_states = (
1293
+ output_hidden_states
1294
+ if output_hidden_states is not None
1295
+ else self.config.output_hidden_states
1296
+ )
1297
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1298
+ # use_cache = False
1299
+ return_dict = (
1300
+ return_dict if return_dict is not None else self.config.use_return_dict
1301
+ )
1302
+
1303
+ if input_ids is not None and inputs_embeds is not None:
1304
+ raise ValueError("Specify either input_ids or inputs_embeds")
1305
+ batch_size, seq_length = (
1306
+ input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
1307
+ )
1308
+
1309
+ if self.gradient_checkpointing and self.training:
1310
+ if use_cache:
1311
+ logger.warning_once(
1312
+ "`use_cache=True` incompatible with gradient checkpointing. Setting `use_cache=False`"
1313
+ )
1314
+ use_cache = False
1315
+
1316
+ if not any(param.requires_grad for param in self.parameters()):
1317
+ logger.warning_once(
1318
+ "No parameters require gradients. Disabling gradient checkpointing to avoid warnings."
1319
+ )
1320
+ self.gradient_checkpointing = False
1321
+
1322
+ if inputs_embeds is None:
1323
+ inputs_embeds = self.embed_tokens(input_ids)
1324
+
1325
+ past_key_values_length = 0
1326
+ if use_cache:
1327
+ if not isinstance(past_key_values, Cache):
1328
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1329
+ past_key_values_length = past_key_values.get_seq_length()
1330
+
1331
+ if cache_position is None:
1332
+ past_seen_tokens = (
1333
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1334
+ )
1335
+ cache_position = torch.arange(
1336
+ past_seen_tokens,
1337
+ past_seen_tokens + inputs_embeds.shape[1],
1338
+ device=inputs_embeds.device,
1339
+ )
1340
+
1341
+ if position_ids is None:
1342
+ position_ids = cache_position.unsqueeze(0)
1343
+
1344
+ causal_mask = self._update_causal_mask(
1345
+ attention_mask,
1346
+ inputs_embeds,
1347
+ cache_position,
1348
+ past_key_values,
1349
+ output_attentions,
1350
+ )
1351
+
1352
+ hidden_states = inputs_embeds
1353
+
1354
+ cos, sin = self.rotary_emb(
1355
+ hidden_states, seq_len=seq_length + past_key_values_length
1356
+ )
1357
+ position_embeddings = (cos, sin)
1358
+
1359
+ all_hidden_states = () if output_hidden_states else None
1360
+ all_attentions = (
1361
+ () if output_attentions else None
1362
+ )
1363
+ next_decoder_cache = (
1364
+ past_key_values if use_cache else None
1365
+ )
1366
+
1367
+ if self.config.use_zero_memory:
1368
+ hx = torch.ones(
1369
+ (batch_size, seq_length, hidden_states.size(-1)),
1370
+ device=hidden_states.device,
1371
+ dtype=hidden_states.dtype,
1372
+ )
1373
+ cx = torch.ones(
1374
+ (batch_size, seq_length, hidden_states.size(-1)),
1375
+ device=hidden_states.device,
1376
+ dtype=hidden_states.dtype,
1377
+ )
1378
+
1379
+ if self.training:
1380
+ hx.requires_grad_()
1381
+ cx.requires_grad_()
1382
+ else:
1383
+ hx = None
1384
+ cx = None
1385
+
1386
+ # 1. Association Cortex (Self-Attention)
1387
+ assoc_output = hidden_states
1388
+ for i, layer in enumerate(self.association_cortex):
1389
+ if output_hidden_states:
1390
+ all_hidden_states += (assoc_output,)
1391
+ if self.gradient_checkpointing and self.training:
1392
+ outputs = torch.utils.checkpoint.checkpoint(
1393
+ layer,
1394
+ assoc_output,
1395
+ position_embeddings,
1396
+ causal_mask,
1397
+ position_ids,
1398
+ None,
1399
+ next_decoder_cache,
1400
+ cache_position,
1401
+ output_attentions,
1402
+ use_cache,
1403
+ (hx, cx),
1404
+ )
1405
+ else:
1406
+ outputs = layer(
1407
+ assoc_output,
1408
+ position_embeddings,
1409
+ causal_mask,
1410
+ position_ids,
1411
+ kv_states=None,
1412
+ past_key_value=next_decoder_cache,
1413
+ cache_position=cache_position,
1414
+ output_attentions=output_attentions,
1415
+ use_cache=use_cache,
1416
+ previous_states=(hx, cx),
1417
+ )
1418
+ assoc_output = outputs[0]
1419
+
1420
+ hx, cx = outputs[-1 if not use_cache else -2]
1421
+
1422
+ if output_attentions:
1423
+ all_attentions += (outputs[1],)
1424
+
1425
+ if use_cache:
1426
+ next_decoder_cache = outputs[-1]
1427
+ else:
1428
+ next_decoder_cache = None
1429
+
1430
+ sensory_state = self.assoc_to_sensory_pooler(assoc_output)
1431
+
1432
+ sensory_state = apply_gradient_scaling(
1433
+ sensory_state,
1434
+ self.config.association_gradient_scale,
1435
+ self.config.gradient_scaling_enabled,
1436
+ )
1437
+
1438
+ # 2. Sensory Cortex (Self-Attention + Cross-Attention to Association)
1439
+ for i in range(self.config.num_sensory_cortex_layers):
1440
+ if output_hidden_states:
1441
+ all_hidden_states += (sensory_state,)
1442
+
1443
+ self_attn_layer = self.sensory_self_attn_layers[i]
1444
+ if self.gradient_checkpointing and self.training:
1445
+ outputs_self = torch.utils.checkpoint.checkpoint(
1446
+ self_attn_layer,
1447
+ sensory_state,
1448
+ position_embeddings,
1449
+ causal_mask,
1450
+ position_ids,
1451
+ None,
1452
+ next_decoder_cache,
1453
+ cache_position,
1454
+ output_attentions,
1455
+ use_cache,
1456
+ (hx, cx),
1457
+ )
1458
+ else:
1459
+ outputs_self = self_attn_layer(
1460
+ sensory_state,
1461
+ position_embeddings,
1462
+ causal_mask,
1463
+ position_ids,
1464
+ kv_states=None,
1465
+ past_key_value=next_decoder_cache,
1466
+ cache_position=cache_position,
1467
+ output_attentions=output_attentions,
1468
+ use_cache=use_cache,
1469
+ previous_states=(hx, cx),
1470
+ )
1471
+ sensory_state = outputs_self[0]
1472
+ hx, cx = outputs_self[-1 if not use_cache else -2]
1473
+
1474
+ if output_attentions:
1475
+ all_attentions += (outputs_self[1],)
1476
+
1477
+ if use_cache:
1478
+ next_decoder_cache = outputs_self[-1]
1479
+ else:
1480
+ next_decoder_cache = None
1481
+
1482
+ cross_attn_layer = self.sensory_cross_attn_layers[i]
1483
+
1484
+ cross_attn_causal_mask = None
1485
+
1486
+ if causal_mask is not None:
1487
+ q_seq_len = sensory_state.size(1)
1488
+ kv_seq_len = assoc_output.size(1)
1489
+
1490
+ cross_attn_causal_mask = torch.ones(
1491
+ (batch_size, 1, q_seq_len, kv_seq_len),
1492
+ device=hidden_states.device,
1493
+ dtype=hidden_states.dtype,
1494
+ )
1495
+
1496
+ causal_mask_upper = torch.triu(
1497
+ torch.ones((q_seq_len, kv_seq_len), device=hidden_states.device),
1498
+ diagonal=1,
1499
+ )
1500
+
1501
+ cross_attn_causal_mask = cross_attn_causal_mask.masked_fill(
1502
+ causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1503
+ torch.finfo(hidden_states.dtype).min,
1504
+ )
1505
+
1506
+ if self.gradient_checkpointing and self.training:
1507
+ outputs_cross = torch.utils.checkpoint.checkpoint(
1508
+ cross_attn_layer,
1509
+ sensory_state,
1510
+ position_embeddings,
1511
+ cross_attn_causal_mask,
1512
+ position_ids,
1513
+ assoc_output,
1514
+ next_decoder_cache,
1515
+ cache_position,
1516
+ output_attentions,
1517
+ use_cache,
1518
+ (hx, cx),
1519
+ )
1520
+ else:
1521
+ outputs_cross = cross_attn_layer(
1522
+ sensory_state,
1523
+ position_embeddings,
1524
+ past_key_value=next_decoder_cache,
1525
+ cache_position=cache_position,
1526
+ attention_mask=cross_attn_causal_mask,
1527
+ position_ids=position_ids,
1528
+ kv_states=apply_gradient_scaling(
1529
+ assoc_output,
1530
+ self.config.cross_attention_gradient_scale,
1531
+ self.config.gradient_scaling_enabled,
1532
+ ),
1533
+ output_attentions=output_attentions,
1534
+ use_cache=use_cache,
1535
+ previous_states=(hx, cx),
1536
+ )
1537
+
1538
+ cross_contribution = nn.functional.layer_norm(
1539
+ outputs_cross[0],
1540
+ normalized_shape=(self.config.hidden_size,),
1541
+ eps=getattr(self.config, "rms_norm_eps", 1e-5),
1542
+ )
1543
+
1544
+ sensory_state = sensory_state + self.sensory_cross_assoc_gate(
1545
+ cross_contribution
1546
+ )
1547
+
1548
+ hx, cx = outputs_cross[-1 if not use_cache else -2]
1549
+ if output_attentions:
1550
+ all_attentions += (outputs_cross[1],)
1551
+
1552
+ if use_cache:
1553
+ next_decoder_cache = outputs_cross[-1]
1554
+ else:
1555
+ next_decoder_cache = None
1556
+
1557
+ motor_state = self.sensory_to_motor_pooler(sensory_state)
1558
+
1559
+ motor_state = apply_gradient_scaling(
1560
+ motor_state,
1561
+ self.config.sensory_gradient_scale,
1562
+ self.config.gradient_scaling_enabled,
1563
+ )
1564
+
1565
+ motor_state_from_assoc = self.assoc_to_motor_pooler(assoc_output)
1566
+
1567
+ motor_state_from_assoc = apply_gradient_scaling(
1568
+ motor_state_from_assoc,
1569
+ self.config.association_gradient_scale,
1570
+ self.config.gradient_scaling_enabled,
1571
+ )
1572
+
1573
+ motor_state = motor_state + motor_state_from_assoc # Combine pooled inputs
1574
+
1575
+ # 3. Motor Cortex (Self + Cross-Sensory + Cross-Association)
1576
+ for i in range(self.config.num_motor_cortex_layers):
1577
+ if output_hidden_states:
1578
+ all_hidden_states += (motor_state,)
1579
+
1580
+ self_attn_layer = self.motor_self_attn_layers[i]
1581
+ if self.gradient_checkpointing and self.training:
1582
+ outputs_self = torch.utils.checkpoint.checkpoint(
1583
+ self_attn_layer,
1584
+ motor_state,
1585
+ position_embeddings,
1586
+ causal_mask,
1587
+ position_ids,
1588
+ None,
1589
+ next_decoder_cache,
1590
+ cache_position,
1591
+ output_attentions,
1592
+ use_cache,
1593
+ (hx, cx),
1594
+ )
1595
+ else:
1596
+ outputs_self = self_attn_layer(
1597
+ motor_state,
1598
+ position_embeddings,
1599
+ causal_mask,
1600
+ position_ids,
1601
+ kv_states=None,
1602
+ past_key_value=next_decoder_cache,
1603
+ cache_position=cache_position,
1604
+ output_attentions=output_attentions,
1605
+ use_cache=use_cache,
1606
+ previous_states=(hx, cx),
1607
+ )
1608
+ motor_state = outputs_self[0]
1609
+ hx, cx = outputs_self[-1 if not use_cache else -2]
1610
+ if output_attentions:
1611
+ all_attentions += (outputs_self[1],)
1612
+
1613
+ if use_cache:
1614
+ next_decoder_cache = outputs_self[-1]
1615
+ else:
1616
+ next_decoder_cache = None
1617
+
1618
+ cross_sensory_layer = self.motor_cross_sensory_layers[i]
1619
+
1620
+ motor_cross_sensory_mask = None
1621
+
1622
+ if causal_mask is not None:
1623
+ motor_q_seq_len = motor_state.size(1)
1624
+ sensory_kv_seq_len = sensory_state.size(
1625
+ 1
1626
+ )
1627
+
1628
+ motor_cross_sensory_mask = torch.ones(
1629
+ (batch_size, 1, motor_q_seq_len, sensory_kv_seq_len),
1630
+ device=hidden_states.device,
1631
+ dtype=hidden_states.dtype,
1632
+ )
1633
+
1634
+ causal_mask_upper = torch.triu(
1635
+ torch.ones(
1636
+ (motor_q_seq_len, sensory_kv_seq_len),
1637
+ device=hidden_states.device,
1638
+ ),
1639
+ diagonal=1,
1640
+ )
1641
+
1642
+ motor_cross_sensory_mask = motor_cross_sensory_mask.masked_fill(
1643
+ causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1644
+ torch.finfo(hidden_states.dtype).min,
1645
+ )
1646
+
1647
+ if self.gradient_checkpointing and self.training:
1648
+ outputs_cross_sensory = torch.utils.checkpoint.checkpoint(
1649
+ cross_sensory_layer,
1650
+ motor_state,
1651
+ position_embeddings,
1652
+ motor_cross_sensory_mask,
1653
+ position_ids,
1654
+ sensory_state,
1655
+ next_decoder_cache,
1656
+ cache_position,
1657
+ output_attentions,
1658
+ use_cache,
1659
+ (hx, cx),
1660
+ )
1661
+ else:
1662
+ outputs_cross_sensory = cross_sensory_layer(
1663
+ motor_state,
1664
+ position_embeddings,
1665
+ attention_mask=motor_cross_sensory_mask,
1666
+ position_ids=position_ids,
1667
+ kv_states=apply_gradient_scaling(
1668
+ sensory_state,
1669
+ self.config.cross_attention_gradient_scale,
1670
+ self.config.gradient_scaling_enabled,
1671
+ ),
1672
+ output_attentions=output_attentions,
1673
+ past_key_value=next_decoder_cache,
1674
+ cache_position=cache_position,
1675
+ use_cache=use_cache,
1676
+ previous_states=(hx, cx),
1677
+ )
1678
+ motor_state = motor_state + self.motor_cross_sensory_gate(
1679
+ outputs_cross_sensory[0]
1680
+ )
1681
+ hx, cx = outputs_cross_sensory[-1 if not use_cache else -2]
1682
+ if output_attentions:
1683
+ all_attentions += (outputs_cross_sensory[1],)
1684
+
1685
+ if use_cache:
1686
+ next_decoder_cache = outputs_cross_sensory[-1]
1687
+ else:
1688
+ next_decoder_cache = None
1689
+
1690
+ cross_assoc_layer = self.motor_cross_assoc_layers[i]
1691
+
1692
+ motor_cross_assoc_mask = None
1693
+ if causal_mask is not None:
1694
+ motor_q_seq_len = motor_state.size(1)
1695
+ assoc_kv_seq_len = assoc_output.size(
1696
+ 1
1697
+ )
1698
+
1699
+ motor_cross_assoc_mask = torch.ones(
1700
+ (batch_size, 1, motor_q_seq_len, assoc_kv_seq_len),
1701
+ device=hidden_states.device,
1702
+ dtype=hidden_states.dtype,
1703
+ )
1704
+
1705
+ causal_mask_upper = torch.triu(
1706
+ torch.ones(
1707
+ (motor_q_seq_len, assoc_kv_seq_len), device=hidden_states.device
1708
+ ),
1709
+ diagonal=1,
1710
+ )
1711
+
1712
+ motor_cross_assoc_mask = motor_cross_assoc_mask.masked_fill(
1713
+ causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1714
+ torch.finfo(hidden_states.dtype).min,
1715
+ )
1716
+
1717
+ if self.gradient_checkpointing and self.training:
1718
+ outputs_cross_assoc = torch.utils.checkpoint.checkpoint(
1719
+ cross_assoc_layer,
1720
+ motor_state,
1721
+ position_embeddings,
1722
+ motor_cross_assoc_mask,
1723
+ position_ids,
1724
+ assoc_output,
1725
+ next_decoder_cache,
1726
+ cache_position,
1727
+ output_attentions,
1728
+ use_cache,
1729
+ (hx, cx),
1730
+ )
1731
+ else:
1732
+ outputs_cross_assoc = cross_assoc_layer(
1733
+ motor_state,
1734
+ position_embeddings,
1735
+ attention_mask=motor_cross_assoc_mask,
1736
+ position_ids=position_ids,
1737
+ kv_states=apply_gradient_scaling(
1738
+ assoc_output,
1739
+ self.config.cross_attention_gradient_scale,
1740
+ self.config.gradient_scaling_enabled,
1741
+ ),
1742
+ output_attentions=output_attentions,
1743
+ past_key_value=next_decoder_cache,
1744
+ cache_position=cache_position,
1745
+ use_cache=use_cache,
1746
+ previous_states=(hx, cx),
1747
+ )
1748
+ motor_state = motor_state + self.motor_cross_assoc_gate(
1749
+ outputs_cross_assoc[0]
1750
+ )
1751
+ hx, cx = outputs_cross_assoc[-1 if not use_cache else -2]
1752
+ if output_attentions:
1753
+ all_attentions += (outputs_cross_assoc[1],)
1754
+
1755
+ if use_cache:
1756
+ next_decoder_cache = outputs_cross_assoc[-1]
1757
+ else:
1758
+ next_decoder_cache = None
1759
+
1760
+ final_output = self.norm(motor_state)
1761
+
1762
+ if output_hidden_states:
1763
+ all_hidden_states += (final_output,)
1764
+
1765
+ if not return_dict:
1766
+ outputs_tuple = (final_output,)
1767
+ if use_cache:
1768
+ outputs_tuple += (next_decoder_cache,)
1769
+ if output_hidden_states:
1770
+ outputs_tuple += (all_hidden_states,)
1771
+ if output_attentions:
1772
+ outputs_tuple += (all_attentions,)
1773
+ return tuple(v for v in outputs_tuple if v is not None)
1774
+
1775
+ return BaseModelOutputWithPast(
1776
+ last_hidden_state=final_output,
1777
+ past_key_values=next_decoder_cache,
1778
+ hidden_states=all_hidden_states,
1779
+ attentions=all_attentions,
1780
+ )
1781
+
1782
+ def get_input_embeddings(self):
1783
+ return self.embed_tokens
1784
+
1785
+ def set_input_embeddings(self, value):
1786
+ self.embed_tokens = value
1787
+
1788
+ def _update_causal_mask(
1789
+ self,
1790
+ attention_mask: torch.Tensor,
1791
+ input_tensor: torch.Tensor,
1792
+ cache_position: torch.Tensor,
1793
+ past_key_values: Cache,
1794
+ output_attentions: bool,
1795
+ ):
1796
+ if self.config._attn_implementation == "flash_attention_2":
1797
+ if attention_mask is not None and 0.0 in attention_mask:
1798
+ return attention_mask
1799
+ return None
1800
+
1801
+ past_seen_tokens = (
1802
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1803
+ )
1804
+ using_static_cache = isinstance(past_key_values, StaticCache)
1805
+
1806
+ if (
1807
+ self.config._attn_implementation == "sdpa"
1808
+ and not using_static_cache
1809
+ and not output_attentions
1810
+ ):
1811
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1812
+ attention_mask,
1813
+ inputs_embeds=input_tensor,
1814
+ past_key_values_length=past_seen_tokens,
1815
+ is_training=self.training,
1816
+ ):
1817
+ return None
1818
+
1819
+ dtype, device = input_tensor.dtype, input_tensor.device
1820
+ min_dtype = torch.finfo(dtype).min
1821
+ sequence_length = input_tensor.shape[1]
1822
+ if using_static_cache:
1823
+ target_length = (
1824
+ getattr(
1825
+ past_key_values,
1826
+ "get_max_length",
1827
+ lambda: past_key_values.get_seq_length(),
1828
+ )()
1829
+ if hasattr(past_key_values, "get_seq_length")
1830
+ else sequence_length + past_seen_tokens
1831
+ )
1832
+ else:
1833
+ target_length = (
1834
+ attention_mask.shape[-1]
1835
+ if isinstance(attention_mask, torch.Tensor)
1836
+ else past_seen_tokens + sequence_length + 1
1837
+ )
1838
+
1839
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1840
+ attention_mask,
1841
+ sequence_length=sequence_length,
1842
+ target_length=target_length,
1843
+ dtype=dtype,
1844
+ device=device,
1845
+ min_dtype=min_dtype,
1846
+ cache_position=cache_position,
1847
+ batch_size=input_tensor.shape[0],
1848
+ )
1849
+
1850
+ if (
1851
+ self.config._attn_implementation == "sdpa"
1852
+ and attention_mask is not None
1853
+ and attention_mask.device.type == "cuda"
1854
+ and not output_attentions
1855
+ ):
1856
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1857
+ causal_mask, min_dtype
1858
+ )
1859
+
1860
+ return causal_mask
1861
+
1862
+
1863
+ class NeuroBLASTForCausalLM(NeuroBLASTPreTrainedModel, GenerationMixin):
1864
+ _tied_weights_keys = ["lm_head.weight"]
1865
+
1866
+ def __init__(self, config: NeuroBLASTConfig):
1867
+ super().__init__(config)
1868
+ self.config = config
1869
+ self.model = NeuroBLASTModel(config)
1870
+ self.vocab_size = config.vocab_size # Ensure vocab_size is accessible
1871
+ self.lm_head = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=False)
1872
+ self.loss_steps = 0
1873
+ self.post_init() # Initialize weights
1874
+
1875
+ def get_input_embeddings(self):
1876
+ return self.model.get_input_embeddings()
1877
+
1878
+ def set_input_embeddings(self, value):
1879
+ self.model.set_input_embeddings(value)
1880
+
1881
+ def get_output_embeddings(self):
1882
+ return self.lm_head
1883
+
1884
+ def set_output_embeddings(self, new_embeddings):
1885
+ self.lm_head = new_embeddings
1886
+
1887
+ def tie_weights(self):
1888
+ if getattr(self.config, "tie_word_embeddings", False):
1889
+ output_embeddings = self.get_output_embeddings()
1890
+ input_embeddings = self.get_input_embeddings()
1891
+ output_embeddings.weight = input_embeddings.weight
1892
+ if getattr(output_embeddings, "bias", None) is not None:
1893
+ output_embeddings.bias.data.zero_()
1894
+
1895
+ def forward(
1896
+ self,
1897
+ input_ids: torch.LongTensor = None,
1898
+ attention_mask: Optional[torch.Tensor] = None,
1899
+ position_ids: Optional[torch.LongTensor] = None,
1900
+ past_key_values: Optional[Cache] = None,
1901
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1902
+ labels: Optional[torch.LongTensor] = None,
1903
+ use_cache: Optional[bool] = None,
1904
+ output_attentions: Optional[bool] = None,
1905
+ output_hidden_states: Optional[bool] = None,
1906
+ return_dict: Optional[bool] = None,
1907
+ cache_position: Optional[torch.LongTensor] = None,
1908
+ **loss_kwargs,
1909
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1910
+
1911
+ return_dict = (
1912
+ return_dict if return_dict is not None else self.config.use_return_dict
1913
+ )
1914
+
1915
+ outputs = self.model(
1916
+ input_ids=input_ids,
1917
+ attention_mask=attention_mask,
1918
+ position_ids=position_ids,
1919
+ past_key_values=past_key_values,
1920
+ inputs_embeds=inputs_embeds,
1921
+ use_cache=use_cache,
1922
+ output_attentions=output_attentions,
1923
+ output_hidden_states=output_hidden_states,
1924
+ return_dict=return_dict,
1925
+ cache_position=cache_position,
1926
+ )
1927
+
1928
+ hidden_states = outputs[0]
1929
+
1930
+ logits = self.lm_head(hidden_states)
1931
+ logits = logits.float()
1932
+
1933
+ loss = None
1934
+ if labels is not None:
1935
+ loss = self.loss_function(
1936
+ logits=logits,
1937
+ labels=labels,
1938
+ vocab_size=self.config.vocab_size,
1939
+ **loss_kwargs,
1940
+ )
1941
+
1942
+ if not return_dict:
1943
+ output = (logits,) + outputs[1:]
1944
+ return (loss,) + output if loss is not None else output
1945
+
1946
+ return CausalLMOutputWithPast(
1947
+ loss=loss,
1948
+ logits=logits,
1949
+ past_key_values=outputs.past_key_values,
1950
+ hidden_states=hidden_states,
1951
+ attentions=outputs.attentions,
1952
+ )
1953
+
1954
+ def get_input_embeddings(self):
1955
+ return self.model.get_input_embeddings()
1956
+
1957
+ def set_input_embeddings(self, value):
1958
+ self.model.set_input_embeddings(value)
1959
+
1960
+ def __str__(self):
1961
+ return f"NeuroBLASTForCausalLM(config={self.config})"
neuroblast_model/registration.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from neuroblast_model import NeuroBLASTConfig, NeuroBLASTForCausalLM
3
+
4
+ AutoConfig.register("neuroblast", NeuroBLASTForCausalLM)
5
+ AutoModel.register(NeuroBLASTConfig, NeuroBLASTForCausalLM)