davda54 commited on
Commit
fd35fdf
·
verified ·
1 Parent(s): 40430cf

FlashAttention support

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +484 -613
modeling_gptbert.py CHANGED
@@ -3,13 +3,13 @@ from __future__ import annotations
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
- from torch import _softmax_backward_data as _softmax_backward_data
7
 
8
- from functools import partial
9
 
10
  from .configuration_gptbert import GptBertConfig
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.activations import gelu_new
 
13
  from transformers.modeling_outputs import (
14
  MaskedLMOutput,
15
  MultipleChoiceModelOutput,
@@ -22,111 +22,82 @@ from transformers.modeling_outputs import (
22
  import math
23
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
24
 
 
 
 
 
 
25
  try:
26
- from torch.nn.attention.flex_attention import flex_attention, create_block_mask
 
 
 
 
 
 
 
 
27
  except ImportError:
28
- pass
 
 
 
29
 
30
 
31
- class ModelOutput:
 
 
 
 
 
 
32
 
33
- def __init__(
34
- self,
35
- logits: torch.Tensor | None = None,
36
- loss: torch.Tensor | float | None = None,
37
- perplexity: torch.Tensor | float | None = None,
38
- accuracy: float | None = None,
39
- z_loss: torch.Tensor | float | None = None,
40
- **kwargs
41
- ):
42
- self.logits: torch.Tensor | None
43
- self.loss: torch.Tensor | float | None
44
- self.perplexity: torch.Tensor | float | None
45
- self.accuracy: float | None
46
- self.z_loss: torch.Tensor | float | None
47
 
48
- self.logits = logits
49
- self.loss = loss
50
- self.perplexity = perplexity
51
- self.accuracy = accuracy
52
- self.z_loss = z_loss
53
 
54
- for attr, value in kwargs.items():
55
- setattr(self, attr, value)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- class CastedLinear(nn.Linear):
59
 
 
60
  def __init__(self, in_features, out_features, bias):
61
  super().__init__(in_features, out_features, bias=bias)
62
 
63
- def reset_parameters(self) -> None:
64
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
65
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
66
-
67
  def forward(self, x):
68
  return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
69
 
70
 
71
  class CastedLinearIn(nn.Linear):
72
-
73
  def __init__(self, in_features, out_features, bias):
74
  super().__init__(in_features, out_features, bias=bias)
75
  self.scale = nn.Parameter(torch.ones(in_features))
76
 
77
- def reset_parameters(self) -> None:
78
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
79
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
80
-
81
  def forward(self, x):
82
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
83
 
84
 
85
- class CastedLinearOut(nn.Linear):
86
-
87
- def __init__(self, in_features, out_features, bias):
88
- super().__init__(in_features, out_features, bias=bias)
89
- self.scale = nn.Parameter(torch.ones(out_features))
90
-
91
- def reset_parameters(self) -> None:
92
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features))
93
- nn.init.trunc_normal_(self.weight, mean=0.0, std=std, a=-2*std, b=2*std)
94
-
95
- def forward(self, x):
96
- return F.linear(x, (self.scale.unsqueeze(1) * self.weight).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
97
-
98
-
99
- class MultiCastedLinearOrtho(nn.Module):
100
-
101
- def __init__(self, in_features, out_features, bias):
102
- super().__init__()
103
- self.in_features = in_features
104
- self.out_features = out_features
105
-
106
- self.weights = nn.ParameterList()
107
- for out_feature in out_features:
108
- self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
109
-
110
- if bias:
111
- self.bias = nn.Parameter(torch.zeros(sum(out_features)))
112
- else:
113
- self.bias = self.register_parameter("bias", None)
114
-
115
- self.reset_parameters()
116
-
117
- def reset_parameters(self) -> None:
118
- for i, weight in enumerate(self.weights):
119
- std: float = math.sqrt(2.0 / (self.in_features + self.out_features[i]))
120
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
121
-
122
- def forward(self, x):
123
- return F.linear(x, torch.cat([weight for weight in self.weights], dim=0).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
124
-
125
-
126
  class MultiCastedLinearOrthoIn(nn.Module):
127
-
128
  def __init__(self, in_features, out_features, bias):
129
  super().__init__()
 
130
  self.in_features = in_features
131
  self.out_features = out_features
132
 
@@ -141,244 +112,217 @@ class MultiCastedLinearOrthoIn(nn.Module):
141
 
142
  self.scale = nn.Parameter(torch.ones(in_features))
143
 
144
- self.reset_parameters()
145
-
146
- def reset_parameters(self) -> None:
147
- for weight in self.weights:
148
- std = 0.5 * (self.in_features ** -0.5)
149
- bound = (3 ** 0.5) * std
150
- with torch.no_grad():
151
- weight.uniform_(-bound, bound)
152
-
153
  def forward(self, x):
154
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
155
 
156
 
157
- class MultiCastedLinearOrthoOut(nn.Module):
158
-
159
- def __init__(self, in_features, out_features, bias):
160
- super().__init__()
161
- self.in_features = in_features
162
- self.out_features = out_features
163
-
164
- self.weights = nn.ParameterList()
165
- for out_feature in out_features:
166
- self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
167
-
168
- if bias:
169
- self.bias = nn.Parameter(torch.zeros(sum(out_features)))
170
- else:
171
- self.bias = self.register_parameter("bias", None)
172
-
173
- self.scale = nn.Parameter(torch.ones(sum(out_features)))
174
-
175
- self.reset_parameters()
176
-
177
- def reset_parameters(self) -> None:
178
- for weight in self.weights:
179
- std = 0.5 * (self.in_features ** -0.5)
180
- bound = (3 ** 0.5) * std
181
- with torch.no_grad():
182
- weight.uniform_(-bound, bound)
183
-
184
- def forward(self, x):
185
- return F.linear(x, (self.scale.unsqueeze(1) * torch.cat([weight for weight in self.weights], dim=0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
186
-
187
-
188
  class GeGLU(nn.Module):
189
  def forward(self, x):
190
  x, gate = x.chunk(2, dim=-1)
191
- x = x * gelu_new(gate)
192
- return x
193
-
194
-
195
- class MaskedSoftmax(torch.autograd.Function):
196
- @staticmethod
197
- def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
198
- ctx.dim: int
199
-
200
- ctx.dim = dim
201
- x.masked_fill_(mask, float('-inf'))
202
- x = torch.softmax(x, ctx.dim)
203
- x.masked_fill_(mask, 0.0)
204
- ctx.save_for_backward(x)
205
- return x
206
-
207
- @staticmethod
208
- def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
209
- output: torch.Tensor
210
-
211
- output, = ctx.saved_tensors
212
- inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
213
- return inputGrad, None, None
214
-
215
 
216
- class Encoder(nn.Module):
217
 
218
- def __init__(self, config) -> None:
 
219
  super().__init__()
220
 
221
- self.layers: nn.ModuleList[Layer]
222
-
223
- self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
224
-
225
- for i, layer in enumerate(self.layers):
226
- for weight in layer.mlp.up_proj.weights:
227
- weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
228
- layer.mlp.down_proj.weight.data *= math.sqrt(1.0 / (2.0 * (i + 1)))
229
-
230
- self.short_long_ratio = config.short_long_ratio
231
-
232
- def set_window_length(self, config) -> None:
233
- for i, layer in enumerate(self.layers):
234
- if (i+1) % self.short_long_ratio == 0:
235
- layer.set_window_length(config.window_length, config.not_flex)
236
- else:
237
- layer.set_window_length(256, config.not_flex)
238
-
239
- def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
240
- hidden_layer: List[torch.Tensor]
241
- attention_probs: List[torch.Tensor]
242
-
243
- hidden_states = []
244
- attention_probs = []
245
- v1 = None
246
-
247
- for layer in self.layers:
248
- hidden_layer, v1, attention_p = layer(hidden_layer, embeddings, v1, mask)
249
- hidden_states.append(hidden_layer)
250
- attention_probs.append(attention_p)
251
 
252
- return hidden_states, attention_probs
 
 
 
253
 
 
254
 
255
- class Layer(nn.Module):
256
 
257
- def __init__(self, config, layer_idx: int) -> None:
 
258
  super().__init__()
259
 
260
- self.attention: SelfAttention
261
- self.mlp: FeedForward
262
-
263
- self.attention = SelfAttention(config, layer_idx)
264
- self.mlp = FeedForward(config)
265
- self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
266
-
267
- def set_window_length(self, window_length: int, not_flex: bool) -> None:
268
- self.attention.set_window_length(window_length, not_flex)
269
-
270
- def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
271
- output: torch.Tensor
272
- attention_p: torch.Tensor
273
-
274
- attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
275
- qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
276
- mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
277
-
278
- attention_output, v1, attention_p = self.attention(attention_output, qk_layer, v1, mask)
279
- mlp_layer = mlp_layer + attention_output
280
- hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
281
- output = hidden_layer + attention_output + self.mlp(mlp_layer)
282
-
283
- return output, v1, attention_p
284
-
285
 
286
- class Embedding(nn.Module):
287
 
288
- def __init__(self, config) -> None:
 
289
  super().__init__()
290
 
291
- assert hasattr(config, "vocab_size"), "The config must have a vocab_size attribute!"
292
- assert hasattr(config, "hidden_size"), "The config must have a hidden_size attribute!"
293
- assert hasattr(config, "embedding_dropout_p"), "The model must have a embedding_dropout_p attribute!"
294
-
295
- self.word_embedding: nn.Embedding
296
- self.word_norm: nn.LayerNorm
297
- self.dropout: nn.Dropout
298
-
299
- self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
300
- self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.word_norm_eps, elementwise_affine=False, bias=False)
301
- self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
302
-
303
- self.dropout = nn.Dropout(config.embedding_dropout_p)
 
304
 
305
- self.initialize(config.hidden_size, config.vocab_size)
306
 
307
- @torch.no_grad()
308
- def initialize(self, hidden_size: int, vocab_size: int) -> None:
309
- std: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- std = math.sqrt(2.0 / (hidden_size + vocab_size))
312
- nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
313
 
314
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
315
- word_embedding: torch.Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- word_embedding = self.word_embedding(input_ids)
318
- word_embedding = self.word_norm(word_embedding)
319
- word_embedding = (word_embedding * (self.word_scale + 1.0).unsqueeze(0).unsqueeze(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- return self.dropout(word_embedding)
322
 
323
 
324
- class MaskClassifier(nn.Module):
 
 
325
 
326
- def __init__(self, config, embedding_weights: nn.Parameter) -> None:
327
- super().__init__()
328
 
329
- self.projection: CastedLinear
330
- self.emb2vocab: CastedLinear
331
- self.pre_norm: nn.LayerNorm
332
- self.post_norm: nn.LayerNorm
 
333
 
334
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
335
- self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
336
- self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
337
- self.emb2vocab = CastedLinearIn(config.hidden_size, config.vocab_size, bias=True)
338
 
339
- self.initialize(config.hidden_size, config.vocab_size, embedding_weights)
 
 
 
 
 
 
340
 
341
- @torch.no_grad()
342
- def initialize(self, hidden_size: int, vocab_size: int, embedding_weights: nn.Parameter) -> None:
343
- proj_std: float = math.sqrt(2.0 / (hidden_size + 4*hidden_size))
344
 
345
- nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
346
- self.emb2vocab.weight = embedding_weights
347
- self.emb2vocab.bias.zero_()
348
 
349
- def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
350
- projection: torch.Tensor
 
351
 
352
- projection = self.projection(hidden_layer)
353
- projection = gelu_new(projection)
354
- projection = self.post_norm(projection)
355
 
356
- return projection
 
 
 
 
 
357
 
358
- def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
359
- return self.emb2vocab(hidden_layer)
360
 
361
- def forward(self, hidden_layer: torch.Tensor, labels: torch.Tensor | None = None) -> torch.Tensor:
362
- output: torch.Tensor
363
 
364
- if labels is not None:
365
- hidden_layer = torch.index_select(hidden_layer.flatten(0, 1), 0, torch.nonzero(labels.flatten() != -100).squeeze())
366
 
367
- hidden_layer = self.pre_norm(hidden_layer)
368
- hidden_layer = self.project(hidden_layer)
369
- output = self.calculate_output(hidden_layer)
 
 
 
 
370
 
371
- return output
 
372
 
373
 
374
  class SelfAttention(nn.Module):
375
-
376
- def __init__(self, config, layer_idx) -> None:
377
  super().__init__()
378
- self.d_qk = config.d_qk
379
- self.d_v = config.d_v
 
 
 
 
380
  self.num_attention_heads = config.num_attention_heads
381
- self.num_kv_heads = config.num_kv_heads
382
  self.hidden_size = config.hidden_size
383
 
384
  self.q_out_dim = self.d_qk * self.num_attention_heads
@@ -389,256 +333,228 @@ class SelfAttention(nn.Module):
389
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
390
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
391
 
392
- self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
393
- self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.attention_pre_norm_eps, elementwise_affine=config.attention_pre_norm_affine)
394
- self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.attention_inter_norm_eps, elementwise_affine=config.attention_inter_norm_affine)
395
- self.q_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
396
- self.k_norm = nn.LayerNorm(config.d_qk, eps=config.attention_pre_norm_eps, elementwise_affine=False, bias=False)
397
- self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, config.d_qk))
398
- self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, config.d_qk))
399
-
400
- self.dropout = nn.Dropout(config.attention_output_dropout_p)
401
 
402
- theta = 160_000 if (layer_idx + 1) % config.short_long_ratio == 0 else 10_000
403
 
404
- self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
405
- self.scale: float = 1.0 / math.sqrt(self.d_qk)
406
 
407
- self.dropout = nn.Dropout(config.attention_dropout if hasattr(config, "attention_dropout") else 0.0)
 
 
 
 
408
 
 
409
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
410
 
411
- self.initialize()
412
-
413
  self.sequence_length = config.max_sequence_length
414
  self.is_causal = config.is_decoder
415
- self.not_flex = config.not_flex
416
-
417
- @torch.no_grad()
418
- def initialize(self) -> None:
419
- std: float = math.sqrt(2.0 / (self.hidden_size + 4*self.hidden_size))
420
- for weight in self.qk_proj.weights:
421
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
422
- nn.init.trunc_normal_(self.v_proj.weight, mean=0.0, std=std, a=2*std, b=2*std)
423
- self.out_proj.weight.data.zero_()
424
-
425
- def set_window_length(self, window_length: int, not_flex: bool) -> None:
426
- self.window_length: int = window_length
427
- if not not_flex:
428
- self.block_mask = self.create_block_mask(window_length)
429
 
430
- def causal_mask_mode(self, window_length, b, _, q_idx, kv_idx):
431
- return (q_idx >= kv_idx) & ((q_idx - kv_idx) < window_length)
432
 
433
- def bidirectional_mask_mode(self, window_length, b, _, q_idx, kv_idx):
434
- return ((q_idx - kv_idx) < window_length) & ((kv_idx - q_idx) < window_length)
435
-
436
- def create_block_mask(self, window_length: int) -> torch.Tensor:
437
  if self.is_causal:
438
- return create_block_mask(
439
- partial(self.causal_mask_mode, self.window_length),
440
- 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
441
- )
442
  else:
443
- return create_block_mask(
444
- partial(self.bidirectional_mask_mode, self.window_length),
445
- 1, 1, self.sequence_length, self.sequence_length, device=self.k_scale.device
446
- )
447
-
448
- def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
449
- attention_scores: torch.Tensor
450
- attention_probabilities: torch.Tensor
451
- batch_size: int
452
- query_length: int
453
- key_length: int
454
 
 
 
455
  batch_size, _, query_length, _ = query.size()
456
  _, _, key_length, _ = key.size()
457
 
458
- if self.is_causal:
459
- window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril().triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
460
- else:
461
- window_mask = ~torch.ones(query_length, key_length, dtype=torch.bool, device=self.k_scale.device).tril(diagonal=self.window_length).triu(diagonal=-self.window_length).view(1, 1, query_length, key_length)
462
-
463
- if padding_mask is not None:
464
- attention_mask = padding_mask | window_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  else:
466
- attention_mask = window_mask
 
467
 
468
- attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, T, T]
469
- attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
470
 
471
- attention_probabilities = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
472
- attention_probabilities = self.dropout(attention_probabilities)
473
 
474
- value = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
475
- value = value.view(batch_size, self.num_attention_heads, query_length, self.d_v)
 
 
 
476
 
477
- return value, attention_probabilities.detach()
 
 
478
 
479
- def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, mask: torch.Tensor | None = None, doc_ids: torch.Tensor | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
480
- hidden_layer = self.pre_v_norm(hidden_layer)
481
- qk_layer = self.pre_qk_norm(qk_layer)
482
 
483
- query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
484
- value = self.v_proj(hidden_layer)
485
 
486
- query_length: int = hidden_layer.size(0)
487
- key_length: int = hidden_layer.size(0)
488
- batch_size: int = hidden_layer.size(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
- query = query.reshape(query_length, batch_size, self.num_attention_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
491
- key = key.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
492
- value = value.reshape(key_length, batch_size, self.num_kv_heads, self.d_qk).permute(1, 2, 0, 3) # shape: [B, H, T, D]
493
 
494
- query, key = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query), ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
 
 
 
495
 
496
- if v1 is None:
497
- v1 = value
498
- value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
499
 
500
- query = self.rope_embedding(query)
501
- key = self.rope_embedding(key)
502
 
503
- if self.not_flex:
504
- output, attention_probabilities = self.attention_operation(query, key, value, mask)
505
- else:
506
- def document_score_mod(score, b, _, q_idx, kv_idx):
507
- return torch.where(doc_ids[q_idx] == doc_ids[kv_idx], score, -float("inf"))
508
-
509
- if self.is_causal:
510
- block_mask = create_block_mask(
511
- partial(self.causal_mask_mode, self.window_length),
512
- 1, 1, query_length, key_length, device=self.k_scale.device
513
- )
514
  else:
515
- block_mask = create_block_mask(
516
- partial(self.bidirectional_mask_mode, self.window_length),
517
- 1, 1, query_length, key_length, device=self.k_scale.device
518
- )
519
 
520
- output = flex_attention(
521
- query, key, value, block_mask=block_mask, enable_gqa=True
522
- )
523
- attention_probabilities = None
524
 
525
- output = output.permute(2, 0, 1, 3).flatten(2, 3) # shape: [T, B, H*D]
526
- output = self.inter_norm(output)
527
- output = self.out_proj(output)
528
 
529
- return self.dropout(output), v1, attention_probabilities
 
 
530
 
 
531
 
532
- class FeedForward(nn.Module):
533
 
534
- def __init__(self, config) -> None:
 
535
  super().__init__()
536
-
537
- self.up_proj: CastedLinear
538
- self.down_proj: CastedLinear
539
- self.pre_norm: nn.LayerNorm
540
- self.inter_norm: nn.LayerNorm
541
- self.activation: GeGLU
542
- self.dropout: nn.Dropout
543
-
544
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.feed_forward_pre_norm_eps, elementwise_affine=config.feed_forward_pre_norm_affine)
545
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
546
  self.activation = GeGLU()
547
- self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.feed_forward_inter_norm_eps, elementwise_affine=config.feed_forward_inter_norm_affine)
548
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
549
- self.dropout = nn.Dropout(config.feed_forward_dropout_p)
550
-
551
- self.initialize(config.hidden_size)
552
-
553
- @torch.no_grad()
554
- def initialize(self, hidden_size: int) -> None:
555
- std: float = math.sqrt(2.0 / (5*hidden_size))
556
-
557
- for weight in self.up_proj.weights:
558
- nn.init.trunc_normal_(weight, mean=0.0, std=std, a=-2*std, b=2*std)
559
- self.down_proj.weight.data.zero_()
560
-
561
- def up_project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
562
- hidden_layer = self.pre_norm(hidden_layer)
563
- return self.up_proj(hidden_layer)
564
-
565
- def activate(self, projection: torch.Tensor) -> torch.Tensor:
566
- activated_projection: torch.Tensor
567
-
568
- activated_projection = self.activation(projection)
569
- activated_projection = self.inter_norm(activated_projection.float()).type_as(projection)
570
-
571
- return activated_projection
572
 
573
- def down_project(self, activated_projection: torch.Tensor) -> torch.Tensor:
574
- output: torch.Tensor
575
 
576
- output = self.down_proj(activated_projection)
 
 
577
 
578
- return self.dropout(output)
 
 
579
 
580
- def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
581
- output: torch.Tensor
582
 
583
- output = self.up_project(hidden_layer)
584
- output = self.activate(output)
585
- output = self.down_project(output)
 
586
 
587
- return output
 
 
 
588
 
 
589
 
590
- class RotaryPositionalEmbeddings(nn.Module):
591
 
592
- def __init__(self, config, theta: int) -> None:
 
593
  super().__init__()
 
 
594
 
595
- assert hasattr(config, "d_qk"), "The config must have a d_qk attribute!"
596
- assert hasattr(config, "max_sequence_length"), "The config must have a max_sequence_length attribute!"
597
-
598
- self.inv_freq: torch.Tensor
599
- self.cos_matrix: torch.Tensor
600
- self.sin_matrix: torch.Tensor
601
- head_size: int
602
- max_seq_len: int
603
- inv_freq: torch.Tensor
604
- pos: torch.Tensor
605
- embedding: torch.Tensor
606
-
607
- head_size = config.d_qk
608
- assert head_size % 2 == 0
609
- max_seq_len = config.max_sequence_length
610
 
611
- inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
612
- pos = torch.arange(max_seq_len, dtype=torch.float32)
613
- embedding = torch.einsum('n, d -> nd', pos, inv_freq)
614
- embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
615
- self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
616
- self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
617
 
618
- def forward(self, x: torch.Tensor) -> torch.Tensor:
619
- seq_len: int
620
- cos_matrix: torch.Tensor
621
- sin_matrix: torch.Tensor
622
- x_rotate_half: torch.Tensor
623
- out: torch.Tensor
624
 
625
- hidden_layer = x.float()
 
626
 
627
- seq_len = x.shape[2]
628
-
629
- cos_matrix = self.cos_matrix[:, None, :seq_len, :]
630
- sin_matrix = self.sin_matrix[:, None, :seq_len, :]
631
-
632
- x_rotate_half = torch.cat(
633
- [
634
- -hidden_layer[:, :, :, x.size(-1) // 2:],
635
- hidden_layer[:, :, :, :x.size(-1) // 2]
636
- ],
637
- dim=-1
638
- )
639
-
640
- out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
641
- return out.type_as(x)
642
 
643
 
644
  #
@@ -647,15 +563,15 @@ class RotaryPositionalEmbeddings(nn.Module):
647
 
648
  class GptBertPreTrainedModel(PreTrainedModel):
649
  config_class = GptBertConfig
650
- supports_gradient_checkpointing = False
651
-
652
- def _set_gradient_checkpointing(self, module, value=False):
653
- raise NotImplementedError("Gradient checkpointing is not supported by this model")
654
 
655
  def _init_weights(self, module):
656
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
657
 
658
- if isinstance(module, nn.Linear):
659
  nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
660
  if module.bias is not None:
661
  module.bias.data.zero_()
@@ -667,16 +583,17 @@ class GptBertPreTrainedModel(PreTrainedModel):
667
 
668
 
669
  class GptBertModel(GptBertPreTrainedModel):
670
-
671
- def __init__(self, config, add_mlm_layer=False, **kwargs):
672
  super().__init__(config, **kwargs)
673
  self.config = config
674
  self.hidden_size = config.hidden_size
675
 
676
  self.embedding = Embedding(config)
677
  self.encoder = Encoder(config)
678
- self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
679
  self.set_window_length(config)
 
 
680
 
681
  def set_window_length(self, config) -> None:
682
  self.encoder.set_window_length(config)
@@ -690,8 +607,9 @@ class GptBertModel(GptBertPreTrainedModel):
690
  def get_contextualized_embeddings(
691
  self,
692
  input_ids: Optional[torch.Tensor] = None,
693
- attention_mask: Optional[torch.Tensor] = None
694
- ) -> List[torch.Tensor]:
 
695
  if input_ids is not None:
696
  input_shape = input_ids.size()
697
  else:
@@ -700,35 +618,55 @@ class GptBertModel(GptBertPreTrainedModel):
700
  batch_size, seq_length = input_shape
701
  device = input_ids.device
702
 
703
- # if attention_mask is None:
704
- # attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
705
- if attention_mask is not None:
706
- attention_mask = ~attention_mask.bool()
707
 
 
 
 
 
 
 
 
708
  if len(attention_mask.size()) == 2:
709
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
710
  elif len(attention_mask.size()) == 3:
711
  attention_mask = attention_mask.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
- if self.config.is_decoder:
714
- attention_mask = attention_mask | torch.triu(torch.ones(seq_length, seq_length, dtype=torch.bool, device=device), 1).unsqueeze(0).unsqueeze(0)
 
 
 
 
 
715
 
716
- static_embeddings = self.embedding(input_ids.t())
717
- contextualized_embeddings, attention_probs = self.encoder(static_embeddings, static_embeddings, attention_mask)
718
- contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
719
- last_layer = contextualized_embeddings[-1]
720
- contextualized_embeddings = [contextualized_embeddings[0]] + [
721
- contextualized_embeddings[i] - contextualized_embeddings[i - 1]
722
- for i in range(1, len(contextualized_embeddings))
723
- ]
724
- return last_layer, contextualized_embeddings, attention_probs
725
 
726
  def forward(
727
  self,
728
  input_ids: Optional[torch.Tensor] = None,
729
  attention_mask: Optional[torch.Tensor] = None,
730
- token_type_ids: Optional[torch.Tensor] = None,
731
- position_ids: Optional[torch.Tensor] = None,
732
  output_hidden_states: Optional[bool] = None,
733
  output_attentions: Optional[bool] = None,
734
  return_dict: Optional[bool] = None,
@@ -736,26 +674,24 @@ class GptBertModel(GptBertPreTrainedModel):
736
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
737
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
738
 
739
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
740
 
741
  if not return_dict:
742
  return (
743
  sequence_output,
744
- *([contextualized_embeddings] if output_hidden_states else []),
745
- *([attention_probs] if output_attentions else [])
746
  )
747
 
748
  return BaseModelOutput(
749
  last_hidden_state=sequence_output,
750
- hidden_states=contextualized_embeddings if output_hidden_states else None,
751
- attentions=attention_probs if output_attentions else None
752
  )
753
 
754
 
755
  class GptBertForMaskedLM(GptBertModel):
756
- _keys_to_ignore_on_load_unexpected = ["head"]
757
 
758
- def __init__(self, config, **kwargs):
759
  super().__init__(config, add_mlm_layer=True, **kwargs)
760
 
761
  def get_output_embeddings(self):
@@ -768,17 +704,14 @@ class GptBertForMaskedLM(GptBertModel):
768
  self,
769
  input_ids: Optional[torch.Tensor] = None,
770
  attention_mask: Optional[torch.Tensor] = None,
771
- token_type_ids: Optional[torch.Tensor] = None,
772
- position_ids: Optional[torch.Tensor] = None,
773
  output_hidden_states: Optional[bool] = None,
774
- output_attentions: Optional[bool] = None,
775
  return_dict: Optional[bool] = None,
776
  labels: Optional[torch.LongTensor] = None,
777
  **kwargs
778
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
779
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
780
 
781
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
782
  subword_prediction = self.classifier(sequence_output)
783
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
784
 
@@ -788,78 +721,28 @@ class GptBertForMaskedLM(GptBertModel):
788
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
789
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
790
 
 
 
 
 
791
  if not return_dict:
792
  output = (
793
  subword_prediction,
794
- *([contextualized_embeddings] if output_hidden_states else []),
795
- *([attention_probs] if output_attentions else [])
796
  )
797
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
798
 
799
  return MaskedLMOutput(
800
  loss=masked_lm_loss,
801
  logits=subword_prediction,
802
- hidden_states=contextualized_embeddings if output_hidden_states else None,
803
- attentions=attention_probs if output_attentions else None
804
  )
805
 
806
 
807
- class Classifier(nn.Module):
808
- def __init__(self, config, num_labels: int):
809
- super().__init__()
810
-
811
- drop_out = getattr(config, "cls_dropout", None)
812
- drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
813
-
814
- self.projection: CastedLinear
815
- self.emb2vocab: CastedLinear
816
- self.pre_norm: nn.LayerNorm
817
- self.post_norm: nn.LayerNorm
818
-
819
- self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_pre_norm_eps, elementwise_affine=config.classifier_pre_norm_affine)
820
- self.projection = CastedLinear(config.hidden_size, config.hidden_size, bias=False)
821
- self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.classifier_post_norm_eps, elementwise_affine=config.classifier_post_norm_affine)
822
- self.emb2vocab = CastedLinear(config.hidden_size, num_labels, bias=True)
823
- self.dropout = nn.Dropout(drop_out)
824
-
825
- self.initialize(config.hidden_size, config.intermediate_size, num_labels)
826
-
827
- @torch.no_grad()
828
- def initialize(self, hidden_size: int, intermediate_size: int, vocab_size: int) -> None:
829
- proj_std: float = math.sqrt(2.0 / (hidden_size + intermediate_size))
830
-
831
- nn.init.trunc_normal_(self.projection.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
832
- nn.init.trunc_normal_(self.emb2vocab.weight, mean=0.0, std=proj_std, a=-2*proj_std, b=2*proj_std)
833
- self.emb2vocab.bias.zero_()
834
-
835
- def project(self, hidden_layer: torch.Tensor) -> torch.Tensor:
836
- projection: torch.Tensor
837
-
838
- projection = self.pre_norm(hidden_layer)
839
- projection = self.dropout(projection)
840
- projection = self.projection(projection)
841
- projection = gelu_new(projection)
842
- projection = self.post_norm(projection)
843
-
844
- return projection
845
-
846
- def calculate_output(self, hidden_layer: torch.Tensor) -> torch.Tensor:
847
- return self.emb2vocab(hidden_layer)
848
-
849
- def forward(self, hidden_layer: torch.Tensor) -> torch.Tensor:
850
- output: torch.Tensor
851
- projection: torch.Tensor
852
-
853
- projection = self.project(hidden_layer)
854
- output = self.calculate_output(projection)
855
-
856
- return output
857
-
858
-
859
  class GptBertForCausalLM(GptBertModel):
860
- _keys_to_ignore_on_load_unexpected = ["head"]
861
 
862
- def __init__(self, config, **kwargs):
863
  config.is_decoder = True
864
  super().__init__(config, add_mlm_layer=True, **kwargs)
865
 
@@ -904,29 +787,27 @@ class GptBertForCausalLM(GptBertModel):
904
  assert past_key_values is None, "past_key_values is not supported for now"
905
  assert not use_cache, "use_cache is not supported for now"
906
 
907
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
908
  subword_prediction = self.classifier(sequence_output)
909
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
910
 
911
- masked_lm_loss = None
912
  if labels is not None:
913
  labels_flatten = labels[:, 1:].flatten()
914
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
915
- masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
916
 
917
  if not return_dict:
918
  output = (
919
  subword_prediction,
920
- *([contextualized_embeddings] if output_hidden_states else []),
921
- *([attention_probs] if output_attentions else [])
922
  )
923
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
924
 
925
  return CausalLMOutput(
926
- loss=masked_lm_loss,
927
  logits=subword_prediction,
928
- hidden_states=contextualized_embeddings if output_hidden_states else None,
929
- attentions=attention_probs if output_attentions else None
930
  )
931
 
932
  def prepare_inputs_for_generation(
@@ -982,21 +863,20 @@ class GptBertForCausalLM(GptBertModel):
982
 
983
 
984
  class GptBertForSequenceClassification(GptBertModel):
985
- _keys_to_ignore_on_load_unexpected = ["classifier"]
 
986
 
987
- def __init__(self, config, **kwargs):
988
  super().__init__(config, add_mlm_layer=False, **kwargs)
989
 
990
  self.num_labels = config.num_labels
991
- self.head = Classifier(config, self.num_labels)
 
992
 
993
  def forward(
994
  self,
995
  input_ids: Optional[torch.Tensor] = None,
996
  attention_mask: Optional[torch.Tensor] = None,
997
- token_type_ids: Optional[torch.Tensor] = None,
998
- position_ids: Optional[torch.Tensor] = None,
999
- output_attentions: Optional[bool] = None,
1000
  output_hidden_states: Optional[bool] = None,
1001
  return_dict: Optional[bool] = None,
1002
  labels: Optional[torch.LongTensor] = None,
@@ -1004,8 +884,8 @@ class GptBertForSequenceClassification(GptBertModel):
1004
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1005
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1006
 
1007
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1008
- logits = self.head(sequence_output[:, 0, :])
1009
 
1010
  loss = None
1011
  if labels is not None:
@@ -1033,35 +913,32 @@ class GptBertForSequenceClassification(GptBertModel):
1033
  if not return_dict:
1034
  output = (
1035
  logits,
1036
- *([contextualized_embeddings] if output_hidden_states else []),
1037
- *([attention_probs] if output_attentions else [])
1038
  )
1039
  return ((loss,) + output) if loss is not None else output
1040
 
1041
  return SequenceClassifierOutput(
1042
  loss=loss,
1043
  logits=logits,
1044
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1045
- attentions=attention_probs if output_attentions else None
1046
  )
1047
 
1048
 
1049
  class GptBertForTokenClassification(GptBertModel):
1050
- _keys_to_ignore_on_load_unexpected = ["classifier"]
 
1051
 
1052
- def __init__(self, config, **kwargs):
1053
  super().__init__(config, add_mlm_layer=False, **kwargs)
1054
 
1055
  self.num_labels = config.num_labels
1056
- self.head = Classifier(config, self.num_labels)
 
1057
 
1058
  def forward(
1059
  self,
1060
  input_ids: Optional[torch.Tensor] = None,
1061
  attention_mask: Optional[torch.Tensor] = None,
1062
- token_type_ids: Optional[torch.Tensor] = None,
1063
- position_ids: Optional[torch.Tensor] = None,
1064
- output_attentions: Optional[bool] = None,
1065
  output_hidden_states: Optional[bool] = None,
1066
  return_dict: Optional[bool] = None,
1067
  labels: Optional[torch.LongTensor] = None,
@@ -1069,8 +946,8 @@ class GptBertForTokenClassification(GptBertModel):
1069
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1070
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
 
1072
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1073
- logits = self.head(sequence_output)
1074
 
1075
  loss = None
1076
  if labels is not None:
@@ -1094,21 +971,20 @@ class GptBertForTokenClassification(GptBertModel):
1094
 
1095
 
1096
  class GptBertForQuestionAnswering(GptBertModel):
1097
- _keys_to_ignore_on_load_unexpected = ["classifier"]
1098
-
1099
- def __init__(self, config, **kwargs):
 
1100
  super().__init__(config, add_mlm_layer=False, **kwargs)
1101
 
1102
  self.num_labels = config.num_labels
1103
- self.head = Classifier(config, self.num_labels)
 
1104
 
1105
  def forward(
1106
  self,
1107
  input_ids: Optional[torch.Tensor] = None,
1108
  attention_mask: Optional[torch.Tensor] = None,
1109
- token_type_ids: Optional[torch.Tensor] = None,
1110
- position_ids: Optional[torch.Tensor] = None,
1111
- output_attentions: Optional[bool] = None,
1112
  output_hidden_states: Optional[bool] = None,
1113
  return_dict: Optional[bool] = None,
1114
  start_positions: Optional[torch.Tensor] = None,
@@ -1117,8 +993,8 @@ class GptBertForQuestionAnswering(GptBertModel):
1117
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1118
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1119
 
1120
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
1121
- logits = self.head(sequence_output)
1122
 
1123
  start_logits, end_logits = logits.split(1, dim=-1)
1124
  start_logits = start_logits.squeeze(-1).contiguous()
@@ -1146,8 +1022,7 @@ class GptBertForQuestionAnswering(GptBertModel):
1146
  output = (
1147
  start_logits,
1148
  end_logits,
1149
- *([contextualized_embeddings] if output_hidden_states else []),
1150
- *([attention_probs] if output_attentions else [])
1151
  )
1152
  return ((total_loss,) + output) if total_loss is not None else output
1153
 
@@ -1155,28 +1030,26 @@ class GptBertForQuestionAnswering(GptBertModel):
1155
  loss=total_loss,
1156
  start_logits=start_logits,
1157
  end_logits=end_logits,
1158
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1159
- attentions=attention_probs if output_attentions else None
1160
  )
1161
 
1162
 
1163
  class GptBertForMultipleChoice(GptBertModel):
1164
- _keys_to_ignore_on_load_unexpected = ["classifier"]
 
1165
 
1166
- def __init__(self, config, **kwargs):
1167
  super().__init__(config, add_mlm_layer=False, **kwargs)
1168
 
1169
  self.num_labels = getattr(config, "num_labels", 2)
1170
- self.head = Classifier(config, self.num_labels)
 
1171
 
1172
  def forward(
1173
  self,
1174
  input_ids: Optional[torch.Tensor] = None,
1175
  attention_mask: Optional[torch.Tensor] = None,
1176
- token_type_ids: Optional[torch.Tensor] = None,
1177
- position_ids: Optional[torch.Tensor] = None,
1178
  labels: Optional[torch.Tensor] = None,
1179
- output_attentions: Optional[bool] = None,
1180
  output_hidden_states: Optional[bool] = None,
1181
  return_dict: Optional[bool] = None,
1182
  **kwargs
@@ -1187,8 +1060,8 @@ class GptBertForMultipleChoice(GptBertModel):
1187
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1188
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1189
 
1190
- sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask)
1191
- logits = self.head(sequence_output)
1192
  reshaped_logits = logits.view(-1, num_choices)
1193
 
1194
  loss = None
@@ -1199,14 +1072,12 @@ class GptBertForMultipleChoice(GptBertModel):
1199
  if not return_dict:
1200
  output = (
1201
  reshaped_logits,
1202
- *([contextualized_embeddings] if output_hidden_states else []),
1203
- *([attention_probs] if output_attentions else [])
1204
  )
1205
  return ((loss,) + output) if loss is not None else output
1206
 
1207
  return MultipleChoiceModelOutput(
1208
  loss=loss,
1209
  logits=reshaped_logits,
1210
- hidden_states=contextualized_embeddings if output_hidden_states else None,
1211
- attentions=attention_probs if output_attentions else None
1212
  )
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
 
7
+ from functools import partial, lru_cache
8
 
9
  from .configuration_gptbert import GptBertConfig
10
  from transformers.modeling_utils import PreTrainedModel
11
  from transformers.activations import gelu_new
12
+ from transformers.utils import is_flash_attn_2_available, logging
13
  from transformers.modeling_outputs import (
14
  MaskedLMOutput,
15
  MultipleChoiceModelOutput,
 
22
  import math
23
  from typing import TYPE_CHECKING, Optional, Union, Tuple, List
24
 
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # Workaround for transformers < 4.36.0 check_imports issue
29
+ # See: https://github.com/huggingface/transformers/issues/28459
30
  try:
31
+ if is_flash_attn_2_available():
32
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
33
+ from flash_attn.layers.rotary import RotaryEmbedding
34
+ from flash_attn.ops.triton.rotary import apply_rotary
35
+ else:
36
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
37
+ logger.warning_once(
38
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
39
+ )
40
  except ImportError:
41
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
42
+ logger.warning_once(
43
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
44
+ )
45
 
46
 
47
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
48
+ @torch.compiler.disable()
49
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
50
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
51
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
52
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
53
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
54
 
55
+ if input_ids.dim() == 2:
56
+ unpadded_inputs = input_ids.flatten()[indices]
57
+ else:
58
+ batch_size, sequence_length, *rest = input_ids.shape
59
+ shape = batch_size * sequence_length
60
+ unpadded_inputs = input_ids.view(shape, *rest)[indices]
 
 
 
 
 
 
 
 
61
 
62
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
 
 
 
 
63
 
 
 
64
 
65
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
66
+ def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
67
+ if input_ids.dim() == 1:
68
+ output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
69
+ output[indices] = input_ids
70
+ padded_inputs = output.view(batch_size, sequence_length)
71
+ else:
72
+ _, *rest = input_ids.shape
73
+ output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
74
+ output[indices] = input_ids
75
+ padded_inputs = output.view(batch_size, sequence_length, *rest)
76
+
77
+ return padded_inputs
78
 
 
79
 
80
+ class CastedLinear(nn.Linear):
81
  def __init__(self, in_features, out_features, bias):
82
  super().__init__(in_features, out_features, bias=bias)
83
 
 
 
 
 
84
  def forward(self, x):
85
  return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
86
 
87
 
88
  class CastedLinearIn(nn.Linear):
 
89
  def __init__(self, in_features, out_features, bias):
90
  super().__init__(in_features, out_features, bias=bias)
91
  self.scale = nn.Parameter(torch.ones(in_features))
92
 
 
 
 
 
93
  def forward(self, x):
94
  return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  class MultiCastedLinearOrthoIn(nn.Module):
 
98
  def __init__(self, in_features, out_features, bias):
99
  super().__init__()
100
+
101
  self.in_features = in_features
102
  self.out_features = out_features
103
 
 
112
 
113
  self.scale = nn.Parameter(torch.ones(in_features))
114
 
 
 
 
 
 
 
 
 
 
115
  def forward(self, x):
116
  return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  class GeGLU(nn.Module):
120
  def forward(self, x):
121
  x, gate = x.chunk(2, dim=-1)
122
+ return x * gelu_new(gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
124
 
125
+ class Embedding(nn.Module):
126
+ def __init__(self, config: GptBertConfig):
127
  super().__init__()
128
 
129
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
130
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
131
+ self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
132
+ self.dropout = nn.Dropout(config.embedding_dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def forward(self, input_ids: torch.Tensor):
135
+ word_embedding = self.word_embedding(input_ids)
136
+ word_embedding = self.word_norm(word_embedding)
137
+ word_embedding = word_embedding * (self.word_scale + 1.0)
138
 
139
+ return self.dropout(word_embedding)
140
 
 
141
 
142
+ class LMClassifier(nn.Module):
143
+ def __init__(self, config: GptBertConfig, n_labels: int):
144
  super().__init__()
145
 
146
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
147
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
148
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
149
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
150
+
151
+ def forward(self, x: torch.Tensor):
152
+ x = self.pre_norm(x.float()).type_as(x)
153
+ x = self.projection(x)
154
+ x = gelu_new(x)
155
+ x = self.post_norm(x.float()).type_as(x)
156
+ x = self.emb2vocab(x)
157
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
 
159
 
160
+ class Classifier(nn.Module):
161
+ def __init__(self, config: GptBertConfig, n_labels: int):
162
  super().__init__()
163
 
164
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
165
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
166
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
167
+ self.dropout = nn.Dropout(config.classifier_dropout)
168
+ self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
169
+
170
+ def forward(self, x: torch.Tensor):
171
+ x = self.pre_norm(x.float()).type_as(x)
172
+ x = self.projection(x)
173
+ x = gelu_new(x)
174
+ x = self.post_norm(x.float()).type_as(x)
175
+ x = self.dropout(x)
176
+ x = self.output_projection(x)
177
+ return x
178
 
 
179
 
180
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
181
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
182
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
183
+
184
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
185
+ if convert_dtype:
186
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
187
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
188
+ orig_dtype = qkv.dtype
189
+ qkv = qkv.to(target_dtype)
190
+
191
+ attn = flash_attn_varlen_qkvpacked_func(
192
+ qkv,
193
+ cu_seqlens=cu_seqlens,
194
+ max_seqlen=max_seqlen,
195
+ dropout_p=dropout_p,
196
+ deterministic=deterministic,
197
+ window_size=local_attention,
198
+ causal=False
199
+ )
200
+ attn = attn.to(orig_dtype) # type: ignore
201
+ else:
202
+ attn = flash_attn_varlen_qkvpacked_func(
203
+ qkv,
204
+ cu_seqlens=cu_seqlens,
205
+ max_seqlen=max_seqlen,
206
+ dropout_p=dropout_p,
207
+ deterministic=deterministic,
208
+ window_size=local_attention,
209
+ causal=False
210
+ )
211
+ return attn
212
 
 
 
213
 
214
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
215
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
216
+ @staticmethod
217
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
218
+ # (total_nnz, 3, nheads, headdim)
219
+ qkv = qkv.contiguous()
220
+ total_nnz, _three, _nheads, headdim = qkv.shape
221
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
222
+ # we get the same tensor
223
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
224
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
225
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
226
+
227
+ ctx.save_for_backward(cos, sin, cu_seqlens)
228
+ ctx.max_seqlen = max_seqlen
229
+ return qkv
230
 
231
+ @staticmethod
232
+ def backward(ctx, do):
233
+ cos, sin, cu_seqlens = ctx.saved_tensors
234
+ do = do.contiguous()
235
+ total_nnz, _three, _nheads, headdim = do.shape
236
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
237
+ # we get the same tensor
238
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
239
+ apply_rotary(
240
+ dqk,
241
+ cos,
242
+ sin,
243
+ seqlen_offsets=0,
244
+ cu_seqlens=cu_seqlens,
245
+ max_seqlen=ctx.max_seqlen,
246
+ interleaved=False,
247
+ inplace=True,
248
+ conjugate=True,
249
+ )
250
 
251
+ return do, None, None, None, None, None, None
252
 
253
 
254
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
255
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
256
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
257
 
 
 
258
 
259
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
260
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
261
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
262
+ super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=None, interleaved=False)
263
+ self.max_seqlen = max_seqlen
264
 
265
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
266
+ if max_seqlen is not None:
267
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
 
268
 
269
+ qkv = apply_rotary_unpadded(
270
+ qkv,
271
+ self._cos_cached,
272
+ self._sin_cached,
273
+ cu_seqlens=cu_seqlens,
274
+ max_seqlen=max_seqlen,
275
+ )
276
 
277
+ return qkv
 
 
278
 
 
 
 
279
 
280
+ class RotaryPositionalEmbeddings(nn.Module):
281
+ def __init__(self, config, theta: int):
282
+ super().__init__()
283
 
284
+ head_size = config.query_key_head_size
285
+ assert head_size % 2 == 0
286
+ max_seq_len = config.max_sequence_length
287
 
288
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
289
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
290
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
291
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
292
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
293
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
294
 
295
+ def forward(self, x: torch.Tensor):
296
+ hidden_layer = x.float()
297
 
298
+ seq_len = x.shape[2]
 
299
 
300
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
301
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
302
 
303
+ x_rotate_half = torch.cat(
304
+ [
305
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
306
+ hidden_layer[:, :, :, :x.size(-1) // 2]
307
+ ],
308
+ dim=-1
309
+ )
310
 
311
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
312
+ return out.type_as(x)
313
 
314
 
315
  class SelfAttention(nn.Module):
316
+ def __init__(self, config: GptBertConfig, layer_idx: int):
 
317
  super().__init__()
318
+
319
+ self.config = config
320
+ self.layer_idx = layer_idx
321
+
322
+ self.d_qk = config.query_key_head_size
323
+ self.d_v = config.value_head_size
324
  self.num_attention_heads = config.num_attention_heads
325
+ self.num_kv_heads = config.num_attention_heads
326
  self.hidden_size = config.hidden_size
327
 
328
  self.q_out_dim = self.d_qk * self.num_attention_heads
 
333
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
334
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
335
 
336
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
337
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
338
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
339
+ self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
340
+ self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
341
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
342
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
 
 
343
 
344
+ self.dropout = nn.Dropout(config.hidden_dropout)
345
 
346
+ theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
 
347
 
348
+ # Initialize rotary embeddings based on whether FlashAttention is available
349
+ if is_flash_attn_2_available():
350
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
351
+ else:
352
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
353
 
354
+ self.scale = 1.0 / math.sqrt(self.d_qk)
355
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
356
 
 
 
357
  self.sequence_length = config.max_sequence_length
358
  self.is_causal = config.is_decoder
359
+ self.window_length = None
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ def set_window_length(self, window_length: int):
362
+ self.window_length = window_length
363
 
364
+ def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
365
+ """Create and cache window attention mask."""
 
 
366
  if self.is_causal:
367
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
368
+ mask = mask.tril().triu(diagonal=-self.window_length)
 
 
369
  else:
370
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
371
+ mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
372
+ return mask.view(1, 1, query_length, key_length)
 
 
 
 
 
 
 
 
373
 
374
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
375
+ """Standard attention computation with masking."""
376
  batch_size, _, query_length, _ = query.size()
377
  _, _, key_length, _ = key.size()
378
 
379
+ # Use cached window mask
380
+ with torch.no_grad():
381
+ window_mask = self._get_window_mask(query_length, key_length, query.device)
382
+ if padding_mask is not None:
383
+ attention_mask = padding_mask & window_mask
384
+ else:
385
+ attention_mask = window_mask
386
+
387
+ output = F.scaled_dot_product_attention(
388
+ query=query,
389
+ key=key,
390
+ value=value,
391
+ attn_mask=attention_mask,
392
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
393
+ is_causal=self.is_causal
394
+ )
395
+ return output
396
+
397
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
398
+ # Get original shape info
399
+ if is_flash_attn_2_available():
400
+ # Unpadded case
401
+ indices, cu_seqlens, max_seqlen = padding_info
402
+ total_seqlen = hidden_layer.size(0)
403
+ batch_size = cu_seqlens.size(0) - 1
404
  else:
405
+ # Padded case
406
+ batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
407
 
408
+ hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
409
+ qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
410
 
411
+ query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
412
+ value = self.v_proj(hidden_layer)
413
 
414
+ if is_flash_attn_2_available():
415
+ # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
416
+ query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
417
+ key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
418
+ value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
419
 
420
+ # Apply layer norm and scaling
421
+ query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
422
+ key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
423
 
424
+ if v1 is None:
425
+ v1 = value
426
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
427
 
428
+ # Prepare qkv for FlashAttention
429
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
430
 
431
+ # Determine window size for local attention
432
+ if self.window_length is not None and self.window_length > 0:
433
+ if self.is_causal:
434
+ local_attention = (self.window_length - 1, 0)
435
+ else:
436
+ local_attention = (self.window_length - 1, self.window_length - 1)
437
+ else:
438
+ local_attention = (-1, -1)
439
+
440
+ # Apply FlashAttention
441
+ output = flash_attention_forward(
442
+ qkv,
443
+ self.rope_embedding,
444
+ cu_seqlens,
445
+ max_seqlen,
446
+ self.is_causal,
447
+ local_attention,
448
+ self.config.attention_dropout if self.training else 0.0,
449
+ self.config.deterministic_flash_attn
450
+ )
451
 
452
+ # Reshape output back
453
+ output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
 
454
 
455
+ else:
456
+ # Standard attention path
457
+ query_length = query.size(1)
458
+ key_length = key.size(1)
459
 
460
+ query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
461
+ key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
462
+ value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
463
 
464
+ query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
465
+ key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
466
 
467
+ if v1 is None:
468
+ v1 = value
 
 
 
 
 
 
 
 
 
469
  else:
470
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
 
 
 
471
 
472
+ # Apply rotary embeddings
473
+ query = self.rope_embedding(query)
474
+ key = self.rope_embedding(key)
 
475
 
476
+ output = self.attention_operation(query, key, value, padding_info)
477
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
 
478
 
479
+ output = self.inter_norm(output.float()).type_as(output)
480
+ output = self.out_proj(output)
481
+ output = self.dropout(output)
482
 
483
+ return output, v1
484
 
 
485
 
486
+ class FeedForward(nn.Module):
487
+ def __init__(self, config: GptBertConfig):
488
  super().__init__()
489
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
 
 
 
 
 
 
 
 
490
  self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
491
  self.activation = GeGLU()
492
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
493
  self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
494
+ self.dropout = nn.Dropout(config.hidden_dropout)
495
+
496
+ def forward(self, x: torch.Tensor):
497
+ x = self.pre_norm(x.float()).type_as(x)
498
+ x = self.up_proj(x)
499
+ x = self.activation(x)
500
+ x = self.inter_norm(x.float()).type_as(x)
501
+ x = self.down_proj(x)
502
+ x = self.dropout(x)
503
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
504
 
 
 
505
 
506
+ class Layer(nn.Module):
507
+ def __init__(self, config: GptBertConfig, layer_idx: int):
508
+ super().__init__()
509
 
510
+ self.attention = SelfAttention(config, layer_idx)
511
+ self.mlp = FeedForward(config)
512
+ self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
513
 
514
+ def set_window_length(self, window_length: int):
515
+ self.attention.set_window_length(window_length)
516
 
517
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
518
+ attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
519
+ qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
520
+ mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
521
 
522
+ attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
523
+ mlp_layer = mlp_layer + attention_output
524
+ hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
525
+ output = hidden_layer + attention_output + self.mlp(mlp_layer)
526
 
527
+ return output, v1
528
 
 
529
 
530
+ class Encoder(nn.Module):
531
+ def __init__(self, config: GptBertConfig):
532
  super().__init__()
533
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
534
+ self.local_global_ratio = config.local_global_ratio
535
 
536
+ def set_window_length(self, config: GptBertConfig):
537
+ for i, layer in enumerate(self.layers):
538
+ if (i + 1) % self.local_global_ratio == 0:
539
+ layer.set_window_length(config.global_window_length)
540
+ else:
541
+ layer.set_window_length(config.local_window_length)
 
 
 
 
 
 
 
 
 
542
 
543
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
544
+ hidden_layers = [hidden_layer] if output_hidden_states else None
545
+ v1 = None
546
+ embeddings = hidden_layer
 
 
547
 
548
+ for layer in self.layers:
549
+ if checkpoint_activations:
550
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
551
+ else:
552
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
 
553
 
554
+ if output_hidden_states:
555
+ hidden_layers.append(hidden_layer)
556
 
557
+ return hidden_layer, hidden_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
 
560
  #
 
563
 
564
  class GptBertPreTrainedModel(PreTrainedModel):
565
  config_class = GptBertConfig
566
+ supports_gradient_checkpointing = True
567
+ _supports_flash_attn_2 = True
568
+ _supports_sdpa = True
569
+ _supports_flex_attn = False
570
 
571
  def _init_weights(self, module):
572
  std = math.sqrt(2.0 / (5.0 * self.hidden_size))
573
 
574
+ if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
575
  nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
576
  if module.bias is not None:
577
  module.bias.data.zero_()
 
583
 
584
 
585
  class GptBertModel(GptBertPreTrainedModel):
586
+ def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
 
587
  super().__init__(config, **kwargs)
588
  self.config = config
589
  self.hidden_size = config.hidden_size
590
 
591
  self.embedding = Embedding(config)
592
  self.encoder = Encoder(config)
593
+ self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
594
  self.set_window_length(config)
595
+ self.gradient_checkpointing = False
596
+ self.post_init()
597
 
598
  def set_window_length(self, config) -> None:
599
  self.encoder.set_window_length(config)
 
607
  def get_contextualized_embeddings(
608
  self,
609
  input_ids: Optional[torch.Tensor] = None,
610
+ attention_mask: Optional[torch.Tensor] = None,
611
+ output_hidden_states: Optional[bool] = None
612
+ ):
613
  if input_ids is not None:
614
  input_shape = input_ids.size()
615
  else:
 
618
  batch_size, seq_length = input_shape
619
  device = input_ids.device
620
 
621
+ if attention_mask is None:
622
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
623
+ else:
624
+ attention_mask = attention_mask.bool()
625
 
626
+ if is_flash_attn_2_available():
627
+ if len(attention_mask.size()) != 2:
628
+ raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
629
+ with torch.no_grad():
630
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
631
+ padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
632
+ else:
633
  if len(attention_mask.size()) == 2:
634
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
635
  elif len(attention_mask.size()) == 3:
636
  attention_mask = attention_mask.unsqueeze(1)
637
+ padding_info = attention_mask
638
+
639
+ static_embeddings = self.embedding(input_ids)
640
+
641
+ original_dtype = static_embeddings.dtype
642
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
643
+ static_embeddings = static_embeddings.bfloat16()
644
+
645
+ last_layer, contextualized_embeddings = self.encoder(
646
+ static_embeddings,
647
+ padding_info,
648
+ output_hidden_states=output_hidden_states,
649
+ checkpoint_activations=self.gradient_checkpointing and self.training
650
+ )
651
+
652
+ last_layer = last_layer.to(original_dtype)
653
+ if output_hidden_states:
654
+ contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
655
 
656
+ # Pad output if using FlashAttention
657
+ if is_flash_attn_2_available():
658
+ last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
659
+ if output_hidden_states:
660
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
661
+ else:
662
+ contextualized_embeddings = None
663
 
664
+ return last_layer, contextualized_embeddings
 
 
 
 
 
 
 
 
665
 
666
  def forward(
667
  self,
668
  input_ids: Optional[torch.Tensor] = None,
669
  attention_mask: Optional[torch.Tensor] = None,
 
 
670
  output_hidden_states: Optional[bool] = None,
671
  output_attentions: Optional[bool] = None,
672
  return_dict: Optional[bool] = None,
 
674
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
675
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
676
 
677
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
678
 
679
  if not return_dict:
680
  return (
681
  sequence_output,
682
+ *([contextualized_embeddings] if output_hidden_states else [])
 
683
  )
684
 
685
  return BaseModelOutput(
686
  last_hidden_state=sequence_output,
687
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
688
  )
689
 
690
 
691
  class GptBertForMaskedLM(GptBertModel):
692
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
693
 
694
+ def __init__(self, config: GptBertConfig, **kwargs):
695
  super().__init__(config, add_mlm_layer=True, **kwargs)
696
 
697
  def get_output_embeddings(self):
 
704
  self,
705
  input_ids: Optional[torch.Tensor] = None,
706
  attention_mask: Optional[torch.Tensor] = None,
 
 
707
  output_hidden_states: Optional[bool] = None,
 
708
  return_dict: Optional[bool] = None,
709
  labels: Optional[torch.LongTensor] = None,
710
  **kwargs
711
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
712
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
713
 
714
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
715
  subword_prediction = self.classifier(sequence_output)
716
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
717
 
 
721
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
722
  masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
723
 
724
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
725
+ bos_logits[:, :, self.config.bos_token_id] = 1.0
726
+ subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
727
+
728
  if not return_dict:
729
  output = (
730
  subword_prediction,
731
+ *([contextualized_embeddings] if output_hidden_states else [])
 
732
  )
733
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
734
 
735
  return MaskedLMOutput(
736
  loss=masked_lm_loss,
737
  logits=subword_prediction,
738
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
739
  )
740
 
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  class GptBertForCausalLM(GptBertModel):
743
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
744
 
745
+ def __init__(self, config: GptBertConfig, **kwargs):
746
  config.is_decoder = True
747
  super().__init__(config, add_mlm_layer=True, **kwargs)
748
 
 
787
  assert past_key_values is None, "past_key_values is not supported for now"
788
  assert not use_cache, "use_cache is not supported for now"
789
 
790
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
791
  subword_prediction = self.classifier(sequence_output)
792
  subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
793
 
794
+ causal_lm_loss = None
795
  if labels is not None:
796
  labels_flatten = labels[:, 1:].flatten()
797
  subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
798
+ causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
799
 
800
  if not return_dict:
801
  output = (
802
  subword_prediction,
803
+ *([contextualized_embeddings] if output_hidden_states else [])
 
804
  )
805
+ return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
806
 
807
  return CausalLMOutput(
808
+ loss=causal_lm_loss,
809
  logits=subword_prediction,
810
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
811
  )
812
 
813
  def prepare_inputs_for_generation(
 
863
 
864
 
865
  class GptBertForSequenceClassification(GptBertModel):
866
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
867
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
868
 
869
+ def __init__(self, config: GptBertConfig, **kwargs):
870
  super().__init__(config, add_mlm_layer=False, **kwargs)
871
 
872
  self.num_labels = config.num_labels
873
+ self.classifier = Classifier(config, self.num_labels)
874
+ self.post_init()
875
 
876
  def forward(
877
  self,
878
  input_ids: Optional[torch.Tensor] = None,
879
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
880
  output_hidden_states: Optional[bool] = None,
881
  return_dict: Optional[bool] = None,
882
  labels: Optional[torch.LongTensor] = None,
 
884
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
885
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
 
887
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
888
+ logits = self.classifier(sequence_output[:, 0, :])
889
 
890
  loss = None
891
  if labels is not None:
 
913
  if not return_dict:
914
  output = (
915
  logits,
916
+ *([contextualized_embeddings] if output_hidden_states else [])
 
917
  )
918
  return ((loss,) + output) if loss is not None else output
919
 
920
  return SequenceClassifierOutput(
921
  loss=loss,
922
  logits=logits,
923
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
924
  )
925
 
926
 
927
  class GptBertForTokenClassification(GptBertModel):
928
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
929
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
930
 
931
+ def __init__(self, config: GptBertConfig, **kwargs):
932
  super().__init__(config, add_mlm_layer=False, **kwargs)
933
 
934
  self.num_labels = config.num_labels
935
+ self.classifier = Classifier(config, self.num_labels)
936
+ self.post_init()
937
 
938
  def forward(
939
  self,
940
  input_ids: Optional[torch.Tensor] = None,
941
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
942
  output_hidden_states: Optional[bool] = None,
943
  return_dict: Optional[bool] = None,
944
  labels: Optional[torch.LongTensor] = None,
 
946
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
947
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
 
949
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
950
+ logits = self.classifier(sequence_output)
951
 
952
  loss = None
953
  if labels is not None:
 
971
 
972
 
973
  class GptBertForQuestionAnswering(GptBertModel):
974
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
975
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
976
+
977
+ def __init__(self, config: GptBertConfig, **kwargs):
978
  super().__init__(config, add_mlm_layer=False, **kwargs)
979
 
980
  self.num_labels = config.num_labels
981
+ self.classifier = Classifier(config, self.num_labels)
982
+ self.post_init()
983
 
984
  def forward(
985
  self,
986
  input_ids: Optional[torch.Tensor] = None,
987
  attention_mask: Optional[torch.Tensor] = None,
 
 
 
988
  output_hidden_states: Optional[bool] = None,
989
  return_dict: Optional[bool] = None,
990
  start_positions: Optional[torch.Tensor] = None,
 
993
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
994
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995
 
996
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
997
+ logits = self.classifier(sequence_output)
998
 
999
  start_logits, end_logits = logits.split(1, dim=-1)
1000
  start_logits = start_logits.squeeze(-1).contiguous()
 
1022
  output = (
1023
  start_logits,
1024
  end_logits,
1025
+ *([contextualized_embeddings] if output_hidden_states else [])
 
1026
  )
1027
  return ((total_loss,) + output) if total_loss is not None else output
1028
 
 
1030
  loss=total_loss,
1031
  start_logits=start_logits,
1032
  end_logits=end_logits,
1033
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
1034
  )
1035
 
1036
 
1037
  class GptBertForMultipleChoice(GptBertModel):
1038
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1039
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1040
 
1041
+ def __init__(self, config: GptBertConfig, **kwargs):
1042
  super().__init__(config, add_mlm_layer=False, **kwargs)
1043
 
1044
  self.num_labels = getattr(config, "num_labels", 2)
1045
+ self.classifier = Classifier(config, self.num_labels)
1046
+ self.post_init()
1047
 
1048
  def forward(
1049
  self,
1050
  input_ids: Optional[torch.Tensor] = None,
1051
  attention_mask: Optional[torch.Tensor] = None,
 
 
1052
  labels: Optional[torch.Tensor] = None,
 
1053
  output_hidden_states: Optional[bool] = None,
1054
  return_dict: Optional[bool] = None,
1055
  **kwargs
 
1060
  flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1061
  flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1062
 
1063
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
1064
+ logits = self.classifier(sequence_output)
1065
  reshaped_logits = logits.view(-1, num_choices)
1066
 
1067
  loss = None
 
1072
  if not return_dict:
1073
  output = (
1074
  reshaped_logits,
1075
+ *([contextualized_embeddings] if output_hidden_states else [])
 
1076
  )
1077
  return ((loss,) + output) if loss is not None else output
1078
 
1079
  return MultipleChoiceModelOutput(
1080
  loss=loss,
1081
  logits=reshaped_logits,
1082
+ hidden_states=contextualized_embeddings if output_hidden_states else None
 
1083
  )