Mariusz Kurman commited on
Commit
2a4d1ef
·
1 Parent(s): 97edb23

Remove registration of NeuroBLAST model in registration.py

Browse files
neuroblast_model/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from neuroblast_model.configuration_neuroblast import NeuroBLASTConfig
2
- from neuroblast_model.modeling_neuroblast import NeuroBLASTForCausalLM
 
 
 
neuroblast_model/configuration_neuroblast.py DELETED
@@ -1,85 +0,0 @@
1
- from transformers.configuration_utils import PretrainedConfig
2
- from transformers.modeling_rope_utils import rope_config_validation
3
- from typing import Optional
4
-
5
-
6
- class NeuroBLASTConfig(PretrainedConfig):
7
- model_type = "neuroblast"
8
-
9
- def __init__(
10
- self,
11
- vocab_size=28886,
12
- hidden_size=2048,
13
- kv_dim=2048,
14
- intermediate_size=None,
15
- num_attention_heads=32,
16
- num_sensory_cortex_layers=6,
17
- num_motor_cortex_layers=6,
18
- num_association_cortex_layers=6,
19
- dropout=0.1,
20
- layer_norm_epsilon=1e-6,
21
- pad_token_id=None,
22
- use_cache=False,
23
- rope_theta=10000.0,
24
- rope_scaling=None,
25
- max_position_embeddings=2048,
26
- initializer_range=0.02,
27
- use_flash_attn=True,
28
- num_experts=None,
29
- num_experts_per_tok=None,
30
- norm_topk_prob=False,
31
- hidden_act="silu",
32
- use_zero_memory=False,
33
- zero_memory_alpha=1.0,
34
- zero_memory_layers=None,
35
- gradient_scaling_enabled=True,
36
- association_gradient_scale=0.9,
37
- sensory_gradient_scale=0.95,
38
- cross_attention_gradient_scale=0.95,
39
- clamp_value=1e5,
40
- _attn_implementation='sdpa',
41
- **kwargs
42
- ):
43
- # Calculate intermediate_size if not provided
44
- if intermediate_size is None:
45
- intermediate_size = int(hidden_size * 4 * 2 / 3)
46
-
47
- super().__init__(
48
- pad_token_id=pad_token_id,
49
- **kwargs
50
- )
51
-
52
- self.vocab_size = vocab_size
53
- self.hidden_size = hidden_size
54
- self.kv_dim = kv_dim
55
- self.intermediate_size = intermediate_size
56
- self.num_attention_heads = num_attention_heads
57
- self.num_sensory_cortex_layers = num_sensory_cortex_layers
58
- self.num_motor_cortex_layers = num_motor_cortex_layers
59
- self.num_association_cortex_layers = num_association_cortex_layers
60
- self.dropout = dropout
61
- self.layer_norm_epsilon = layer_norm_epsilon
62
- self.rms_norm_eps = layer_norm_epsilon
63
- self.use_cache = use_cache
64
- self.rope_theta = rope_theta
65
- self.rope_scaling = rope_scaling
66
- self.max_position_embeddings = max_position_embeddings
67
- self.initializer_range = initializer_range
68
- self.use_flash_attn = use_flash_attn
69
- self.num_experts = num_experts
70
- self.num_experts_per_tok = num_experts_per_tok
71
- self.norm_topk_prob = norm_topk_prob
72
- self.hidden_act = hidden_act
73
- self.use_zero_memory = use_zero_memory
74
- self.zero_memory_alpha = zero_memory_alpha
75
- self.zero_memory_layers = zero_memory_layers
76
- self.gradient_scaling_enabled = gradient_scaling_enabled
77
- self.association_gradient_scale = association_gradient_scale
78
- self.sensory_gradient_scale = sensory_gradient_scale
79
- self.cross_attention_gradient_scale = cross_attention_gradient_scale
80
- self._attn_implementation = _attn_implementation
81
- self.clamp_value = clamp_value
82
-
83
- if self.rope_scaling is not None and "type" in self.rope_scaling:
84
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
85
- rope_config_validation(self)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neuroblast_model/modeling_neuroblast.py DELETED
@@ -1,1961 +0,0 @@
1
- import torch
2
- import math
3
- from torch import nn
4
- import torch.nn.functional as F
5
- from transformers import PreTrainedModel, GenerationMixin
6
- from transformers.cache_utils import DynamicCache, Cache
7
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
9
- from transformers.utils import logging
10
- from transformers.modeling_outputs import (
11
- BaseModelOutputWithPast,
12
- CausalLMOutputWithPast,
13
- )
14
- from transformers.activations import ACT2FN
15
- from typing import Optional, Tuple, Union, List
16
- from neuroblast_model.configuration_neuroblast import NeuroBLASTConfig
17
-
18
- CLAMP_VALUE = 1e5
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
- def apply_gradient_scaling(
23
- tensor: torch.Tensor, scale: float, enabled: bool = True
24
- ) -> torch.Tensor:
25
- """
26
- Apply gradient scaling to a tensor.
27
- This scales the gradients during backward pass while keeping forward pass unchanged.
28
- """
29
- if not enabled or scale == 1.0 or not tensor.requires_grad:
30
- return tensor
31
-
32
- # Use a custom autograd function for gradient scaling
33
- class GradientScale(torch.autograd.Function):
34
- @staticmethod
35
- def forward(ctx, input_tensor, scale_factor):
36
- ctx.scale = scale_factor
37
- return input_tensor.clone()
38
-
39
- @staticmethod
40
- def backward(ctx, grad_output):
41
- if grad_output is None:
42
- return None, None
43
- return grad_output * ctx.scale, None
44
-
45
- return GradientScale.apply(tensor, scale)
46
-
47
-
48
- def _prepare_4d_causal_attention_mask_with_cache_position(
49
- attention_mask: torch.Tensor,
50
- sequence_length: int,
51
- target_length: int,
52
- dtype: torch.dtype,
53
- device: torch.device,
54
- min_dtype: float,
55
- cache_position: torch.Tensor,
56
- batch_size: int,
57
- ):
58
- if attention_mask is not None and attention_mask.dim() == 4:
59
- causal_mask = attention_mask
60
- else:
61
- causal_mask = torch.full(
62
- (sequence_length, target_length),
63
- fill_value=min_dtype,
64
- dtype=dtype,
65
- device=device,
66
- )
67
- if sequence_length != 1:
68
- causal_mask = torch.triu(causal_mask, diagonal=1)
69
- causal_mask *= torch.arange(
70
- target_length, device=device
71
- ) > cache_position.reshape(-1, 1)
72
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
73
- if attention_mask is not None:
74
- causal_mask = (
75
- causal_mask.clone()
76
- )
77
- mask_length = attention_mask.shape[-1]
78
- padding_mask = (
79
- causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
80
- )
81
- padding_mask = padding_mask == 0
82
- causal_mask[:, :, :, :mask_length] = causal_mask[
83
- :, :, :, :mask_length
84
- ].masked_fill(padding_mask, min_dtype)
85
-
86
- return causal_mask
87
-
88
-
89
- # --- RoPE Implementation (using HF LlamaRotaryEmbedding) ---
90
-
91
-
92
- class LlamaRMSNorm(nn.Module):
93
- def __init__(self, hidden_size, eps=1e-6):
94
- """
95
- LlamaRMSNorm is equivalent to T5LayerNorm
96
- """
97
- super().__init__()
98
- self.weight = nn.Parameter(torch.ones(hidden_size))
99
- self.variance_epsilon = eps
100
-
101
- def forward(self, hidden_states):
102
- input_dtype = hidden_states.dtype
103
- hidden_states = hidden_states.to(torch.float32)
104
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
105
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
106
- return self.weight * hidden_states.to(input_dtype)
107
-
108
- def extra_repr(self):
109
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
110
-
111
-
112
- class NeuroBLASTRotaryEmbedding(nn.Module):
113
- """
114
- Rotary Positional Embedding for NeuroBLAST model.
115
- Source: LlamaRotaryEmbedding
116
- """
117
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
118
- super().__init__()
119
- self.dim = dim
120
- self.max_position_embeddings = max_position_embeddings
121
- self.base = base
122
- inv_freq = 1.0 / (
123
- self.base
124
- ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim)
125
- )
126
- self.register_buffer("inv_freq", inv_freq, persistent=False)
127
- self._set_cos_sin_cache(
128
- seq_len=max_position_embeddings, device="cpu", dtype=torch.float32
129
- )
130
-
131
- def _set_cos_sin_cache(self, seq_len, device, dtype):
132
- self.max_seq_len_cached = seq_len
133
- t = torch.arange(
134
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
135
- )
136
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
- emb = torch.cat((freqs, freqs), dim=-1)
138
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
-
141
- def forward(self, x, seq_len=None):
142
- if seq_len > self.max_seq_len_cached:
143
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
- elif self.cos_cached.device != x.device or self.cos_cached.dtype != x.dtype:
145
- self.cos_cached = self.cos_cached.to(device=x.device, dtype=x.dtype)
146
- self.sin_cached = self.sin_cached.to(device=x.device, dtype=x.dtype)
147
-
148
- return (
149
- self.cos_cached[:seq_len],
150
- self.sin_cached[:seq_len],
151
- )
152
-
153
-
154
- def rotate_half(x):
155
- """Rotates half the hidden dims of the input."""
156
- x1 = x[..., : x.shape[-1] // 2]
157
- x2 = x[..., x.shape[-1] // 2 :]
158
- return torch.cat((-x2, x1), dim=-1)
159
-
160
-
161
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
162
- """ Applies rotary positional embeddings to query and key tensors."""
163
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
164
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
- q_embed = (q * cos) + (rotate_half(q) * sin)
166
- k_embed = (k * cos) + (rotate_half(k) * sin)
167
- return q_embed, k_embed
168
-
169
-
170
- # Overload for Cross Attention where only query is rotated
171
- def apply_rotary_pos_emb_single(q, cos, sin, position_ids):
172
- """ Applies rotary positional embeddings to query tensor. """
173
- cos = cos[position_ids].unsqueeze(1) # [1, 1, seq_len, dim]
174
- sin = sin[position_ids].unsqueeze(1) # [1, 1, seq_len, dim]
175
- q_embed = (q * cos) + (rotate_half(q) * sin)
176
- return q_embed
177
-
178
-
179
- class SwiGLUMLP(nn.Module):
180
- """SwiGLU MLP block"""
181
-
182
- def __init__(self, hidden_size, config: NeuroBLASTConfig, dropout):
183
- super().__init__()
184
- intermediate_size = getattr(config, "intermediate_size", int(hidden_size * 2.5))
185
- self.init_std = getattr(config, "initializer_range", 0.02)
186
- self.clamp_value = config.clamp_value
187
-
188
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
189
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
190
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
191
- self.act_fn = nn.SiLU()
192
- self.dropout = nn.Dropout(dropout)
193
-
194
- if config.num_experts is not None:
195
- self.experts = NeuroBLASTSparseMoeBlock(config)
196
-
197
- # Re-enable scaled initialization
198
- with torch.no_grad():
199
- # Scale down initial weights in the up/gate projections
200
- self.gate_proj.weight.data.normal_(
201
- mean=0.0,
202
- std=self.init_std / math.sqrt(hidden_size), # Scale by input dim
203
- )
204
- if self.gate_proj.bias is not None:
205
- self.gate_proj.bias.data.zero_()
206
-
207
- self.up_proj.weight.data.normal_(
208
- mean=0.0, std=self.init_std / math.sqrt(hidden_size)
209
- ) # Scale by input dim
210
- # Scale down initial weights in the down projection even further
211
- self.down_proj.weight.data.normal_(
212
- mean=0.0,
213
- std=self.init_std
214
- / math.sqrt(intermediate_size), # Scale by intermediate dim
215
- )
216
-
217
- def forward(self, x):
218
- gated_x = self.gate_proj(x)
219
- activated_x = self.act_fn(gated_x)
220
- up_projected_x = self.up_proj(x)
221
-
222
- intermediate_activation = activated_x * up_projected_x
223
-
224
- # Clamp the intermediate activation before down_proj
225
- clamp_value = self.clamp_value
226
-
227
- intermediate_activation = torch.clamp(
228
- intermediate_activation, min=-clamp_value, max=clamp_value
229
- )
230
- intermediate_activation = torch.nan_to_num(
231
- intermediate_activation
232
- ) # Safeguard against NaNs
233
-
234
- y = self.down_proj(intermediate_activation)
235
- y = self.dropout(y)
236
-
237
- if hasattr(self, "experts"):
238
- z = self.experts(y)
239
-
240
- y = y + z
241
-
242
- return y
243
-
244
-
245
- class NeuroBLASTMoeMLP(nn.Module):
246
- """ Source: Qwen3MoeMLP """
247
- def __init__(self, config, intermediate_size=None):
248
- super().__init__()
249
- self.config = config
250
- self.clamp_value = config.clamp_value
251
- self.hidden_size = config.hidden_size
252
- self.intermediate_size = (
253
- intermediate_size
254
- if intermediate_size is not None
255
- else config.intermediate_size
256
- )
257
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
258
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
259
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
260
- self.act_fn = ACT2FN[config.hidden_act]
261
-
262
- def forward(self, x):
263
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
264
-
265
- down_proj = torch.clamp(down_proj, min=-self.clamp_value, max=self.clamp_value)
266
- down_proj = torch.nan_to_num(down_proj) # Safeguard against NaNs
267
-
268
- return down_proj
269
-
270
-
271
- class NeuroBLASTSparseMoeBlock(nn.Module):
272
- """ Source: Qwen3SparseMoeBlock """
273
- def __init__(self, config):
274
- super().__init__()
275
- self.num_experts = config.num_experts
276
- self.top_k = config.num_experts_per_tok
277
- self.norm_topk_prob = config.norm_topk_prob
278
-
279
- # gating
280
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
281
- self.experts = nn.ModuleList(
282
- [NeuroBLASTMoeMLP(config) for _ in range(self.num_experts)]
283
- )
284
-
285
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
286
- """ """
287
- batch_size, sequence_length, hidden_dim = hidden_states.shape
288
- hidden_states = hidden_states.view(-1, hidden_dim)
289
- # router_logits: (batch * sequence_length, n_experts)
290
- router_logits = self.gate(hidden_states)
291
-
292
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
293
- routing_weights, selected_experts = torch.topk(
294
- routing_weights, self.top_k, dim=-1
295
- )
296
- if self.norm_topk_prob: # only diff with mixtral sparse moe block!
297
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
298
- # we cast back to the input dtype
299
- routing_weights = routing_weights.to(hidden_states.dtype)
300
-
301
- final_hidden_states = torch.zeros(
302
- (batch_size * sequence_length, hidden_dim),
303
- dtype=hidden_states.dtype,
304
- device=hidden_states.device,
305
- )
306
-
307
- # One hot encode the selected experts to create an expert mask
308
- # this will be used to easily index which expert is going to be sollicitated
309
- expert_mask = torch.nn.functional.one_hot(
310
- selected_experts, num_classes=self.num_experts
311
- ).permute(2, 1, 0)
312
-
313
- # Loop over all available experts in the model and perform the computation on each expert
314
- for expert_idx in range(self.num_experts):
315
- expert_layer = self.experts[expert_idx]
316
- idx, top_x = torch.where(expert_mask[expert_idx])
317
-
318
- # Index the correct hidden states and compute the expert hidden state for
319
- # the current expert. We need to make sure to multiply the output hidden
320
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
321
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
322
- current_hidden_states = (
323
- expert_layer(current_state) * routing_weights[top_x, idx, None]
324
- )
325
-
326
- # However `index_add_` only support torch tensors for indexing so we'll use
327
- # the `top_x` tensor here.
328
- final_hidden_states.index_add_(
329
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
330
- )
331
- final_hidden_states = final_hidden_states.reshape(
332
- batch_size, sequence_length, hidden_dim
333
- )
334
- return final_hidden_states
335
-
336
-
337
- class NeuroBLASTRouterBlock(nn.Module):
338
- """ Memory router; overcomplicated due to backward compatibility """
339
- def __init__(
340
- self,
341
- config,
342
- hidden_size,
343
- ):
344
- super().__init__()
345
- self.num_experts = 2
346
- self.top_k = 1
347
- self.norm_topk_prob = config.norm_topk_prob
348
-
349
- # gating
350
- self.gate = nn.Linear(hidden_size, self.num_experts, bias=False)
351
-
352
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
353
- """ """
354
- batch_size, sequence_length, hidden_dim = hidden_states.shape
355
- hidden_states = hidden_states.view(-1, hidden_dim)
356
- # router_logits: (batch * sequence_length, n_experts)
357
- router_logits = self.gate(hidden_states)
358
-
359
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
360
- routing_weights, selected_experts = torch.topk(
361
- routing_weights, self.top_k, dim=-1
362
- )
363
- if self.norm_topk_prob: # only diff with mixtral sparse moe block!
364
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
365
- # we cast back to the input dtype
366
- routing_weights = routing_weights.to(hidden_states.dtype)
367
-
368
- return routing_weights, selected_experts
369
-
370
-
371
- class SelfAttention(torch.nn.Module):
372
- def __init__(
373
- self,
374
- config: NeuroBLASTConfig,
375
- hidden_size: int,
376
- is_causal: bool = False,
377
- layer_idx: Optional[int] = None,
378
- ):
379
- super().__init__()
380
- self.is_causal = is_causal
381
- self.dropout_p = config.dropout # Will apply based on self.training
382
- self.layer_idx = layer_idx
383
- self.hidden_size = hidden_size
384
- self.intermediate_size = config.kv_dim
385
- self.use_flash_attn = config.use_flash_attn
386
- # Allow overriding num_heads, default to config.num_attention_heads
387
- self.num_heads = getattr(
388
- config, f"num_heads_{layer_idx}", config.num_attention_heads
389
- )
390
- self.head_dim = self.intermediate_size // self.num_heads
391
-
392
- if (self.head_dim * self.num_heads) != self.intermediate_size:
393
- raise ValueError(
394
- f"Layer {self.layer_idx}: hidden_size ({self.intermediate_size}) must be divisible by num_heads ({self.num_heads})"
395
- )
396
-
397
- self.qkv_proj = nn.Linear(
398
- self.hidden_size, self.intermediate_size * 3, bias=True
399
- )
400
- self.o_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
401
- self.dropout = nn.Dropout(config.dropout)
402
-
403
- def forward(
404
- self,
405
- hidden_states: torch.Tensor,
406
- position_embeddings: Tuple[torch.Tensor, torch.Tensor], # (cos, sin)
407
- attention_mask: Optional[
408
- torch.Tensor
409
- ], # Not directly used by flash_attn causal
410
- past_key_value: Optional[Cache] = None,
411
- cache_position: Optional[torch.LongTensor] = None,
412
- output_attentions: Optional[bool] = False,
413
- use_cache: Optional[bool] = False,
414
- position_ids: Optional[torch.LongTensor] = None,
415
- ):
416
- batch_size, seq_len, _ = hidden_states.shape
417
- dropout_p = self.dropout_p if self.training else 0.0
418
-
419
- qkv = self.qkv_proj(hidden_states)
420
- query_states, key_states, value_states = qkv.chunk(3, dim=-1)
421
-
422
- query_states = query_states.view(
423
- batch_size, seq_len, self.num_heads, self.head_dim
424
- ).transpose(1, 2)
425
- key_states = key_states.view(
426
- batch_size, seq_len, self.num_heads, self.head_dim
427
- ).transpose(1, 2)
428
- value_states = value_states.view(
429
- batch_size, seq_len, self.num_heads, self.head_dim
430
- ).transpose(1, 2)
431
-
432
- cos, sin = position_embeddings
433
- query_states, key_states = apply_rotary_pos_emb(
434
- query_states, key_states, cos, sin, position_ids
435
- )
436
-
437
- if past_key_value is not None:
438
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
439
- # When using Cache object, updating happens in-place.
440
- key_states, value_states = past_key_value.update(
441
- key_states, value_states, self.layer_idx, cache_kwargs
442
- )
443
-
444
- if self.use_flash_attn:
445
- causal_mask = attention_mask
446
- if attention_mask is not None:
447
- if attention_mask.dim() == 4:
448
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
449
- elif attention_mask.dim() == 2:
450
- causal_mask = attention_mask
451
-
452
- if causal_mask.dtype not in [
453
- torch.bool,
454
- torch.float16,
455
- torch.float32,
456
- torch.bfloat16,
457
- ]:
458
- causal_mask = causal_mask.to(query_states.dtype)
459
-
460
- is_causal = (
461
- True if causal_mask is None and query_states.shape[-2] > 1 else False
462
- )
463
-
464
- attn_output = F.scaled_dot_product_attention(
465
- query_states,
466
- key_states,
467
- value_states,
468
- attn_mask=causal_mask,
469
- dropout_p=dropout_p,
470
- enable_gqa=False,
471
- scale=self.head_dim**-0.5,
472
- is_causal=is_causal,
473
- )
474
- attn_weights = None
475
- else:
476
- attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / (
477
- self.head_dim**0.5
478
- )
479
-
480
- if attention_mask is not None:
481
- attn_weights = attn_weights + attention_mask
482
- elif self.is_causal and seq_len > 1:
483
- causal_mask = torch.triu(
484
- torch.ones(
485
- (seq_len, key_states.shape[2]),
486
- dtype=torch.bool,
487
- device=query_states.device,
488
- ),
489
- diagonal=1,
490
- )
491
- attn_weights = attn_weights.masked_fill(
492
- causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")
493
- )
494
-
495
- attn_weights = F.softmax(attn_weights, dim=-1)
496
- attn_weights_dropped = F.dropout(
497
- attn_weights, p=dropout_p, training=self.training
498
- )
499
- attn_output = torch.matmul(attn_weights_dropped, value_states)
500
-
501
- attn_output = attn_output.transpose(1, 2).contiguous()
502
- attn_output = attn_output.view(batch_size, seq_len, -1)
503
- attn_output = self.o_proj(attn_output)
504
- attn_output = self.dropout(attn_output)
505
-
506
- if output_attentions:
507
- logger.warning(
508
- f"Layer {self.layer_idx}: Flash Attention does not return attention weights."
509
- )
510
-
511
- outputs = (attn_output,)
512
- if output_attentions:
513
- outputs += (attn_weights,)
514
- if use_cache:
515
- outputs += (past_key_value,) # Return the cache object
516
-
517
- return outputs
518
-
519
-
520
- class CrossAttention(torch.nn.Module):
521
- def __init__(
522
- self,
523
- config: NeuroBLASTConfig,
524
- query_dim: int,
525
- kv_dim: int,
526
- layer_idx: int,
527
- is_causal: bool = True,
528
- ):
529
- super().__init__()
530
- self.dropout_p = config.dropout
531
- self.layer_idx = layer_idx
532
- self.query_dim = query_dim
533
- self.kv_dim = kv_dim
534
- self.is_causal = is_causal
535
-
536
- self.num_heads = config.num_attention_heads
537
- self.head_dim = self.kv_dim // self.num_heads
538
- self.kv_head_dim = (
539
- self.kv_dim // self.num_heads
540
- )
541
-
542
- if (self.head_dim * self.num_heads) != self.kv_dim:
543
- raise ValueError(
544
- f"CrossAttn {layer_idx}: query_dim ({self.kv_dim}) must be divisible by num_heads ({self.num_heads})"
545
- )
546
- if (self.kv_head_dim * self.num_heads) != self.kv_dim:
547
- raise ValueError(
548
- f"CrossAttn {layer_idx}: kv_dim ({kv_dim}) must be divisible by num_heads ({self.num_heads})"
549
- )
550
-
551
- self.q_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
552
- self.k_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
553
- self.v_proj = nn.Linear(self.query_dim, self.kv_dim, bias=True)
554
- self.o_proj = nn.Linear(self.kv_dim, self.query_dim, bias=False)
555
- self.dropout = nn.Dropout(config.dropout)
556
-
557
- self.use_flash_attn = hasattr(
558
- F, "scaled_dot_product_attention"
559
- )
560
-
561
- def forward(
562
- self,
563
- query_states: torch.Tensor,
564
- kv_states: torch.Tensor,
565
- position_embeddings: Tuple[
566
- torch.Tensor, torch.Tensor
567
- ],
568
- past_key_value: Optional[Cache] = None,
569
- cache_position: Optional[torch.LongTensor] = None,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- output_attentions: Optional[bool] = False,
572
- position_ids: Optional[torch.LongTensor] = None,
573
- use_cache: Optional[bool] = False,
574
- ):
575
- batch_size, q_seq_len, _ = query_states.shape
576
- kv_seq_len = kv_states.shape[1]
577
- dropout_p = self.dropout_p if self.training else 0.0
578
-
579
- query = self.q_proj(query_states)
580
- key = self.k_proj(kv_states)
581
- value = self.v_proj(kv_states)
582
-
583
- query = query.view(
584
- batch_size, q_seq_len, self.num_heads, self.head_dim
585
- ).transpose(1, 2)
586
-
587
- cos, sin = position_embeddings
588
- query = apply_rotary_pos_emb_single(query, cos, sin, position_ids)
589
-
590
- if past_key_value is not None:
591
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
592
- key, value = past_key_value.update(
593
- key, value, self.layer_idx or 0, cache_kwargs
594
- )
595
- kv_seq_len = key.shape[1]
596
-
597
- key = key.view(
598
- batch_size, kv_seq_len, self.num_heads, self.kv_head_dim
599
- ).transpose(1, 2)
600
- value = value.view(
601
- batch_size, kv_seq_len, self.num_heads, self.kv_head_dim
602
- ).transpose(1, 2)
603
-
604
- sdpa_attn_mask = attention_mask
605
-
606
- if self.use_flash_attn:
607
- is_causal = True if sdpa_attn_mask is None and q_seq_len > 1 else False
608
- attn_output = F.scaled_dot_product_attention(
609
- query,
610
- key,
611
- value,
612
- attn_mask=sdpa_attn_mask if not is_causal else None,
613
- dropout_p=dropout_p,
614
- is_causal=is_causal,
615
- enable_gqa=False,
616
- scale=self.head_dim**-0.5,
617
- )
618
- attn_weights = None
619
- else:
620
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) / (
621
- self.head_dim**0.5
622
- )
623
-
624
- if sdpa_attn_mask is not None:
625
- attn_weights = attn_weights + sdpa_attn_mask
626
- elif self.is_causal and q_seq_len > 1:
627
- causal_mask = torch.triu(
628
- torch.ones(
629
- (q_seq_len, kv_seq_len), dtype=torch.bool, device=query.device
630
- ),
631
- diagonal=1,
632
- )
633
- attn_weights = attn_weights.masked_fill(
634
- causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")
635
- )
636
-
637
- attn_weights = F.softmax(attn_weights, dim=-1)
638
- attn_weights_dropped = F.dropout(
639
- attn_weights, p=dropout_p, training=self.training
640
- )
641
- attn_output = torch.matmul(attn_weights_dropped, value)
642
-
643
- attn_output = attn_output.transpose(
644
- 1, 2
645
- ).contiguous()
646
- attn_output = attn_output.reshape(batch_size, q_seq_len, self.kv_dim)
647
-
648
- attn_output = self.o_proj(attn_output)
649
- attn_output = self.dropout(attn_output)
650
-
651
- outputs = (attn_output,)
652
- if output_attentions:
653
- outputs += (attn_weights,)
654
-
655
- if use_cache:
656
- outputs += (past_key_value,)
657
-
658
- return outputs
659
-
660
-
661
- class AttentionBlock(torch.nn.Module):
662
- """Modified Attention Block with Pre-Norm and choice of Self/Cross Attention & MLP"""
663
-
664
- def __init__(
665
- self,
666
- config: NeuroBLASTConfig,
667
- hidden_size: int,
668
- attention_module: nn.Module,
669
- mlp_module: nn.Module,
670
- is_cross_attention: bool = False,
671
- layer_idx: int = 0,
672
- precomputed_total_layers: Optional[int] = None,
673
- ):
674
- super().__init__()
675
- self.hidden_size = hidden_size
676
- self.config = config
677
- self.layer_idx = layer_idx
678
- self.is_cross_attention = is_cross_attention
679
-
680
- self.input_layernorm = nn.LayerNorm(
681
- hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
682
- )
683
- self.attention = attention_module
684
-
685
- self.post_attention_layernorm = nn.LayerNorm(
686
- hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
687
- )
688
- self.mlp = mlp_module
689
-
690
- if self.config.use_zero_memory and (
691
- self.config.zero_memory_layers is None
692
- or self.layer_idx in self.config.zero_memory_layers
693
- ):
694
- self.router = NeuroBLASTRouterBlock(config, hidden_size)
695
- self.memory = NeuroBLASTMemory(
696
- config,
697
- hidden_size=hidden_size,
698
- layer_idx=layer_idx,
699
- precomputed_total_layers=precomputed_total_layers,
700
- )
701
-
702
- def forward(
703
- self,
704
- hidden_states: torch.Tensor,
705
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
706
- attention_mask: Optional[torch.Tensor],
707
- position_ids: Optional[torch.LongTensor],
708
- kv_states: Optional[torch.Tensor] = None,
709
- past_key_value: Optional[Cache] = None,
710
- cache_position: Optional[torch.LongTensor] = None,
711
- output_attentions: Optional[bool] = False,
712
- use_cache: Optional[bool] = False,
713
- previous_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
714
- ):
715
- residual = hidden_states
716
- hidden_states = torch.nan_to_num(hidden_states)
717
- normed_hidden_states = self.input_layernorm(hidden_states)
718
-
719
- if self.is_cross_attention:
720
- if kv_states is None:
721
- raise ValueError("kv_states must be provided for CrossAttention")
722
- attn_outputs = self.attention(
723
- query_states=normed_hidden_states,
724
- kv_states=kv_states,
725
- past_key_value=past_key_value,
726
- cache_position=cache_position,
727
- position_embeddings=position_embeddings,
728
- attention_mask=attention_mask,
729
- output_attentions=output_attentions,
730
- position_ids=position_ids,
731
- use_cache=use_cache,
732
- )
733
- else:
734
- attn_outputs = self.attention(
735
- normed_hidden_states,
736
- position_embeddings=position_embeddings,
737
- attention_mask=attention_mask,
738
- past_key_value=past_key_value,
739
- cache_position=cache_position,
740
- output_attentions=output_attentions,
741
- use_cache=use_cache,
742
- position_ids=position_ids,
743
- )
744
- attn_output = attn_outputs[0]
745
- past_key_value = attn_outputs[-1] if use_cache else None
746
-
747
- hidden_states = residual + attn_output
748
- hidden_states = torch.nan_to_num(hidden_states)
749
-
750
- residual = hidden_states
751
-
752
- normed_hidden_states = self.post_attention_layernorm(hidden_states)
753
-
754
- mlp_output = self.mlp(normed_hidden_states)
755
-
756
- hidden_states = residual + mlp_output
757
-
758
- hidden_states = torch.nan_to_num(hidden_states)
759
-
760
- if self.config.use_zero_memory and (
761
- self.config.zero_memory_layers is None
762
- or self.layer_idx in self.config.zero_memory_layers
763
- ):
764
- routing_weights, selected_experts = self.router(hidden_states)
765
-
766
- residual = hidden_states
767
-
768
- hidden_states, (hx, cx), past_key_value = self.memory(
769
- hidden_states,
770
- previous_states,
771
- past_key_value=past_key_value,
772
- cache_position=cache_position,
773
- position_embeddings=position_embeddings,
774
- attention_mask=attention_mask,
775
- output_attentions=output_attentions,
776
- position_ids=position_ids,
777
- use_cache=use_cache,
778
- )
779
-
780
- hidden_states = torch.nan_to_num(hidden_states)
781
- hx = torch.nan_to_num(hx)
782
- cx = torch.nan_to_num(cx)
783
-
784
- hidden_states = hidden_states * routing_weights.reshape(
785
- hidden_states.shape[:-1]
786
- ).unsqueeze(-1)
787
-
788
- hidden_states = residual + self.config.zero_memory_alpha * hidden_states
789
- hidden_states = torch.nan_to_num(hidden_states)
790
-
791
- outputs = (hidden_states,) + attn_outputs[1:]
792
-
793
- if self.config.use_zero_memory and (
794
- self.config.zero_memory_layers is None
795
- or self.layer_idx in self.config.zero_memory_layers
796
- ):
797
- outputs += ((hx, cx),)
798
- else:
799
- outputs += (previous_states,)
800
-
801
- if use_cache:
802
- outputs += (past_key_value,)
803
-
804
- return outputs
805
-
806
-
807
- class NeuroBLASTMemory(nn.Module):
808
- def __init__(
809
- self,
810
- config: NeuroBLASTConfig,
811
- hidden_size: int = 256,
812
- num_heads: int = 4,
813
- scale_factor: int = 4,
814
- layer_idx: int = 0,
815
- with_hx: bool = True,
816
- precomputed_total_layers: Optional[int] = None,
817
- *args,
818
- **kwargs,
819
- ):
820
- super().__init__(*args, **kwargs)
821
-
822
- self.hidden_size = hidden_size
823
- self.num_heads = num_heads
824
- self.scale_factor = scale_factor
825
- self.clamp_value = config.clamp_value
826
- # Use precomputed_total_layers instead of hardcoded 100 for layer index shift
827
- layer_shift = (
828
- precomputed_total_layers if precomputed_total_layers is not None else 100
829
- )
830
- self.layer_idx = layer_idx + layer_shift
831
- self.with_hx = with_hx
832
- self.kv_dim = (
833
- config.kv_dim
834
- )
835
-
836
- self.scaled_dim = hidden_size * scale_factor
837
- self.head_dim = self.kv_dim // config.num_attention_heads
838
- self.num_heads = self.hidden_size // self.head_dim
839
-
840
- self.norm1 = nn.LayerNorm(hidden_size)
841
-
842
- if self.with_hx:
843
- self.lin1 = nn.Linear(self.hidden_size, self.scaled_dim)
844
- self.lin2 = nn.Linear(self.hidden_size, self.scaled_dim)
845
- self.lin3 = nn.Linear(self.scaled_dim, self.hidden_size)
846
-
847
- self.lin4 = nn.Linear(self.hidden_size, self.scaled_dim)
848
- self.lin5 = nn.Linear(self.scaled_dim, self.hidden_size)
849
-
850
- if self.with_hx:
851
- self.lin6 = nn.Linear(self.scaled_dim, self.hidden_size)
852
-
853
- self.gate1 = nn.Linear(self.scaled_dim, self.scaled_dim)
854
- self.act1 = nn.SiLU()
855
-
856
- self.gate2 = nn.Linear(self.scaled_dim, self.scaled_dim)
857
- self.act2 = nn.SiLU()
858
-
859
- self.last_token_reg = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
860
- self.prev_tokens_reg = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
861
-
862
- self.norm2 = nn.LayerNorm(hidden_size)
863
- self.dropout = nn.Dropout(config.dropout)
864
-
865
- def forward(
866
- self,
867
- x: torch.Tensor,
868
- previous_state: tuple[torch.Tensor, torch.Tensor],
869
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
870
- attention_mask: Optional[torch.Tensor],
871
- position_ids: Optional[torch.LongTensor],
872
- past_key_value: Optional[Cache] = None,
873
- cache_position: Optional[torch.LongTensor] = None,
874
- output_attentions: Optional[bool] = False,
875
- use_cache: Optional[bool] = False,
876
- ):
877
- hx, cx = previous_state
878
-
879
- b, s, d = x.size()
880
-
881
- x = torch.nan_to_num(x)
882
- norm_x = self.norm1(x)
883
-
884
- norm_x = norm_x.view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
885
-
886
- cos, sin = position_embeddings
887
- norm_x = apply_rotary_pos_emb_single(norm_x, cos, sin, position_ids)
888
-
889
- kv_seq_len = s
890
-
891
- if past_key_value is not None:
892
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
893
- norm_x, _ = past_key_value.update(
894
- norm_x,
895
- torch.zeros((b, 1, kv_seq_len, d)),
896
- self.layer_idx,
897
- cache_kwargs,
898
- )
899
- kv_seq_len = norm_x.shape[2]
900
-
901
- norm_x = norm_x.transpose(1, 2).contiguous()
902
- norm_x = norm_x.view(b, kv_seq_len, d)
903
-
904
- norm_x = torch.nan_to_num(norm_x)
905
-
906
- expanded_x = None
907
-
908
- shifted_x = torch.cat(
909
- [
910
- torch.zeros((b, 1, d), device=x.device, dtype=x.dtype),
911
- (norm_x[:, :-1].contiguous()),
912
- ],
913
- dim=1,
914
- ).contiguous()
915
-
916
- shifted_x = torch.nan_to_num(shifted_x) # Replace NaNs with zeros
917
-
918
- prev_tokens_x = norm_x.cumsum(dim=1)
919
-
920
- prev_tokens_x = prev_tokens_x - shifted_x
921
- prev_tokens_x = torch.nan_to_num(prev_tokens_x)[:, -s:].contiguous()
922
-
923
- if self.with_hx:
924
- expanded_x = self.lin1(norm_x[:, -s:].contiguous())
925
- expanded_x = torch.nan_to_num(expanded_x)
926
-
927
- expanded_shifted_x = self.lin2(shifted_x[:, -s:].contiguous())
928
-
929
- expanded_shifted_x = torch.nan_to_num(expanded_shifted_x)
930
-
931
- gated_shifted_x = self.gate1(expanded_shifted_x)
932
- gated_shifted_x = self.act1(gated_shifted_x)
933
- gated_shifted_x = torch.clamp(
934
- gated_shifted_x, min=-self.clamp_value, max=self.clamp_value
935
- )
936
-
937
- gated_shifted_x = torch.nan_to_num(gated_shifted_x)
938
-
939
- collapsed_shifted_x = self.lin3(gated_shifted_x)
940
- collapsed_shifted_x = torch.nan_to_num(collapsed_shifted_x)
941
-
942
- prev_tokens_x = torch.nan_to_num(prev_tokens_x)
943
-
944
- expanded_prev_tokens_x = self.lin4(prev_tokens_x)
945
- expanded_prev_tokens_x = torch.nan_to_num(expanded_prev_tokens_x)
946
-
947
- gated_prev_tokens_x = self.gate2(expanded_prev_tokens_x)
948
- gated_prev_tokens_x = self.act2(gated_prev_tokens_x)
949
- gated_prev_tokens_x = torch.clamp(
950
- gated_prev_tokens_x, min=-self.clamp_value, max=self.clamp_value
951
- )
952
-
953
- gated_prev_tokens_x = torch.nan_to_num(gated_prev_tokens_x)
954
-
955
- collapsed_prev_tokens_x = self.lin5(gated_prev_tokens_x)
956
- collapsed_prev_tokens_x = torch.nan_to_num(collapsed_prev_tokens_x)
957
-
958
- if self.with_hx:
959
- weights = torch.softmax(expanded_x * expanded_shifted_x, dim=-1)
960
-
961
- expanded_x_attn = weights * expanded_x
962
-
963
- expanded_x_attn = torch.nan_to_num(expanded_x_attn)
964
- hx = hx + self.lin6(expanded_x_attn)
965
- hx = torch.nan_to_num(hx)
966
-
967
- if self.with_hx:
968
- x = torch.nan_to_num(x)
969
- hx = torch.nan_to_num(hx)
970
- collapsed_shifted_x = torch.nan_to_num(collapsed_shifted_x)
971
- collapsed_prev_tokens_x = torch.nan_to_num(collapsed_prev_tokens_x)
972
- output = x + (
973
- hx
974
- * (
975
- self.last_token_reg(collapsed_shifted_x)
976
- + self.prev_tokens_reg(collapsed_prev_tokens_x)
977
- )
978
- )
979
- output = torch.nan_to_num(output)
980
-
981
- else:
982
- output = (
983
- x
984
- + (
985
- self.last_token_reg(collapsed_shifted_x)
986
- + self.prev_tokens_reg(collapsed_prev_tokens_x)
987
- )[:, -s:].contiguous()
988
- )
989
-
990
- output = self.norm2(output)
991
- output = self.dropout(output)
992
- output = torch.nan_to_num(output)
993
-
994
- return (
995
- output,
996
- (hx, cx),
997
- past_key_value,
998
- )
999
-
1000
-
1001
- class NeuroBLASTPreTrainedModel(PreTrainedModel):
1002
- config_class = NeuroBLASTConfig
1003
- base_model_prefix = "brain"
1004
- supports_gradient_checkpointing = True
1005
- _no_split_modules = []
1006
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1007
- _supports_flash_attn_2 = False
1008
- _supports_sdpa = True
1009
-
1010
- def _init_weights(self, module):
1011
- std = getattr(self.config, "initializer_range", 0.02)
1012
- if isinstance(module, nn.Linear):
1013
- module.weight.data.normal_(mean=0.0, std=std)
1014
- if module.bias is not None:
1015
- module.bias.data.zero_()
1016
- elif isinstance(module, nn.Embedding):
1017
- module.weight.data.normal_(mean=0.0, std=std)
1018
- if module.padding_idx is not None:
1019
- module.weight.data[module.padding_idx].zero_()
1020
- elif isinstance(module, nn.LayerNorm):
1021
- if module.bias is not None:
1022
- module.bias.data.zero_()
1023
- module.weight.data.fill_(1.0)
1024
-
1025
-
1026
- class NeuroBLASTModel(NeuroBLASTPreTrainedModel):
1027
- def __init__(self, config: NeuroBLASTConfig):
1028
- super(NeuroBLASTModel, self).__init__(config)
1029
- self.config = config
1030
- self.padding_idx = config.pad_token_id
1031
- self.vocab_size = config.vocab_size
1032
-
1033
- self.embed_tokens = nn.Embedding(
1034
- config.vocab_size,
1035
- config.hidden_size, # Using main hidden_size for embeddings now
1036
- padding_idx=self.padding_idx,
1037
- )
1038
- self.rotary_emb = NeuroBLASTRotaryEmbedding(
1039
- config.kv_dim // config.num_attention_heads,
1040
- max_position_embeddings=config.max_position_embeddings,
1041
- base=config.rope_theta,
1042
- )
1043
- self.dropout = nn.Dropout(config.dropout)
1044
-
1045
- self.assoc_to_sensory_pooler = nn.Sequential(
1046
- nn.Linear(config.hidden_size, config.hidden_size),
1047
- nn.Identity(), # Backward compatibility - previously LayerNorm, but we found that removing it improve generalization
1048
- nn.GELU(),
1049
- nn.LayerNorm(config.hidden_size),
1050
- )
1051
- self.assoc_to_motor_pooler = nn.Sequential(
1052
- nn.Linear(config.hidden_size, config.hidden_size),
1053
- nn.Identity(), # Backward compatibility
1054
- nn.GELU(),
1055
- nn.LayerNorm(config.hidden_size),
1056
- )
1057
- self.sensory_to_motor_pooler = nn.Sequential(
1058
- nn.Linear(config.hidden_size, config.hidden_size),
1059
- nn.Identity(), # Backward compatibility
1060
- nn.GELU(),
1061
- nn.LayerNorm(config.hidden_size),
1062
- )
1063
-
1064
- # --- Cortex Layers ---
1065
- # Using generic AttentionBlock with specific Self/Cross Attention modules passed in
1066
- total_layers = 0
1067
-
1068
- # Precompute total layers for Memory layer indexing before creating any layers
1069
- precomputed_total_layers = (
1070
- config.num_association_cortex_layers
1071
- + config.num_sensory_cortex_layers * 2
1072
- + config.num_motor_cortex_layers
1073
- * 3
1074
- )
1075
- self.precomputed_total_layers = precomputed_total_layers
1076
- config.precomputed_total_layers = precomputed_total_layers
1077
-
1078
- # 1. Association Cortex (Self-Attention)
1079
- self.association_cortex = nn.ModuleList()
1080
- for i in range(config.num_association_cortex_layers):
1081
- layer_idx = total_layers + i
1082
- print(f"Adding layer {layer_idx} to association cortex")
1083
- self.association_cortex.append(
1084
- AttentionBlock(
1085
- config,
1086
- config.hidden_size, # Use main hidden_size
1087
- attention_module=SelfAttention(
1088
- config, config.hidden_size, is_causal=True, layer_idx=layer_idx
1089
- ),
1090
- mlp_module=(
1091
- NeuroBLASTSparseMoeBlock(
1092
- config,
1093
- )
1094
- if config.num_experts
1095
- else NeuroBLASTMoeMLP(config)
1096
- ),
1097
- is_cross_attention=False,
1098
- layer_idx=layer_idx,
1099
- precomputed_total_layers=precomputed_total_layers,
1100
- )
1101
- )
1102
-
1103
- total_layers += config.num_association_cortex_layers
1104
- # 2. Sensory Cortex (Self-Attention + Cross-Attention to Association)
1105
- self.sensory_self_attn_layers = nn.ModuleList()
1106
- self.sensory_cross_attn_layers = (
1107
- nn.ModuleList()
1108
- ) # One cross-attn per self-attn layer
1109
- for i in range(config.num_sensory_cortex_layers):
1110
- layer_idx = total_layers + i
1111
- print(f"Adding layer {layer_idx} to sensory cortex")
1112
- self.sensory_self_attn_layers.append(
1113
- AttentionBlock(
1114
- config,
1115
- config.hidden_size,
1116
- attention_module=SelfAttention(
1117
- config,
1118
- config.hidden_size,
1119
- is_causal=True,
1120
- layer_idx=layer_idx,
1121
- ),
1122
- mlp_module=(
1123
- NeuroBLASTSparseMoeBlock(
1124
- config,
1125
- )
1126
- if config.num_experts
1127
- else NeuroBLASTMoeMLP(config)
1128
- ),
1129
- is_cross_attention=False,
1130
- layer_idx=layer_idx,
1131
- precomputed_total_layers=precomputed_total_layers,
1132
- )
1133
- )
1134
-
1135
- total_layers += config.num_sensory_cortex_layers
1136
- for i in range(config.num_sensory_cortex_layers):
1137
- layer_idx = total_layers + i
1138
- print(f"Adding layer {layer_idx} to sensory cross-attention")
1139
- # Add Cross-Attention layer: Sensory queries, Association is K/V source
1140
- self.sensory_cross_attn_layers.append(
1141
- AttentionBlock(
1142
- config,
1143
- config.hidden_size, # Query Dim
1144
- attention_module=CrossAttention(
1145
- config,
1146
- query_dim=config.hidden_size,
1147
- kv_dim=config.kv_dim, # Assoc output dim
1148
- layer_idx=layer_idx,
1149
- ),
1150
- mlp_module=(
1151
- NeuroBLASTSparseMoeBlock(
1152
- config,
1153
- )
1154
- if config.num_experts
1155
- else NeuroBLASTMoeMLP(config)
1156
- ),
1157
- is_cross_attention=True,
1158
- layer_idx=layer_idx,
1159
- precomputed_total_layers=precomputed_total_layers,
1160
- )
1161
- )
1162
-
1163
- total_layers += config.num_sensory_cortex_layers
1164
-
1165
- # 3. Motor Cortex (Self-Attention + Cross-Attention to Sensory + Cross-Attention to Association)
1166
- self.motor_self_attn_layers = nn.ModuleList()
1167
- self.motor_cross_sensory_layers = nn.ModuleList()
1168
- self.motor_cross_assoc_layers = nn.ModuleList()
1169
- for i in range(config.num_motor_cortex_layers):
1170
- layer_idx = total_layers + i
1171
- print(f"Adding layer {layer_idx} to motor cortex")
1172
- self.motor_self_attn_layers.append(
1173
- AttentionBlock(
1174
- config,
1175
- config.hidden_size,
1176
- attention_module=SelfAttention(
1177
- config,
1178
- config.hidden_size,
1179
- is_causal=True,
1180
- layer_idx=layer_idx,
1181
- ),
1182
- mlp_module=(
1183
- NeuroBLASTSparseMoeBlock(
1184
- config,
1185
- )
1186
- if config.num_experts
1187
- else NeuroBLASTMoeMLP(config)
1188
- ),
1189
- is_cross_attention=False,
1190
- layer_idx=layer_idx,
1191
- precomputed_total_layers=precomputed_total_layers,
1192
- )
1193
- )
1194
-
1195
- total_layers += config.num_motor_cortex_layers
1196
- for i in range(config.num_motor_cortex_layers):
1197
- layer_idx = total_layers + i
1198
- print(f"Adding layer {layer_idx} to motor cross-sensory")
1199
- # Cross-Attend to Sensory Output
1200
- self.motor_cross_sensory_layers.append(
1201
- AttentionBlock(
1202
- config,
1203
- config.hidden_size, # Query Dim
1204
- attention_module=CrossAttention(
1205
- config,
1206
- query_dim=config.hidden_size,
1207
- kv_dim=config.kv_dim, # Sensory output dim
1208
- layer_idx=layer_idx,
1209
- ),
1210
- mlp_module=(
1211
- NeuroBLASTSparseMoeBlock(
1212
- config,
1213
- )
1214
- if config.num_experts
1215
- else NeuroBLASTMoeMLP(config)
1216
- ),
1217
- is_cross_attention=True,
1218
- layer_idx=layer_idx,
1219
- precomputed_total_layers=precomputed_total_layers,
1220
- )
1221
- )
1222
-
1223
- total_layers += config.num_motor_cortex_layers
1224
- for i in range(config.num_motor_cortex_layers):
1225
- layer_idx = total_layers + i
1226
- print(f"Adding layer {layer_idx} to motor cross-association")
1227
- # Cross-Attend to Association Output
1228
- self.motor_cross_assoc_layers.append(
1229
- AttentionBlock(
1230
- config,
1231
- config.hidden_size, # Query Dim
1232
- attention_module=CrossAttention(
1233
- config,
1234
- query_dim=config.hidden_size,
1235
- kv_dim=config.kv_dim, # Assoc output dim
1236
- layer_idx=layer_idx,
1237
- ),
1238
- mlp_module=(
1239
- NeuroBLASTSparseMoeBlock(
1240
- config,
1241
- )
1242
- if config.num_experts
1243
- else NeuroBLASTMoeMLP(config)
1244
- ),
1245
- is_cross_attention=True,
1246
- layer_idx=layer_idx,
1247
- precomputed_total_layers=precomputed_total_layers,
1248
- )
1249
- )
1250
-
1251
- total_layers += config.num_motor_cortex_layers
1252
-
1253
- # Initialize more conservatively to prevent strong gradient flow initially
1254
- self.sensory_cross_assoc_gate = NeuroBLASTMoeMLP(
1255
- config,
1256
- )
1257
- self.motor_cross_sensory_gate = NeuroBLASTMoeMLP(
1258
- config,
1259
- )
1260
- self.motor_cross_assoc_gate = NeuroBLASTMoeMLP(
1261
- config,
1262
- )
1263
-
1264
- # Final normalization before output head
1265
- self.norm = nn.LayerNorm(
1266
- config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
1267
- )
1268
-
1269
- self.gradient_checkpointing = False
1270
- self.post_init()
1271
-
1272
- def forward(
1273
- self,
1274
- input_ids: torch.LongTensor,
1275
- attention_mask: Optional[torch.Tensor] = None,
1276
- position_ids: Optional[torch.LongTensor] = None,
1277
- past_key_values: Optional[
1278
- Cache
1279
- ] = None,
1280
- inputs_embeds: Optional[torch.FloatTensor] = None,
1281
- use_cache: Optional[bool] = None,
1282
- output_attentions: Optional[bool] = None,
1283
- output_hidden_states: Optional[bool] = None,
1284
- return_dict: Optional[bool] = None,
1285
- cache_position: Optional[torch.LongTensor] = None,
1286
- ):
1287
- output_attentions = (
1288
- output_attentions
1289
- if output_attentions is not None
1290
- else self.config.output_attentions
1291
- )
1292
- output_hidden_states = (
1293
- output_hidden_states
1294
- if output_hidden_states is not None
1295
- else self.config.output_hidden_states
1296
- )
1297
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1298
- # use_cache = False
1299
- return_dict = (
1300
- return_dict if return_dict is not None else self.config.use_return_dict
1301
- )
1302
-
1303
- if input_ids is not None and inputs_embeds is not None:
1304
- raise ValueError("Specify either input_ids or inputs_embeds")
1305
- batch_size, seq_length = (
1306
- input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
1307
- )
1308
-
1309
- if self.gradient_checkpointing and self.training:
1310
- if use_cache:
1311
- logger.warning_once(
1312
- "`use_cache=True` incompatible with gradient checkpointing. Setting `use_cache=False`"
1313
- )
1314
- use_cache = False
1315
-
1316
- if not any(param.requires_grad for param in self.parameters()):
1317
- logger.warning_once(
1318
- "No parameters require gradients. Disabling gradient checkpointing to avoid warnings."
1319
- )
1320
- self.gradient_checkpointing = False
1321
-
1322
- if inputs_embeds is None:
1323
- inputs_embeds = self.embed_tokens(input_ids)
1324
-
1325
- past_key_values_length = 0
1326
- if use_cache:
1327
- if not isinstance(past_key_values, Cache):
1328
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1329
- past_key_values_length = past_key_values.get_seq_length()
1330
-
1331
- if cache_position is None:
1332
- past_seen_tokens = (
1333
- past_key_values.get_seq_length() if past_key_values is not None else 0
1334
- )
1335
- cache_position = torch.arange(
1336
- past_seen_tokens,
1337
- past_seen_tokens + inputs_embeds.shape[1],
1338
- device=inputs_embeds.device,
1339
- )
1340
-
1341
- if position_ids is None:
1342
- position_ids = cache_position.unsqueeze(0)
1343
-
1344
- causal_mask = self._update_causal_mask(
1345
- attention_mask,
1346
- inputs_embeds,
1347
- cache_position,
1348
- past_key_values,
1349
- output_attentions,
1350
- )
1351
-
1352
- hidden_states = inputs_embeds
1353
-
1354
- cos, sin = self.rotary_emb(
1355
- hidden_states, seq_len=seq_length + past_key_values_length
1356
- )
1357
- position_embeddings = (cos, sin)
1358
-
1359
- all_hidden_states = () if output_hidden_states else None
1360
- all_attentions = (
1361
- () if output_attentions else None
1362
- )
1363
- next_decoder_cache = (
1364
- past_key_values if use_cache else None
1365
- )
1366
-
1367
- if self.config.use_zero_memory:
1368
- hx = torch.ones(
1369
- (batch_size, seq_length, hidden_states.size(-1)),
1370
- device=hidden_states.device,
1371
- dtype=hidden_states.dtype,
1372
- )
1373
- cx = torch.ones(
1374
- (batch_size, seq_length, hidden_states.size(-1)),
1375
- device=hidden_states.device,
1376
- dtype=hidden_states.dtype,
1377
- )
1378
-
1379
- if self.training:
1380
- hx.requires_grad_()
1381
- cx.requires_grad_()
1382
- else:
1383
- hx = None
1384
- cx = None
1385
-
1386
- # 1. Association Cortex (Self-Attention)
1387
- assoc_output = hidden_states
1388
- for i, layer in enumerate(self.association_cortex):
1389
- if output_hidden_states:
1390
- all_hidden_states += (assoc_output,)
1391
- if self.gradient_checkpointing and self.training:
1392
- outputs = torch.utils.checkpoint.checkpoint(
1393
- layer,
1394
- assoc_output,
1395
- position_embeddings,
1396
- causal_mask,
1397
- position_ids,
1398
- None,
1399
- next_decoder_cache,
1400
- cache_position,
1401
- output_attentions,
1402
- use_cache,
1403
- (hx, cx),
1404
- )
1405
- else:
1406
- outputs = layer(
1407
- assoc_output,
1408
- position_embeddings,
1409
- causal_mask,
1410
- position_ids,
1411
- kv_states=None,
1412
- past_key_value=next_decoder_cache,
1413
- cache_position=cache_position,
1414
- output_attentions=output_attentions,
1415
- use_cache=use_cache,
1416
- previous_states=(hx, cx),
1417
- )
1418
- assoc_output = outputs[0]
1419
-
1420
- hx, cx = outputs[-1 if not use_cache else -2]
1421
-
1422
- if output_attentions:
1423
- all_attentions += (outputs[1],)
1424
-
1425
- if use_cache:
1426
- next_decoder_cache = outputs[-1]
1427
- else:
1428
- next_decoder_cache = None
1429
-
1430
- sensory_state = self.assoc_to_sensory_pooler(assoc_output)
1431
-
1432
- sensory_state = apply_gradient_scaling(
1433
- sensory_state,
1434
- self.config.association_gradient_scale,
1435
- self.config.gradient_scaling_enabled,
1436
- )
1437
-
1438
- # 2. Sensory Cortex (Self-Attention + Cross-Attention to Association)
1439
- for i in range(self.config.num_sensory_cortex_layers):
1440
- if output_hidden_states:
1441
- all_hidden_states += (sensory_state,)
1442
-
1443
- self_attn_layer = self.sensory_self_attn_layers[i]
1444
- if self.gradient_checkpointing and self.training:
1445
- outputs_self = torch.utils.checkpoint.checkpoint(
1446
- self_attn_layer,
1447
- sensory_state,
1448
- position_embeddings,
1449
- causal_mask,
1450
- position_ids,
1451
- None,
1452
- next_decoder_cache,
1453
- cache_position,
1454
- output_attentions,
1455
- use_cache,
1456
- (hx, cx),
1457
- )
1458
- else:
1459
- outputs_self = self_attn_layer(
1460
- sensory_state,
1461
- position_embeddings,
1462
- causal_mask,
1463
- position_ids,
1464
- kv_states=None,
1465
- past_key_value=next_decoder_cache,
1466
- cache_position=cache_position,
1467
- output_attentions=output_attentions,
1468
- use_cache=use_cache,
1469
- previous_states=(hx, cx),
1470
- )
1471
- sensory_state = outputs_self[0]
1472
- hx, cx = outputs_self[-1 if not use_cache else -2]
1473
-
1474
- if output_attentions:
1475
- all_attentions += (outputs_self[1],)
1476
-
1477
- if use_cache:
1478
- next_decoder_cache = outputs_self[-1]
1479
- else:
1480
- next_decoder_cache = None
1481
-
1482
- cross_attn_layer = self.sensory_cross_attn_layers[i]
1483
-
1484
- cross_attn_causal_mask = None
1485
-
1486
- if causal_mask is not None:
1487
- q_seq_len = sensory_state.size(1)
1488
- kv_seq_len = assoc_output.size(1)
1489
-
1490
- cross_attn_causal_mask = torch.ones(
1491
- (batch_size, 1, q_seq_len, kv_seq_len),
1492
- device=hidden_states.device,
1493
- dtype=hidden_states.dtype,
1494
- )
1495
-
1496
- causal_mask_upper = torch.triu(
1497
- torch.ones((q_seq_len, kv_seq_len), device=hidden_states.device),
1498
- diagonal=1,
1499
- )
1500
-
1501
- cross_attn_causal_mask = cross_attn_causal_mask.masked_fill(
1502
- causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1503
- torch.finfo(hidden_states.dtype).min,
1504
- )
1505
-
1506
- if self.gradient_checkpointing and self.training:
1507
- outputs_cross = torch.utils.checkpoint.checkpoint(
1508
- cross_attn_layer,
1509
- sensory_state,
1510
- position_embeddings,
1511
- cross_attn_causal_mask,
1512
- position_ids,
1513
- assoc_output,
1514
- next_decoder_cache,
1515
- cache_position,
1516
- output_attentions,
1517
- use_cache,
1518
- (hx, cx),
1519
- )
1520
- else:
1521
- outputs_cross = cross_attn_layer(
1522
- sensory_state,
1523
- position_embeddings,
1524
- past_key_value=next_decoder_cache,
1525
- cache_position=cache_position,
1526
- attention_mask=cross_attn_causal_mask,
1527
- position_ids=position_ids,
1528
- kv_states=apply_gradient_scaling(
1529
- assoc_output,
1530
- self.config.cross_attention_gradient_scale,
1531
- self.config.gradient_scaling_enabled,
1532
- ),
1533
- output_attentions=output_attentions,
1534
- use_cache=use_cache,
1535
- previous_states=(hx, cx),
1536
- )
1537
-
1538
- cross_contribution = nn.functional.layer_norm(
1539
- outputs_cross[0],
1540
- normalized_shape=(self.config.hidden_size,),
1541
- eps=getattr(self.config, "rms_norm_eps", 1e-5),
1542
- )
1543
-
1544
- sensory_state = sensory_state + self.sensory_cross_assoc_gate(
1545
- cross_contribution
1546
- )
1547
-
1548
- hx, cx = outputs_cross[-1 if not use_cache else -2]
1549
- if output_attentions:
1550
- all_attentions += (outputs_cross[1],)
1551
-
1552
- if use_cache:
1553
- next_decoder_cache = outputs_cross[-1]
1554
- else:
1555
- next_decoder_cache = None
1556
-
1557
- motor_state = self.sensory_to_motor_pooler(sensory_state)
1558
-
1559
- motor_state = apply_gradient_scaling(
1560
- motor_state,
1561
- self.config.sensory_gradient_scale,
1562
- self.config.gradient_scaling_enabled,
1563
- )
1564
-
1565
- motor_state_from_assoc = self.assoc_to_motor_pooler(assoc_output)
1566
-
1567
- motor_state_from_assoc = apply_gradient_scaling(
1568
- motor_state_from_assoc,
1569
- self.config.association_gradient_scale,
1570
- self.config.gradient_scaling_enabled,
1571
- )
1572
-
1573
- motor_state = motor_state + motor_state_from_assoc # Combine pooled inputs
1574
-
1575
- # 3. Motor Cortex (Self + Cross-Sensory + Cross-Association)
1576
- for i in range(self.config.num_motor_cortex_layers):
1577
- if output_hidden_states:
1578
- all_hidden_states += (motor_state,)
1579
-
1580
- self_attn_layer = self.motor_self_attn_layers[i]
1581
- if self.gradient_checkpointing and self.training:
1582
- outputs_self = torch.utils.checkpoint.checkpoint(
1583
- self_attn_layer,
1584
- motor_state,
1585
- position_embeddings,
1586
- causal_mask,
1587
- position_ids,
1588
- None,
1589
- next_decoder_cache,
1590
- cache_position,
1591
- output_attentions,
1592
- use_cache,
1593
- (hx, cx),
1594
- )
1595
- else:
1596
- outputs_self = self_attn_layer(
1597
- motor_state,
1598
- position_embeddings,
1599
- causal_mask,
1600
- position_ids,
1601
- kv_states=None,
1602
- past_key_value=next_decoder_cache,
1603
- cache_position=cache_position,
1604
- output_attentions=output_attentions,
1605
- use_cache=use_cache,
1606
- previous_states=(hx, cx),
1607
- )
1608
- motor_state = outputs_self[0]
1609
- hx, cx = outputs_self[-1 if not use_cache else -2]
1610
- if output_attentions:
1611
- all_attentions += (outputs_self[1],)
1612
-
1613
- if use_cache:
1614
- next_decoder_cache = outputs_self[-1]
1615
- else:
1616
- next_decoder_cache = None
1617
-
1618
- cross_sensory_layer = self.motor_cross_sensory_layers[i]
1619
-
1620
- motor_cross_sensory_mask = None
1621
-
1622
- if causal_mask is not None:
1623
- motor_q_seq_len = motor_state.size(1)
1624
- sensory_kv_seq_len = sensory_state.size(
1625
- 1
1626
- )
1627
-
1628
- motor_cross_sensory_mask = torch.ones(
1629
- (batch_size, 1, motor_q_seq_len, sensory_kv_seq_len),
1630
- device=hidden_states.device,
1631
- dtype=hidden_states.dtype,
1632
- )
1633
-
1634
- causal_mask_upper = torch.triu(
1635
- torch.ones(
1636
- (motor_q_seq_len, sensory_kv_seq_len),
1637
- device=hidden_states.device,
1638
- ),
1639
- diagonal=1,
1640
- )
1641
-
1642
- motor_cross_sensory_mask = motor_cross_sensory_mask.masked_fill(
1643
- causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1644
- torch.finfo(hidden_states.dtype).min,
1645
- )
1646
-
1647
- if self.gradient_checkpointing and self.training:
1648
- outputs_cross_sensory = torch.utils.checkpoint.checkpoint(
1649
- cross_sensory_layer,
1650
- motor_state,
1651
- position_embeddings,
1652
- motor_cross_sensory_mask,
1653
- position_ids,
1654
- sensory_state,
1655
- next_decoder_cache,
1656
- cache_position,
1657
- output_attentions,
1658
- use_cache,
1659
- (hx, cx),
1660
- )
1661
- else:
1662
- outputs_cross_sensory = cross_sensory_layer(
1663
- motor_state,
1664
- position_embeddings,
1665
- attention_mask=motor_cross_sensory_mask,
1666
- position_ids=position_ids,
1667
- kv_states=apply_gradient_scaling(
1668
- sensory_state,
1669
- self.config.cross_attention_gradient_scale,
1670
- self.config.gradient_scaling_enabled,
1671
- ),
1672
- output_attentions=output_attentions,
1673
- past_key_value=next_decoder_cache,
1674
- cache_position=cache_position,
1675
- use_cache=use_cache,
1676
- previous_states=(hx, cx),
1677
- )
1678
- motor_state = motor_state + self.motor_cross_sensory_gate(
1679
- outputs_cross_sensory[0]
1680
- )
1681
- hx, cx = outputs_cross_sensory[-1 if not use_cache else -2]
1682
- if output_attentions:
1683
- all_attentions += (outputs_cross_sensory[1],)
1684
-
1685
- if use_cache:
1686
- next_decoder_cache = outputs_cross_sensory[-1]
1687
- else:
1688
- next_decoder_cache = None
1689
-
1690
- cross_assoc_layer = self.motor_cross_assoc_layers[i]
1691
-
1692
- motor_cross_assoc_mask = None
1693
- if causal_mask is not None:
1694
- motor_q_seq_len = motor_state.size(1)
1695
- assoc_kv_seq_len = assoc_output.size(
1696
- 1
1697
- )
1698
-
1699
- motor_cross_assoc_mask = torch.ones(
1700
- (batch_size, 1, motor_q_seq_len, assoc_kv_seq_len),
1701
- device=hidden_states.device,
1702
- dtype=hidden_states.dtype,
1703
- )
1704
-
1705
- causal_mask_upper = torch.triu(
1706
- torch.ones(
1707
- (motor_q_seq_len, assoc_kv_seq_len), device=hidden_states.device
1708
- ),
1709
- diagonal=1,
1710
- )
1711
-
1712
- motor_cross_assoc_mask = motor_cross_assoc_mask.masked_fill(
1713
- causal_mask_upper.unsqueeze(0).unsqueeze(0).bool(),
1714
- torch.finfo(hidden_states.dtype).min,
1715
- )
1716
-
1717
- if self.gradient_checkpointing and self.training:
1718
- outputs_cross_assoc = torch.utils.checkpoint.checkpoint(
1719
- cross_assoc_layer,
1720
- motor_state,
1721
- position_embeddings,
1722
- motor_cross_assoc_mask,
1723
- position_ids,
1724
- assoc_output,
1725
- next_decoder_cache,
1726
- cache_position,
1727
- output_attentions,
1728
- use_cache,
1729
- (hx, cx),
1730
- )
1731
- else:
1732
- outputs_cross_assoc = cross_assoc_layer(
1733
- motor_state,
1734
- position_embeddings,
1735
- attention_mask=motor_cross_assoc_mask,
1736
- position_ids=position_ids,
1737
- kv_states=apply_gradient_scaling(
1738
- assoc_output,
1739
- self.config.cross_attention_gradient_scale,
1740
- self.config.gradient_scaling_enabled,
1741
- ),
1742
- output_attentions=output_attentions,
1743
- past_key_value=next_decoder_cache,
1744
- cache_position=cache_position,
1745
- use_cache=use_cache,
1746
- previous_states=(hx, cx),
1747
- )
1748
- motor_state = motor_state + self.motor_cross_assoc_gate(
1749
- outputs_cross_assoc[0]
1750
- )
1751
- hx, cx = outputs_cross_assoc[-1 if not use_cache else -2]
1752
- if output_attentions:
1753
- all_attentions += (outputs_cross_assoc[1],)
1754
-
1755
- if use_cache:
1756
- next_decoder_cache = outputs_cross_assoc[-1]
1757
- else:
1758
- next_decoder_cache = None
1759
-
1760
- final_output = self.norm(motor_state)
1761
-
1762
- if output_hidden_states:
1763
- all_hidden_states += (final_output,)
1764
-
1765
- if not return_dict:
1766
- outputs_tuple = (final_output,)
1767
- if use_cache:
1768
- outputs_tuple += (next_decoder_cache,)
1769
- if output_hidden_states:
1770
- outputs_tuple += (all_hidden_states,)
1771
- if output_attentions:
1772
- outputs_tuple += (all_attentions,)
1773
- return tuple(v for v in outputs_tuple if v is not None)
1774
-
1775
- return BaseModelOutputWithPast(
1776
- last_hidden_state=final_output,
1777
- past_key_values=next_decoder_cache,
1778
- hidden_states=all_hidden_states,
1779
- attentions=all_attentions,
1780
- )
1781
-
1782
- def get_input_embeddings(self):
1783
- return self.embed_tokens
1784
-
1785
- def set_input_embeddings(self, value):
1786
- self.embed_tokens = value
1787
-
1788
- def _update_causal_mask(
1789
- self,
1790
- attention_mask: torch.Tensor,
1791
- input_tensor: torch.Tensor,
1792
- cache_position: torch.Tensor,
1793
- past_key_values: Cache,
1794
- output_attentions: bool,
1795
- ):
1796
- if self.config._attn_implementation == "flash_attention_2":
1797
- if attention_mask is not None and 0.0 in attention_mask:
1798
- return attention_mask
1799
- return None
1800
-
1801
- past_seen_tokens = (
1802
- past_key_values.get_seq_length() if past_key_values is not None else 0
1803
- )
1804
- using_static_cache = isinstance(past_key_values, StaticCache)
1805
-
1806
- if (
1807
- self.config._attn_implementation == "sdpa"
1808
- and not using_static_cache
1809
- and not output_attentions
1810
- ):
1811
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1812
- attention_mask,
1813
- inputs_embeds=input_tensor,
1814
- past_key_values_length=past_seen_tokens,
1815
- is_training=self.training,
1816
- ):
1817
- return None
1818
-
1819
- dtype, device = input_tensor.dtype, input_tensor.device
1820
- min_dtype = torch.finfo(dtype).min
1821
- sequence_length = input_tensor.shape[1]
1822
- if using_static_cache:
1823
- target_length = (
1824
- getattr(
1825
- past_key_values,
1826
- "get_max_length",
1827
- lambda: past_key_values.get_seq_length(),
1828
- )()
1829
- if hasattr(past_key_values, "get_seq_length")
1830
- else sequence_length + past_seen_tokens
1831
- )
1832
- else:
1833
- target_length = (
1834
- attention_mask.shape[-1]
1835
- if isinstance(attention_mask, torch.Tensor)
1836
- else past_seen_tokens + sequence_length + 1
1837
- )
1838
-
1839
- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1840
- attention_mask,
1841
- sequence_length=sequence_length,
1842
- target_length=target_length,
1843
- dtype=dtype,
1844
- device=device,
1845
- min_dtype=min_dtype,
1846
- cache_position=cache_position,
1847
- batch_size=input_tensor.shape[0],
1848
- )
1849
-
1850
- if (
1851
- self.config._attn_implementation == "sdpa"
1852
- and attention_mask is not None
1853
- and attention_mask.device.type == "cuda"
1854
- and not output_attentions
1855
- ):
1856
- causal_mask = AttentionMaskConverter._unmask_unattended(
1857
- causal_mask, min_dtype
1858
- )
1859
-
1860
- return causal_mask
1861
-
1862
-
1863
- class NeuroBLASTForCausalLM(NeuroBLASTPreTrainedModel, GenerationMixin):
1864
- _tied_weights_keys = ["lm_head.weight"]
1865
-
1866
- def __init__(self, config: NeuroBLASTConfig):
1867
- super().__init__(config)
1868
- self.config = config
1869
- self.model = NeuroBLASTModel(config)
1870
- self.vocab_size = config.vocab_size # Ensure vocab_size is accessible
1871
- self.lm_head = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=False)
1872
- self.loss_steps = 0
1873
- self.post_init() # Initialize weights
1874
-
1875
- def get_input_embeddings(self):
1876
- return self.model.get_input_embeddings()
1877
-
1878
- def set_input_embeddings(self, value):
1879
- self.model.set_input_embeddings(value)
1880
-
1881
- def get_output_embeddings(self):
1882
- return self.lm_head
1883
-
1884
- def set_output_embeddings(self, new_embeddings):
1885
- self.lm_head = new_embeddings
1886
-
1887
- def tie_weights(self):
1888
- if getattr(self.config, "tie_word_embeddings", False):
1889
- output_embeddings = self.get_output_embeddings()
1890
- input_embeddings = self.get_input_embeddings()
1891
- output_embeddings.weight = input_embeddings.weight
1892
- if getattr(output_embeddings, "bias", None) is not None:
1893
- output_embeddings.bias.data.zero_()
1894
-
1895
- def forward(
1896
- self,
1897
- input_ids: torch.LongTensor = None,
1898
- attention_mask: Optional[torch.Tensor] = None,
1899
- position_ids: Optional[torch.LongTensor] = None,
1900
- past_key_values: Optional[Cache] = None,
1901
- inputs_embeds: Optional[torch.FloatTensor] = None,
1902
- labels: Optional[torch.LongTensor] = None,
1903
- use_cache: Optional[bool] = None,
1904
- output_attentions: Optional[bool] = None,
1905
- output_hidden_states: Optional[bool] = None,
1906
- return_dict: Optional[bool] = None,
1907
- cache_position: Optional[torch.LongTensor] = None,
1908
- **loss_kwargs,
1909
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1910
-
1911
- return_dict = (
1912
- return_dict if return_dict is not None else self.config.use_return_dict
1913
- )
1914
-
1915
- outputs = self.model(
1916
- input_ids=input_ids,
1917
- attention_mask=attention_mask,
1918
- position_ids=position_ids,
1919
- past_key_values=past_key_values,
1920
- inputs_embeds=inputs_embeds,
1921
- use_cache=use_cache,
1922
- output_attentions=output_attentions,
1923
- output_hidden_states=output_hidden_states,
1924
- return_dict=return_dict,
1925
- cache_position=cache_position,
1926
- )
1927
-
1928
- hidden_states = outputs[0]
1929
-
1930
- logits = self.lm_head(hidden_states)
1931
- logits = logits.float()
1932
-
1933
- loss = None
1934
- if labels is not None:
1935
- loss = self.loss_function(
1936
- logits=logits,
1937
- labels=labels,
1938
- vocab_size=self.config.vocab_size,
1939
- **loss_kwargs,
1940
- )
1941
-
1942
- if not return_dict:
1943
- output = (logits,) + outputs[1:]
1944
- return (loss,) + output if loss is not None else output
1945
-
1946
- return CausalLMOutputWithPast(
1947
- loss=loss,
1948
- logits=logits,
1949
- past_key_values=outputs.past_key_values,
1950
- hidden_states=hidden_states,
1951
- attentions=outputs.attentions,
1952
- )
1953
-
1954
- def get_input_embeddings(self):
1955
- return self.model.get_input_embeddings()
1956
-
1957
- def set_input_embeddings(self, value):
1958
- self.model.set_input_embeddings(value)
1959
-
1960
- def __str__(self):
1961
- return f"NeuroBLASTForCausalLM(config={self.config})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neuroblast_model/registration.py DELETED
@@ -1,5 +0,0 @@
1
- from transformers import AutoConfig, AutoModel
2
- from neuroblast_model import NeuroBLASTConfig, NeuroBLASTForCausalLM
3
-
4
- AutoConfig.register("neuroblast", NeuroBLASTForCausalLM)
5
- AutoModel.register(NeuroBLASTConfig, NeuroBLASTForCausalLM)