Murhaf commited on
Commit
f55f78d
·
verified ·
1 Parent(s): 6dd832b

Add new SentenceTransformer model

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 640,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - 'no'
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-similarity
7
+ - feature-extraction
8
+ - dense
9
+ - generated_from_trainer
10
+ - dataset_size:556367
11
+ - loss:CachedMultipleNegativesRankingLoss
12
+ base_model: Murhaf/ltg-norbert4-base_ndla
13
+ widget:
14
+ - source_sentence: Inne i igloen gjør den unge mannen seg klar for sitt overnattingsopphold.
15
+ sentences:
16
+ - Folk danser i gaten.
17
+ - Den unge mannen gjør seg klar for sitt overnattingsopphold.
18
+ - Den unge mannen gjør seg klar til å dra.
19
+ - source_sentence: En kvinne i rullestol snakker med vennen sin mens hun er omgitt
20
+ av andre mennesker som går i parken.
21
+ sentences:
22
+ - Barna blir fotografert.
23
+ - Kvinnen er utendørs.
24
+ - Kvinnen spiser en pølse midt i soverommet sitt.
25
+ - source_sentence: En kvinne løper langs en steinete strand.
26
+ sentences:
27
+ - En mann og en kvinne ser på frukt og grønnsaker.
28
+ - En kvinne løper.
29
+ - En kvinne sitter ved et piknikbord nær den steinete kysten.
30
+ - source_sentence: To basketballspillere i svart og hvitt antrekk står på en basketballbane
31
+ og snakker.
32
+ sentences:
33
+ - De to basketballspillerne snakker sammen.
34
+ - Den unge gutten multitasker.
35
+ - De to basketballspillerne sitter på benken.
36
+ - source_sentence: En mann lager et sandmaleri på gulvet.
37
+ sentences:
38
+ - En mann lager kunst.
39
+ - På fornøyelsesturen var det to jenter som smilte og lo
40
+ - En kvinne ødelegger et sandmaleri.
41
+ datasets:
42
+ - Murhaf/all-nli-norwegian
43
+ pipeline_tag: sentence-similarity
44
+ library_name: sentence-transformers
45
+ metrics:
46
+ - cosine_accuracy
47
+ model-index:
48
+ - name: SentenceTransformer based on Murhaf/ltg-norbert4-base_ndla
49
+ results:
50
+ - task:
51
+ type: triplet
52
+ name: Triplet
53
+ dataset:
54
+ name: nob all nli test
55
+ type: nob_all_nli_test
56
+ metrics:
57
+ - type: cosine_accuracy
58
+ value: 0.9470000267028809
59
+ name: Cosine Accuracy
60
+ ---
61
+
62
+ # SentenceTransformer based on Murhaf/ltg-norbert4-base_ndla
63
+
64
+ This is a [sentence-transformers](https://www.SBERT.net) model finetuned from [Murhaf/ltg-norbert4-base_ndla](https://huggingface.co/Murhaf/ltg-norbert4-base_ndla) on the [all-nli-norwegian](https://huggingface.co/datasets/Murhaf/all-nli-norwegian) dataset. It maps sentences & paragraphs to a 640-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
65
+
66
+ ## Model Details
67
+
68
+ ### Model Description
69
+ - **Model Type:** Sentence Transformer
70
+ - **Base model:** [Murhaf/ltg-norbert4-base_ndla](https://huggingface.co/Murhaf/ltg-norbert4-base_ndla) <!-- at revision 762fb095e1c571e52d8690bf07ec8b65d3551026 -->
71
+ - **Maximum Sequence Length:** 75 tokens
72
+ - **Output Dimensionality:** 640 dimensions
73
+ - **Similarity Function:** Cosine Similarity
74
+ - **Training Dataset:**
75
+ - [all-nli-norwegian](https://huggingface.co/datasets/Murhaf/all-nli-norwegian)
76
+ - **Language:** no
77
+ <!-- - **License:** Unknown -->
78
+
79
+ ### Model Sources
80
+
81
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
82
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
83
+ - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
84
+
85
+ ### Full Model Architecture
86
+
87
+ ```
88
+ SentenceTransformer(
89
+ (0): Transformer({'max_seq_length': 75, 'do_lower_case': False, 'architecture': 'GptBertModel'})
90
+ (1): Pooling({'word_embedding_dimension': 640, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
91
+ )
92
+ ```
93
+
94
+ ## Usage
95
+
96
+ ### Direct Usage (Sentence Transformers)
97
+
98
+ First install the Sentence Transformers library:
99
+
100
+ ```bash
101
+ pip install -U sentence-transformers
102
+ ```
103
+
104
+ Then you can load this model and run inference.
105
+ ```python
106
+ from sentence_transformers import SentenceTransformer
107
+
108
+ # Download from the 🤗 Hub
109
+ model = SentenceTransformer("Murhaf/ltg-norbert4-base_ndla-all-nli")
110
+ # Run inference
111
+ sentences = [
112
+ 'En mann lager et sandmaleri på gulvet.',
113
+ 'En mann lager kunst.',
114
+ 'En kvinne ødelegger et sandmaleri.',
115
+ ]
116
+ embeddings = model.encode(sentences)
117
+ print(embeddings.shape)
118
+ # [3, 640]
119
+
120
+ # Get the similarity scores for the embeddings
121
+ similarities = model.similarity(embeddings, embeddings)
122
+ print(similarities)
123
+ # tensor([[1.0000, 0.5608, 0.3858],
124
+ # [0.5608, 1.0000, 0.2424],
125
+ # [0.3858, 0.2424, 1.0000]])
126
+ ```
127
+
128
+ <!--
129
+ ### Direct Usage (Transformers)
130
+
131
+ <details><summary>Click to see the direct usage in Transformers</summary>
132
+
133
+ </details>
134
+ -->
135
+
136
+ <!--
137
+ ### Downstream Usage (Sentence Transformers)
138
+
139
+ You can finetune this model on your own dataset.
140
+
141
+ <details><summary>Click to expand</summary>
142
+
143
+ </details>
144
+ -->
145
+
146
+ <!--
147
+ ### Out-of-Scope Use
148
+
149
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
150
+ -->
151
+
152
+ ## Evaluation
153
+
154
+ ### Metrics
155
+
156
+ #### Triplet
157
+
158
+ * Dataset: `nob_all_nli_test`
159
+ * Evaluated with [<code>TripletEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.TripletEvaluator)
160
+
161
+ | Metric | Value |
162
+ |:--------------------|:----------|
163
+ | **cosine_accuracy** | **0.947** |
164
+
165
+ <!--
166
+ ## Bias, Risks and Limitations
167
+
168
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
169
+ -->
170
+
171
+ <!--
172
+ ### Recommendations
173
+
174
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
175
+ -->
176
+
177
+ ## Training Details
178
+
179
+ ### Training Dataset
180
+
181
+ #### all-nli-norwegian
182
+
183
+ * Dataset: [all-nli-norwegian](https://huggingface.co/datasets/Murhaf/all-nli-norwegian) at [98cabde](https://huggingface.co/datasets/Murhaf/all-nli-norwegian/tree/98cabded09bfe5f505757840026ecdf6a357a04c)
184
+ * Size: 556,367 training samples
185
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
186
+ * Approximate statistics based on the first 1000 samples:
187
+ | | anchor | positive | negative |
188
+ |:--------|:---------------------------------------------------------------------------------|:----------------------------------------------------------------------------------|:---------------------------------------------------------------------------------|
189
+ | type | string | string | string |
190
+ | details | <ul><li>min: 6 tokens</li><li>mean: 9.53 tokens</li><li>max: 47 tokens</li></ul> | <ul><li>min: 5 tokens</li><li>mean: 12.03 tokens</li><li>max: 40 tokens</li></ul> | <ul><li>min: 5 tokens</li><li>mean: 12.7 tokens</li><li>max: 49 tokens</li></ul> |
191
+ * Samples:
192
+ | anchor | positive | negative |
193
+ |:---------------------------------------------------------------|:------------------------------------------------|:---------------------------------------------------------------|
194
+ | <code>En person på en hest hopper over et havarert fly.</code> | <code>En person er utendørs, på en hest.</code> | <code>En person er på en diner og bestiller en omelett.</code> |
195
+ | <code>Barn smiler og vinker til kameraet</code> | <code>Det er barn til stede</code> | <code>Barna rynker pannen</code> |
196
+ | <code>En gutt hopper på skateboard midt på en rød bro.</code> | <code>Gutten gjør et skateboardtriks.</code> | <code>Gutten skater nedover fortauet.</code> |
197
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
198
+ ```json
199
+ {
200
+ "scale": 20.0,
201
+ "similarity_fct": "cos_sim",
202
+ "mini_batch_size": 32,
203
+ "gather_across_devices": false
204
+ }
205
+ ```
206
+
207
+ ### Evaluation Dataset
208
+
209
+ #### all-nli-norwegian
210
+
211
+ * Dataset: [all-nli-norwegian](https://huggingface.co/datasets/Murhaf/all-nli-norwegian) at [98cabde](https://huggingface.co/datasets/Murhaf/all-nli-norwegian/tree/98cabded09bfe5f505757840026ecdf6a357a04c)
212
+ * Size: 6,561 evaluation samples
213
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
214
+ * Approximate statistics based on the first 1000 samples:
215
+ | | anchor | positive | negative |
216
+ |:--------|:----------------------------------------------------------------------------------|:---------------------------------------------------------------------------------|:--------------------------------------------------------------------------------|
217
+ | type | string | string | string |
218
+ | details | <ul><li>min: 5 tokens</li><li>mean: 17.72 tokens</li><li>max: 74 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 8.98 tokens</li><li>max: 31 tokens</li></ul> | <ul><li>min: 3 tokens</li><li>mean: 9.5 tokens</li><li>max: 29 tokens</li></ul> |
219
+ * Samples:
220
+ | anchor | positive | negative |
221
+ |:--------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------|:------------------------------------------------------------|
222
+ | <code>To kvinner klemmer mens de holder take-away pakker.</code> | <code>To kvinner holder pakker.</code> | <code>Mennene slåss utenfor en deli.</code> |
223
+ | <code>To små barn i blå drakter, en med nummer 9 og en med nummer 2, står på trinn i et bad og vasker hendene i en vask.</code> | <code>To barn i nummererte drakter vasker hendene.</code> | <code>To barn i jakker går til skolen.</code> |
224
+ | <code>En mann selger donuts til en kunde under et verdensutstillingsarrangement holdt i byen Angeles</code> | <code>En mann selger donuts til en kunde.</code> | <code>En kvinne drikker kaffen sin på en liten kafé.</code> |
225
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
226
+ ```json
227
+ {
228
+ "scale": 20.0,
229
+ "similarity_fct": "cos_sim",
230
+ "mini_batch_size": 32,
231
+ "gather_across_devices": false
232
+ }
233
+ ```
234
+
235
+ ### Training Hyperparameters
236
+ #### Non-Default Hyperparameters
237
+
238
+ - `eval_strategy`: steps
239
+ - `per_device_train_batch_size`: 512
240
+ - `per_device_eval_batch_size`: 256
241
+ - `num_train_epochs`: 1
242
+ - `warmup_ratio`: 0.1
243
+ - `batch_sampler`: no_duplicates
244
+
245
+ #### All Hyperparameters
246
+ <details><summary>Click to expand</summary>
247
+
248
+ - `overwrite_output_dir`: False
249
+ - `do_predict`: False
250
+ - `eval_strategy`: steps
251
+ - `prediction_loss_only`: True
252
+ - `per_device_train_batch_size`: 512
253
+ - `per_device_eval_batch_size`: 256
254
+ - `per_gpu_train_batch_size`: None
255
+ - `per_gpu_eval_batch_size`: None
256
+ - `gradient_accumulation_steps`: 1
257
+ - `eval_accumulation_steps`: None
258
+ - `torch_empty_cache_steps`: None
259
+ - `learning_rate`: 5e-05
260
+ - `weight_decay`: 0.0
261
+ - `adam_beta1`: 0.9
262
+ - `adam_beta2`: 0.999
263
+ - `adam_epsilon`: 1e-08
264
+ - `max_grad_norm`: 1.0
265
+ - `num_train_epochs`: 1
266
+ - `max_steps`: -1
267
+ - `lr_scheduler_type`: linear
268
+ - `lr_scheduler_kwargs`: {}
269
+ - `warmup_ratio`: 0.1
270
+ - `warmup_steps`: 0
271
+ - `log_level`: passive
272
+ - `log_level_replica`: warning
273
+ - `log_on_each_node`: True
274
+ - `logging_nan_inf_filter`: True
275
+ - `save_safetensors`: True
276
+ - `save_on_each_node`: False
277
+ - `save_only_model`: False
278
+ - `restore_callback_states_from_checkpoint`: False
279
+ - `no_cuda`: False
280
+ - `use_cpu`: False
281
+ - `use_mps_device`: False
282
+ - `seed`: 42
283
+ - `data_seed`: None
284
+ - `jit_mode_eval`: False
285
+ - `use_ipex`: False
286
+ - `bf16`: False
287
+ - `fp16`: False
288
+ - `fp16_opt_level`: O1
289
+ - `half_precision_backend`: auto
290
+ - `bf16_full_eval`: False
291
+ - `fp16_full_eval`: False
292
+ - `tf32`: None
293
+ - `local_rank`: 2
294
+ - `ddp_backend`: None
295
+ - `tpu_num_cores`: None
296
+ - `tpu_metrics_debug`: False
297
+ - `debug`: []
298
+ - `dataloader_drop_last`: True
299
+ - `dataloader_num_workers`: 0
300
+ - `dataloader_prefetch_factor`: None
301
+ - `past_index`: -1
302
+ - `disable_tqdm`: False
303
+ - `remove_unused_columns`: True
304
+ - `label_names`: None
305
+ - `load_best_model_at_end`: False
306
+ - `ignore_data_skip`: False
307
+ - `fsdp`: []
308
+ - `fsdp_min_num_params`: 0
309
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
310
+ - `fsdp_transformer_layer_cls_to_wrap`: None
311
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
312
+ - `parallelism_config`: None
313
+ - `deepspeed`: None
314
+ - `label_smoothing_factor`: 0.0
315
+ - `optim`: adamw_torch
316
+ - `optim_args`: None
317
+ - `adafactor`: False
318
+ - `group_by_length`: False
319
+ - `length_column_name`: length
320
+ - `ddp_find_unused_parameters`: None
321
+ - `ddp_bucket_cap_mb`: None
322
+ - `ddp_broadcast_buffers`: False
323
+ - `dataloader_pin_memory`: True
324
+ - `dataloader_persistent_workers`: False
325
+ - `skip_memory_metrics`: True
326
+ - `use_legacy_prediction_loop`: False
327
+ - `push_to_hub`: False
328
+ - `resume_from_checkpoint`: None
329
+ - `hub_model_id`: None
330
+ - `hub_strategy`: every_save
331
+ - `hub_private_repo`: None
332
+ - `hub_always_push`: False
333
+ - `hub_revision`: None
334
+ - `gradient_checkpointing`: False
335
+ - `gradient_checkpointing_kwargs`: None
336
+ - `include_inputs_for_metrics`: False
337
+ - `include_for_metrics`: []
338
+ - `eval_do_concat_batches`: True
339
+ - `fp16_backend`: auto
340
+ - `push_to_hub_model_id`: None
341
+ - `push_to_hub_organization`: None
342
+ - `mp_parameters`:
343
+ - `auto_find_batch_size`: False
344
+ - `full_determinism`: False
345
+ - `torchdynamo`: None
346
+ - `ray_scope`: last
347
+ - `ddp_timeout`: 1800
348
+ - `torch_compile`: False
349
+ - `torch_compile_backend`: None
350
+ - `torch_compile_mode`: None
351
+ - `include_tokens_per_second`: False
352
+ - `include_num_input_tokens_seen`: False
353
+ - `neftune_noise_alpha`: None
354
+ - `optim_target_modules`: None
355
+ - `batch_eval_metrics`: False
356
+ - `eval_on_start`: False
357
+ - `use_liger_kernel`: False
358
+ - `liger_kernel_config`: None
359
+ - `eval_use_gather_object`: False
360
+ - `average_tokens_across_devices`: True
361
+ - `prompts`: None
362
+ - `batch_sampler`: no_duplicates
363
+ - `multi_dataset_batch_sampler`: proportional
364
+ - `router_mapping`: {}
365
+ - `learning_rate_mapping`: {}
366
+
367
+ </details>
368
+
369
+ ### Training Logs
370
+ | Epoch | Step | Training Loss | Validation Loss | nob_all_nli_test_cosine_accuracy |
371
+ |:------:|:----:|:-------------:|:---------------:|:--------------------------------:|
372
+ | 0.3690 | 100 | 1.8282 | 0.6138 | 0.9420 |
373
+ | 0.7380 | 200 | 1.1887 | 0.5645 | 0.9470 |
374
+
375
+
376
+ ### Framework Versions
377
+ - Python: 3.12.11
378
+ - Sentence Transformers: 5.1.1
379
+ - Transformers: 4.56.2
380
+ - PyTorch: 2.6.0+cu124
381
+ - Accelerate: 1.10.1
382
+ - Datasets: 4.1.1
383
+ - Tokenizers: 0.22.1
384
+
385
+ ## Citation
386
+
387
+ ### BibTeX
388
+
389
+ #### Sentence Transformers
390
+ ```bibtex
391
+ @inproceedings{reimers-2019-sentence-bert,
392
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
393
+ author = "Reimers, Nils and Gurevych, Iryna",
394
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
395
+ month = "11",
396
+ year = "2019",
397
+ publisher = "Association for Computational Linguistics",
398
+ url = "https://arxiv.org/abs/1908.10084",
399
+ }
400
+ ```
401
+
402
+ #### CachedMultipleNegativesRankingLoss
403
+ ```bibtex
404
+ @misc{gao2021scaling,
405
+ title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
406
+ author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
407
+ year={2021},
408
+ eprint={2101.06983},
409
+ archivePrefix={arXiv},
410
+ primaryClass={cs.LG}
411
+ }
412
+ ```
413
+
414
+ <!--
415
+ ## Glossary
416
+
417
+ *Clearly define terms in order to be accessible across audiences.*
418
+ -->
419
+
420
+ <!--
421
+ ## Model Card Authors
422
+
423
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
424
+ -->
425
+
426
+ <!--
427
+ ## Model Card Contact
428
+
429
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
430
+ -->
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GptBertModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_implementation": null,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_gptbert.GptBertConfig",
9
+ "AutoModel": "modeling_gptbert.GptBertModel",
10
+ "AutoModelForCausalLM": "modeling_gptbert.GptBertForCausalLM",
11
+ "AutoModelForMaskedLM": "modeling_gptbert.GptBertForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_gptbert.GptBertForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_gptbert.GptBertForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_gptbert.GptBertForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_gptbert.GptBertForTokenClassification"
16
+ },
17
+ "bos_token_id": 1,
18
+ "classifier_dropout": 0.2,
19
+ "deterministic_flash_attn": false,
20
+ "dtype": "float32",
21
+ "embedding_dropout": 0.1,
22
+ "eos_token_id": 2,
23
+ "global_window_length": 8192,
24
+ "hidden_dropout": 0.0,
25
+ "hidden_size": 640,
26
+ "intermediate_size": 1664,
27
+ "layer_norm_eps": 1e-07,
28
+ "local_global_ratio": 4,
29
+ "local_window_length": 256,
30
+ "mask_token_id": 4,
31
+ "max_sequence_length": 16384,
32
+ "model": "norbert4",
33
+ "num_attention_heads": 10,
34
+ "num_layers": 24,
35
+ "pad_token_id": 3,
36
+ "query_key_head_size": 64,
37
+ "rope_theta": 160000,
38
+ "transformers_version": "4.56.2",
39
+ "unk_token_id": 0,
40
+ "use_cache": false,
41
+ "value_head_size": 64,
42
+ "vocab_size": 51200
43
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SentenceTransformer",
3
+ "__version__": {
4
+ "sentence_transformers": "5.1.1",
5
+ "transformers": "4.56.2",
6
+ "pytorch": "2.6.0+cu124"
7
+ },
8
+ "prompts": {
9
+ "query": "",
10
+ "document": ""
11
+ },
12
+ "default_prompt_name": null,
13
+ "similarity_fn_name": "cosine"
14
+ }
configuration_gptbert.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ import copy
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class GptBertConfig(PretrainedConfig):
10
+
11
+ def __init__(
12
+ self,
13
+ config_file: Path | str | None = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.model = "norbert4"
18
+
19
+ if config_file is not None:
20
+ if type(config_file) is str:
21
+ config_file = Path(config_file)
22
+ assert type(config_file) is not Path, "The config_file should either be a Path or str"
23
+ with config_file.open("r") as file:
24
+ config = json.load(file)
25
+
26
+ for attr, value in config.items():
27
+ if isinstance(value, str):
28
+ value = value.lower()
29
+ setattr(self, attr, value)
30
+
31
+ for attr, value in kwargs.items():
32
+ if isinstance(value, str):
33
+ value = value.lower()
34
+ setattr(self, attr, value)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:041900178a3b1d3d04eaeafe056d49541f203dc201c591bebf88f3d735e004f8
3
+ size 595640976
modeling_gptbert.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch import _softmax_backward_data as _softmax_backward_data
7
+
8
+ from functools import partial, lru_cache
9
+
10
+ from .configuration_gptbert import GptBertConfig
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.utils import is_flash_attn_2_available, logging
14
+ from transformers.modeling_outputs import (
15
+ MaskedLMOutput,
16
+ MultipleChoiceModelOutput,
17
+ QuestionAnsweringModelOutput,
18
+ SequenceClassifierOutput,
19
+ TokenClassifierOutput,
20
+ BaseModelOutput,
21
+ CausalLMOutput
22
+ )
23
+ import math
24
+ from typing import TYPE_CHECKING, Optional, Union, Tuple, List
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # Workaround for transformers < 4.36.0 check_imports issue
30
+ # See: https://github.com/huggingface/transformers/issues/28459
31
+ try:
32
+ if is_flash_attn_2_available():
33
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
34
+ from flash_attn.layers.rotary import RotaryEmbedding
35
+ from flash_attn.ops.triton.rotary import apply_rotary
36
+ else:
37
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
38
+ logger.warning_once(
39
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
40
+ )
41
+ except ImportError:
42
+ flash_attn_varlen_qkvpacked_func, RotaryEmbedding, apply_rotary = None, object, None
43
+ logger.warning_once(
44
+ "NorBERT4 støtter FlashAttention, men det er ikke funnet i miljøet ditt. Du bør vurdere å oppdatere miljøet ditt for å få raskere og mindre minnekrevende behandling."
45
+ )
46
+
47
+
48
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
49
+ @torch.compiler.disable()
50
+ def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
51
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
52
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
53
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
54
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
55
+
56
+ if input_ids.dim() == 2:
57
+ unpadded_inputs = input_ids.flatten()[indices]
58
+ else:
59
+ batch_size, sequence_length, *rest = input_ids.shape
60
+ shape = batch_size * sequence_length
61
+ unpadded_inputs = input_ids.view(shape, *rest)[indices]
62
+
63
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch
64
+
65
+
66
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
67
+ def _pad_output(input_ids: torch.Tensor, indices: torch.Tensor, batch_size: int, sequence_length: int) -> torch.Tensor:
68
+ if input_ids.dim() == 1:
69
+ output = torch.zeros(batch_size * sequence_length, dtype=input_ids.dtype, device=input_ids.device)
70
+ output[indices] = input_ids
71
+ padded_inputs = output.view(batch_size, sequence_length)
72
+ else:
73
+ _, *rest = input_ids.shape
74
+ output = torch.zeros(batch_size * sequence_length, *rest, dtype=input_ids.dtype, device=input_ids.device)
75
+ output[indices] = input_ids
76
+ padded_inputs = output.view(batch_size, sequence_length, *rest)
77
+
78
+ return padded_inputs
79
+
80
+
81
+ class CastedLinear(nn.Linear):
82
+ def __init__(self, in_features, out_features, bias):
83
+ super().__init__(in_features, out_features, bias=bias)
84
+
85
+ def forward(self, x):
86
+ return F.linear(x, self.weight.type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
87
+
88
+
89
+ class CastedLinearIn(nn.Linear):
90
+ def __init__(self, in_features, out_features, bias):
91
+ super().__init__(in_features, out_features, bias=bias)
92
+ self.scale = nn.Parameter(torch.ones(in_features))
93
+
94
+ def forward(self, x):
95
+ return F.linear(x, (self.weight * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
96
+
97
+
98
+ class MultiCastedLinearOrthoIn(nn.Module):
99
+ def __init__(self, in_features, out_features, bias):
100
+ super().__init__()
101
+
102
+ self.in_features = in_features
103
+ self.out_features = out_features
104
+
105
+ self.weights = nn.ParameterList()
106
+ for out_feature in out_features:
107
+ self.weights.append(nn.Parameter(torch.empty((out_feature, in_features))))
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(sum(out_features)))
111
+ else:
112
+ self.bias = self.register_parameter("bias", None)
113
+
114
+ self.scale = nn.Parameter(torch.ones(in_features))
115
+
116
+ def forward(self, x):
117
+ return F.linear(x, (torch.cat([weight for weight in self.weights], dim=0) * (self.scale + 1.0).unsqueeze(0)).type_as(x), bias=self.bias.type_as(x) if self.bias is not None else None)
118
+
119
+
120
+ class GeGLU(nn.Module):
121
+ def forward(self, x):
122
+ x, gate = x.chunk(2, dim=-1)
123
+ return x * gelu_new(gate)
124
+
125
+
126
+ class Embedding(nn.Module):
127
+ def __init__(self, config: GptBertConfig):
128
+ super().__init__()
129
+
130
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
131
+ self.word_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
132
+ self.word_scale = nn.Parameter(torch.zeros(config.hidden_size))
133
+ self.dropout = nn.Dropout(config.embedding_dropout)
134
+
135
+ def forward(self, input_ids: torch.Tensor):
136
+ word_embedding = self.word_embedding(input_ids)
137
+ word_embedding = self.word_norm(word_embedding)
138
+ word_embedding = word_embedding * (self.word_scale + 1.0)
139
+
140
+ return self.dropout(word_embedding)
141
+
142
+
143
+ class LMClassifier(nn.Module):
144
+ def __init__(self, config: GptBertConfig, n_labels: int):
145
+ super().__init__()
146
+
147
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
148
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
149
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
150
+ self.emb2vocab = CastedLinearIn(config.hidden_size, n_labels, bias=True)
151
+
152
+ def forward(self, x: torch.Tensor):
153
+ x = self.pre_norm(x.float()).type_as(x)
154
+ x = self.projection(x)
155
+ x = gelu_new(x)
156
+ x = self.post_norm(x.float()).type_as(x)
157
+ x = self.emb2vocab(x)
158
+ return x
159
+
160
+
161
+ class Classifier(nn.Module):
162
+ def __init__(self, config: GptBertConfig, n_labels: int):
163
+ super().__init__()
164
+
165
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
166
+ self.projection = CastedLinearIn(config.hidden_size, config.hidden_size, bias=False)
167
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
168
+ self.dropout = nn.Dropout(config.classifier_dropout)
169
+ self.output_projection = CastedLinearIn(config.hidden_size, n_labels, bias=True)
170
+
171
+ def forward(self, x: torch.Tensor):
172
+ x = self.pre_norm(x.float()).type_as(x)
173
+ x = self.projection(x)
174
+ x = gelu_new(x)
175
+ x = self.post_norm(x.float()).type_as(x)
176
+ x = self.dropout(x)
177
+ x = self.output_projection(x)
178
+ return x
179
+
180
+
181
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
182
+ def flash_attention_forward(qkv: torch.Tensor, rotary_emb: UnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, max_seqlen: int, causal: bool, local_attention: Tuple[int, int], dropout_p: float, deterministic: bool, target_dtype: torch.dtype = torch.bfloat16, **_kwargs):
183
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
184
+
185
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
186
+ if convert_dtype:
187
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
188
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
189
+ orig_dtype = qkv.dtype
190
+ qkv = qkv.to(target_dtype)
191
+
192
+ attn = flash_attn_varlen_qkvpacked_func(
193
+ qkv,
194
+ cu_seqlens=cu_seqlens,
195
+ max_seqlen=max_seqlen,
196
+ dropout_p=dropout_p,
197
+ deterministic=deterministic,
198
+ window_size=local_attention,
199
+ causal=False
200
+ )
201
+ attn = attn.to(orig_dtype) # type: ignore
202
+ else:
203
+ attn = flash_attn_varlen_qkvpacked_func(
204
+ qkv,
205
+ cu_seqlens=cu_seqlens,
206
+ max_seqlen=max_seqlen,
207
+ dropout_p=dropout_p,
208
+ deterministic=deterministic,
209
+ window_size=local_attention,
210
+ causal=False
211
+ )
212
+ return attn
213
+
214
+
215
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
216
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
217
+ @staticmethod
218
+ def forward(ctx, qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
219
+ # (total_nnz, 3, nheads, headdim)
220
+ qkv = qkv.contiguous()
221
+ total_nnz, _three, _nheads, headdim = qkv.shape
222
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
223
+ # we get the same tensor
224
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
225
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
226
+ apply_rotary(qk, cos, sin, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=False, inplace=True)
227
+
228
+ ctx.save_for_backward(cos, sin, cu_seqlens)
229
+ ctx.max_seqlen = max_seqlen
230
+ return qkv
231
+
232
+ @staticmethod
233
+ def backward(ctx, do):
234
+ cos, sin, cu_seqlens = ctx.saved_tensors
235
+ do = do.contiguous()
236
+ total_nnz, _three, _nheads, headdim = do.shape
237
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
238
+ # we get the same tensor
239
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
240
+ apply_rotary(
241
+ dqk,
242
+ cos,
243
+ sin,
244
+ seqlen_offsets=0,
245
+ cu_seqlens=cu_seqlens,
246
+ max_seqlen=ctx.max_seqlen,
247
+ interleaved=False,
248
+ inplace=True,
249
+ conjugate=True,
250
+ )
251
+
252
+ return do, None, None, None, None, None, None
253
+
254
+
255
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
256
+ def apply_rotary_unpadded(qkv, cos, sin, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
257
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
258
+
259
+
260
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modeling_modernbert.py
261
+ class UnpaddedRotaryEmbedding(RotaryEmbedding):
262
+ def __init__(self, dim: int, base: float = 10000.0, max_seqlen: Optional[int] = None):
263
+ super().__init__(dim=dim, base=base, device=None, interleaved=False)
264
+ self.max_seqlen = max_seqlen
265
+
266
+ def forward(self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
267
+ if max_seqlen is not None:
268
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
269
+
270
+ qkv = apply_rotary_unpadded(
271
+ qkv,
272
+ self._cos_cached,
273
+ self._sin_cached,
274
+ cu_seqlens=cu_seqlens,
275
+ max_seqlen=max_seqlen,
276
+ )
277
+
278
+ return qkv
279
+
280
+
281
+ class RotaryPositionalEmbeddings(nn.Module):
282
+ def __init__(self, config, theta: int):
283
+ super().__init__()
284
+
285
+ head_size = config.query_key_head_size
286
+ assert head_size % 2 == 0
287
+ max_seq_len = config.max_sequence_length
288
+
289
+ inv_freq = 1.0 / (theta ** (torch.arange(0, head_size, 2, dtype=torch.float32) / head_size))
290
+ pos = torch.arange(max_seq_len, dtype=torch.float32)
291
+ embedding = torch.einsum('n, d -> nd', pos, inv_freq)
292
+ embedding = torch.cat([embedding, embedding], dim=-1).unsqueeze(0)
293
+ self.register_buffer("cos_matrix", embedding.cos(), persistent=False)
294
+ self.register_buffer("sin_matrix", embedding.sin(), persistent=False)
295
+
296
+ def forward(self, x: torch.Tensor):
297
+ hidden_layer = x.float()
298
+
299
+ seq_len = x.shape[2]
300
+
301
+ cos_matrix = self.cos_matrix[:, None, :seq_len, :]
302
+ sin_matrix = self.sin_matrix[:, None, :seq_len, :]
303
+
304
+ x_rotate_half = torch.cat(
305
+ [
306
+ -hidden_layer[:, :, :, x.size(-1) // 2:],
307
+ hidden_layer[:, :, :, :x.size(-1) // 2]
308
+ ],
309
+ dim=-1
310
+ )
311
+
312
+ out = hidden_layer * cos_matrix + x_rotate_half * sin_matrix
313
+ return out.type_as(x)
314
+
315
+
316
+ class MaskedSoftmax(torch.autograd.Function):
317
+ @staticmethod
318
+ def forward(ctx, x: torch.Tensor, mask: torch.BoolTensor, dim: int) -> torch.Tensor:
319
+ ctx.dim = dim
320
+ x.masked_fill_(mask, float('-inf'))
321
+ x = torch.softmax(x, ctx.dim)
322
+ x.masked_fill_(mask, 0.0)
323
+ ctx.save_for_backward(x)
324
+ return x
325
+
326
+ @staticmethod
327
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
328
+ output: torch.Tensor
329
+
330
+ output, = ctx.saved_tensors
331
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
332
+ return inputGrad, None, None
333
+
334
+
335
+ class SelfAttention(nn.Module):
336
+ def __init__(self, config: GptBertConfig, layer_idx: int):
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ self.layer_idx = layer_idx
341
+
342
+ self.d_qk = config.query_key_head_size
343
+ self.d_v = config.value_head_size
344
+ self.num_attention_heads = config.num_attention_heads
345
+ self.num_kv_heads = config.num_attention_heads
346
+ self.hidden_size = config.hidden_size
347
+
348
+ self.q_out_dim = self.d_qk * self.num_attention_heads
349
+ self.k_out_dim = self.d_qk * self.num_kv_heads
350
+ self.v_out_dim = self.d_v * self.num_kv_heads
351
+
352
+ self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
+ self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
+ self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
355
+
356
+ self.pre_v_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
357
+ self.pre_qk_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
358
+ self.inter_norm = nn.LayerNorm(self.d_v * self.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=False)
359
+ self.q_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
360
+ self.k_norm = nn.LayerNorm(self.d_qk, eps=config.layer_norm_eps, elementwise_affine=False, bias=False)
361
+ self.k_scale = nn.Parameter(torch.ones(self.num_kv_heads, self.d_qk))
362
+ self.q_scale = nn.Parameter(torch.ones(self.num_attention_heads, self.d_qk))
363
+
364
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
365
+ self.dropout = nn.Dropout(config.hidden_dropout)
366
+
367
+ theta = 160_000 if (layer_idx + 1) % config.local_global_ratio == 0 else 10_000
368
+
369
+ # Initialize rotary embeddings based on whether FlashAttention is available
370
+ if flash_attn_varlen_qkvpacked_func is not None:
371
+ self.rope_embedding = UnpaddedRotaryEmbedding(dim=self.d_qk, base=theta, max_seqlen=config.max_sequence_length)
372
+ else:
373
+ self.rope_embedding = RotaryPositionalEmbeddings(config, theta)
374
+
375
+ self.scale = 1.0 / math.sqrt(self.d_qk)
376
+ self.lambdas = nn.Parameter(torch.tensor([0.5]))
377
+
378
+ self.sequence_length = config.max_sequence_length
379
+ self.is_causal = config.is_decoder
380
+ self.window_length = None
381
+
382
+ def set_window_length(self, window_length: int):
383
+ self.window_length = window_length
384
+
385
+ def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
386
+ """Create and cache window attention mask."""
387
+ if self.is_causal:
388
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
389
+ mask = mask.tril().triu(diagonal=-self.window_length)
390
+ else:
391
+ mask = torch.ones(query_length, key_length, dtype=torch.bool, device=device)
392
+ mask = mask.tril(diagonal=self.window_length).triu(diagonal=-self.window_length)
393
+ return mask.view(1, 1, query_length, key_length)
394
+
395
+ def attention_operation(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
396
+ """Standard attention computation with masking."""
397
+ batch_size, _, query_length, _ = query.size()
398
+ _, _, key_length, _ = key.size()
399
+
400
+ # Use cached window mask
401
+ with torch.no_grad():
402
+ window_mask = self._get_window_mask(query_length, key_length, query.device)
403
+ if padding_mask is not None:
404
+ attention_mask = padding_mask & window_mask
405
+ else:
406
+ attention_mask = window_mask
407
+
408
+ attention_scores = torch.bmm(query.flatten(0, 1), key.transpose(-1, -2).flatten(0, 1)) * self.scale # shape: [B*H, Q_T, K_T]
409
+ attention_scores = attention_scores.view(batch_size, self.num_attention_heads, query_length, key_length)
410
+
411
+ attention_probabilities = MaskedSoftmax.apply(attention_scores, ~attention_mask, -1)
412
+ attention_probabilities = self.attention_dropout(attention_probabilities)
413
+
414
+ output = torch.bmm(attention_probabilities.flatten(0, 1), value.flatten(0, 1))
415
+ output = output.view(batch_size, self.num_attention_heads, query_length, self.d_v)
416
+
417
+ return output
418
+
419
+ def forward(self, hidden_layer: torch.Tensor, qk_layer: torch.Tensor, v1: torch.Tensor | None, padding_info):
420
+ # Get original shape info
421
+ if flash_attn_varlen_qkvpacked_func is not None:
422
+ # Unpadded case
423
+ indices, cu_seqlens, max_seqlen = padding_info
424
+ total_seqlen = hidden_layer.size(0)
425
+ batch_size = cu_seqlens.size(0) - 1
426
+ else:
427
+ # Padded case
428
+ batch_size, seq_length = hidden_layer.size(0), hidden_layer.size(1)
429
+
430
+ hidden_layer = self.pre_v_norm(hidden_layer.float()).type_as(hidden_layer)
431
+ qk_layer = self.pre_qk_norm(qk_layer.float()).type_as(qk_layer)
432
+
433
+ query, key = self.qk_proj(qk_layer).tensor_split([self.q_out_dim], dim=-1)
434
+ value = self.v_proj(hidden_layer)
435
+
436
+ if flash_attn_varlen_qkvpacked_func is not None:
437
+ # Reshape for FlashAttention: (total_seqlen, num_heads, head_dim)
438
+ query = query.view(total_seqlen, self.num_attention_heads, self.d_qk)
439
+ key = key.view(total_seqlen, self.num_kv_heads, self.d_qk)
440
+ value = value.view(total_seqlen, self.num_kv_heads, self.d_v)
441
+
442
+ # Apply layer norm and scaling
443
+ query = ((self.q_scale + 1.0).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
444
+ key = ((self.k_scale + 1.0).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
445
+
446
+ if v1 is None:
447
+ v1 = value
448
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
449
+
450
+ # Prepare qkv for FlashAttention
451
+ qkv = torch.stack([query, key, value], dim=1) # (total_seqlen, 3, num_heads, head_dim)
452
+
453
+ # Determine window size for local attention
454
+ if self.window_length is not None and self.window_length > 0:
455
+ if self.is_causal:
456
+ local_attention = (self.window_length - 1, 0)
457
+ else:
458
+ local_attention = (self.window_length - 1, self.window_length - 1)
459
+ else:
460
+ local_attention = (-1, -1)
461
+
462
+ # Apply FlashAttention
463
+ output = flash_attention_forward(
464
+ qkv,
465
+ self.rope_embedding,
466
+ cu_seqlens,
467
+ max_seqlen,
468
+ self.is_causal,
469
+ local_attention,
470
+ self.config.attention_dropout if self.training else 0.0,
471
+ self.config.deterministic_flash_attn
472
+ )
473
+
474
+ # Reshape output back
475
+ output = output.view(total_seqlen, self.d_v * self.num_attention_heads)
476
+
477
+ else:
478
+ # Standard attention path
479
+ query_length = query.size(1)
480
+ key_length = key.size(1)
481
+
482
+ query = query.reshape(batch_size, query_length, self.num_attention_heads, self.d_qk).transpose(1, 2)
483
+ key = key.reshape(batch_size, key_length, self.num_kv_heads, self.d_qk).transpose(1, 2)
484
+ value = value.reshape(batch_size, key_length, self.num_kv_heads, self.d_v).transpose(1, 2)
485
+
486
+ query = ((self.q_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.q_norm(query.float())).type_as(query)
487
+ key = ((self.k_scale + 1.0).unsqueeze(1).unsqueeze(0) * self.k_norm(key.float())).type_as(key)
488
+
489
+ if v1 is None:
490
+ v1 = value
491
+ else:
492
+ value = (1 - self.lambdas[0]) * value + self.lambdas[0] * v1
493
+
494
+ # Apply rotary embeddings
495
+ query = self.rope_embedding(query)
496
+ key = self.rope_embedding(key)
497
+
498
+ output = self.attention_operation(query, key, value, padding_info)
499
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T, H*D]
500
+
501
+ output = self.inter_norm(output.float()).type_as(output)
502
+ output = self.out_proj(output)
503
+ output = self.dropout(output)
504
+
505
+ return output, v1
506
+
507
+
508
+ class FeedForward(nn.Module):
509
+ def __init__(self, config: GptBertConfig):
510
+ super().__init__()
511
+ self.pre_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
512
+ self.up_proj = MultiCastedLinearOrthoIn(config.hidden_size, [config.intermediate_size, config.intermediate_size], bias=False)
513
+ self.activation = GeGLU()
514
+ self.inter_norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False)
515
+ self.down_proj = CastedLinearIn(config.intermediate_size, config.hidden_size, bias=False)
516
+ self.dropout = nn.Dropout(config.hidden_dropout)
517
+
518
+ def forward(self, x: torch.Tensor):
519
+ x = self.pre_norm(x.float()).type_as(x)
520
+ x = self.up_proj(x)
521
+ x = self.activation(x)
522
+ x = self.inter_norm(x.float()).type_as(x)
523
+ x = self.down_proj(x)
524
+ x = self.dropout(x)
525
+ return x
526
+
527
+
528
+ class Layer(nn.Module):
529
+ def __init__(self, config: GptBertConfig, layer_idx: int):
530
+ super().__init__()
531
+
532
+ self.attention = SelfAttention(config, layer_idx)
533
+ self.mlp = FeedForward(config)
534
+ self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
535
+
536
+ def set_window_length(self, window_length: int):
537
+ self.attention.set_window_length(window_length)
538
+
539
+ def forward(self, hidden_layer: torch.Tensor, embeddings: torch.Tensor, v1: torch.Tensor | None, padding_info):
540
+ attention_output = (1 - self.lambdas[0]) * hidden_layer + self.lambdas[0] * embeddings
541
+ qk_layer = (1 - self.lambdas[1]) * hidden_layer + self.lambdas[1] * embeddings
542
+ mlp_layer = F.softplus(self.lambdas[2]) * ((1 - self.lambdas[3]) * hidden_layer + self.lambdas[3] * embeddings)
543
+
544
+ attention_output, v1 = self.attention(attention_output, qk_layer, v1, padding_info)
545
+ mlp_layer = mlp_layer + attention_output
546
+ hidden_layer = F.softplus(self.lambdas[4]) * ((1 - self.lambdas[5]) * hidden_layer + self.lambdas[5] * embeddings)
547
+ output = hidden_layer + attention_output + self.mlp(mlp_layer)
548
+
549
+ return output, v1
550
+
551
+
552
+ class Encoder(nn.Module):
553
+ def __init__(self, config: GptBertConfig):
554
+ super().__init__()
555
+ self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
556
+ self.local_global_ratio = config.local_global_ratio
557
+
558
+ def set_window_length(self, config: GptBertConfig):
559
+ for i, layer in enumerate(self.layers):
560
+ if (i + 1) % self.local_global_ratio == 0:
561
+ layer.set_window_length(config.global_window_length)
562
+ else:
563
+ layer.set_window_length(config.local_window_length)
564
+
565
+ def forward(self, hidden_layer: torch.Tensor, padding_info, output_hidden_states=False, checkpoint_activations=False):
566
+ hidden_layers = [hidden_layer] if output_hidden_states else None
567
+ v1 = None
568
+ embeddings = hidden_layer
569
+
570
+ for layer in self.layers:
571
+ if checkpoint_activations:
572
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
573
+ else:
574
+ hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
575
+
576
+ if output_hidden_states:
577
+ hidden_layers.append(hidden_layer)
578
+
579
+ return hidden_layer, hidden_layers
580
+
581
+
582
+ #
583
+ # HuggingFace wrappers
584
+ #
585
+
586
+ class GptBertPreTrainedModel(PreTrainedModel):
587
+ config_class = GptBertConfig
588
+ supports_gradient_checkpointing = True
589
+ _supports_flash_attn_2 = True
590
+ _supports_sdpa = True
591
+ _supports_flex_attn = False
592
+
593
+ def _init_weights(self, module):
594
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
595
+
596
+ if isinstance(module, nn.Linear) or isinstance(module, CastedLinearIn):
597
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
598
+ if module.bias is not None:
599
+ module.bias.data.zero_()
600
+ elif isinstance(module, nn.Embedding):
601
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
602
+ elif isinstance(module, nn.LayerNorm):
603
+ module.bias.data.zero_()
604
+ module.weight.data.fill_(1.0)
605
+
606
+
607
+ class GptBertModel(GptBertPreTrainedModel):
608
+ def __init__(self, config: GptBertConfig, add_mlm_layer=False, **kwargs):
609
+ super().__init__(config, **kwargs)
610
+ self.config = config
611
+ self.hidden_size = config.hidden_size
612
+
613
+ self.embedding = Embedding(config)
614
+ self.encoder = Encoder(config)
615
+ self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
616
+ self.set_window_length(config)
617
+ self.gradient_checkpointing = False
618
+ self.post_init()
619
+
620
+ def set_window_length(self, config) -> None:
621
+ self.encoder.set_window_length(config)
622
+
623
+ def get_input_embeddings(self):
624
+ return self.embedding.word_embedding
625
+
626
+ def set_input_embeddings(self, value):
627
+ self.embedding.word_embedding = value
628
+
629
+ def get_contextualized_embeddings(
630
+ self,
631
+ input_ids: Optional[torch.Tensor] = None,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ output_hidden_states: Optional[bool] = None
634
+ ):
635
+ if input_ids is not None:
636
+ input_shape = input_ids.size()
637
+ else:
638
+ raise ValueError("You have to specify input_ids")
639
+
640
+ batch_size, seq_length = input_shape
641
+ device = input_ids.device
642
+
643
+ if attention_mask is None:
644
+ attention_mask = torch.ones(batch_size, seq_length, dtype=torch.bool, device=device)
645
+ else:
646
+ attention_mask = attention_mask.bool()
647
+
648
+ if flash_attn_varlen_qkvpacked_func is not None:
649
+ if len(attention_mask.size()) != 2:
650
+ raise ValueError("Bare `attention_mask` med to dimensjoner støttes nå for FlashAttention.")
651
+ with torch.no_grad():
652
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch = _unpad_input(input_ids, attention_mask)
653
+ padding_info = (indices, cu_seqlens, max_seqlen_in_batch)
654
+ else:
655
+ if len(attention_mask.size()) == 2:
656
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
657
+ elif len(attention_mask.size()) == 3:
658
+ attention_mask = attention_mask.unsqueeze(1)
659
+ padding_info = attention_mask
660
+
661
+ static_embeddings = self.embedding(input_ids)
662
+
663
+ original_dtype = static_embeddings.dtype
664
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and static_embeddings.dtype == torch.float32:
665
+ static_embeddings = static_embeddings.bfloat16()
666
+
667
+ last_layer, contextualized_embeddings = self.encoder(
668
+ static_embeddings,
669
+ padding_info,
670
+ output_hidden_states=output_hidden_states,
671
+ checkpoint_activations=self.gradient_checkpointing and self.training
672
+ )
673
+
674
+ last_layer = last_layer.to(original_dtype)
675
+ if output_hidden_states:
676
+ contextualized_embeddings = [layer.to(original_dtype) for layer in contextualized_embeddings]
677
+
678
+ # Pad output if using FlashAttention
679
+ if flash_attn_varlen_qkvpacked_func is not None:
680
+ last_layer = _pad_output(last_layer, indices, batch_size, seq_length)
681
+ if output_hidden_states:
682
+ contextualized_embeddings = [_pad_output(layer, indices, batch_size, seq_length) for layer in contextualized_embeddings]
683
+ else:
684
+ contextualized_embeddings = None
685
+
686
+ return last_layer, contextualized_embeddings
687
+
688
+ def forward(
689
+ self,
690
+ input_ids: Optional[torch.Tensor] = None,
691
+ attention_mask: Optional[torch.Tensor] = None,
692
+ output_hidden_states: Optional[bool] = None,
693
+ output_attentions: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ **kwargs
696
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
697
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
698
+
699
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
700
+
701
+ if not return_dict:
702
+ return (
703
+ sequence_output,
704
+ *([contextualized_embeddings] if output_hidden_states else [])
705
+ )
706
+
707
+ return BaseModelOutput(
708
+ last_hidden_state=sequence_output,
709
+ hidden_states=contextualized_embeddings if output_hidden_states else None
710
+ )
711
+
712
+
713
+ class GptBertForMaskedLM(GptBertModel):
714
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
715
+
716
+ def __init__(self, config: GptBertConfig, **kwargs):
717
+ super().__init__(config, add_mlm_layer=True, **kwargs)
718
+
719
+ def get_output_embeddings(self):
720
+ return self.classifier.emb2vocab.weight
721
+
722
+ def set_output_embeddings(self, new_embeddings):
723
+ self.classifier.emb2vocab.weight = new_embeddings
724
+
725
+ def forward(
726
+ self,
727
+ input_ids: Optional[torch.Tensor] = None,
728
+ attention_mask: Optional[torch.Tensor] = None,
729
+ output_hidden_states: Optional[bool] = None,
730
+ return_dict: Optional[bool] = None,
731
+ labels: Optional[torch.LongTensor] = None,
732
+ **kwargs
733
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
734
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
735
+
736
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
737
+ subword_prediction = self.classifier(sequence_output)
738
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
739
+
740
+ masked_lm_loss = None
741
+ if labels is not None:
742
+ labels_flatten = labels[:, 1:].flatten()
743
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
744
+ masked_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
745
+
746
+ bos_logits = torch.zeros(subword_prediction.size(0), 1, self.config.vocab_size, dtype=subword_prediction.dtype, device=subword_prediction.device)
747
+ bos_logits[:, :, self.config.bos_token_id] = 1.0
748
+ subword_prediction = torch.cat([bos_logits, subword_prediction[:, :-1]], dim=1)
749
+
750
+ if not return_dict:
751
+ output = (
752
+ subword_prediction,
753
+ *([contextualized_embeddings] if output_hidden_states else [])
754
+ )
755
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
756
+
757
+ return MaskedLMOutput(
758
+ loss=masked_lm_loss,
759
+ logits=subword_prediction,
760
+ hidden_states=contextualized_embeddings if output_hidden_states else None
761
+ )
762
+
763
+
764
+ class GptBertForCausalLM(GptBertModel):
765
+ _tied_weights_keys = ["classifier.emb2vocab.weight"]
766
+
767
+ def __init__(self, config: GptBertConfig, **kwargs):
768
+ config.is_decoder = True
769
+ super().__init__(config, add_mlm_layer=True, **kwargs)
770
+
771
+ def get_output_embeddings(self):
772
+ return self.classifier.emb2vocab.weight
773
+
774
+ def set_output_embeddings(self, new_embeddings):
775
+ self.classifier.emb2vocab.weight = new_embeddings
776
+
777
+ def get_input_embeddings(self):
778
+ return self.embedding.word_embedding
779
+
780
+ def set_input_embeddings(self, value):
781
+ self.embedding.word_embedding = value
782
+
783
+ def set_decoder(self, decoder):
784
+ self.encoder = decoder
785
+
786
+ def get_decoder(self):
787
+ return self.encoder
788
+
789
+ def can_generate(self):
790
+ return True
791
+
792
+ def forward(
793
+ self,
794
+ input_ids: torch.LongTensor = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ position_ids: Optional[torch.LongTensor] = None,
797
+ token_type_ids: Optional[torch.Tensor] = None,
798
+ past_key_values: Optional[torch.Tensor] = None,
799
+ inputs_embeds: Optional[torch.FloatTensor] = None,
800
+ labels: Optional[torch.LongTensor] = None,
801
+ use_cache: Optional[bool] = None,
802
+ cache_position: Optional[torch.LongTensor] = None,
803
+ output_attentions: Optional[bool] = None,
804
+ output_hidden_states: Optional[bool] = None,
805
+ return_dict: Optional[bool] = None
806
+ ) -> Union[Tuple, CausalLMOutput]:
807
+
808
+ assert inputs_embeds is None, "inputs_embeds is not supported for now"
809
+ assert past_key_values is None, "past_key_values is not supported for now"
810
+ assert not use_cache, "use_cache is not supported for now"
811
+
812
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
813
+ subword_prediction = self.classifier(sequence_output)
814
+ subword_prediction = 30 * torch.sigmoid(subword_prediction / 7.5)
815
+
816
+ causal_lm_loss = None
817
+ if labels is not None:
818
+ labels_flatten = labels[:, 1:].flatten()
819
+ subword_prediction_flatten = subword_prediction[:, :-1].flatten(0, 1)
820
+ causal_lm_loss = F.cross_entropy(subword_prediction_flatten, labels_flatten)
821
+
822
+ if not return_dict:
823
+ output = (
824
+ subword_prediction,
825
+ *([contextualized_embeddings] if output_hidden_states else [])
826
+ )
827
+ return ((causal_lm_loss,) + output) if masked_lm_loss is not None else output
828
+
829
+ return CausalLMOutput(
830
+ loss=causal_lm_loss,
831
+ logits=subword_prediction,
832
+ hidden_states=contextualized_embeddings if output_hidden_states else None
833
+ )
834
+
835
+ def prepare_inputs_for_generation(
836
+ self,
837
+ input_ids: torch.Tensor,
838
+ past_key_values: Optional[torch.Tensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ inputs_embeds: Optional[torch.Tensor] = None,
841
+ cache_position: Optional[torch.LongTensor] = None,
842
+ position_ids: Optional[torch.LongTensor] = None,
843
+ use_cache: bool = True,
844
+ num_logits_to_keep: Optional[int] = None,
845
+ **kwargs,
846
+ ):
847
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
848
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
849
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
850
+ if past_key_values is not None:
851
+ if inputs_embeds is not None: # Exception 1
852
+ input_ids = input_ids[:, -cache_position.shape[0] :]
853
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
854
+ input_ids = input_ids[:, cache_position]
855
+
856
+ if attention_mask is not None and position_ids is None:
857
+ # create position_ids on the fly for batch generation
858
+ position_ids = attention_mask.long().cumsum(-1) - 1
859
+ position_ids.masked_fill_(attention_mask == 0, 1)
860
+ if past_key_values:
861
+ position_ids = position_ids[:, -input_ids.shape[1] :]
862
+
863
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
864
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
865
+
866
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
867
+ if inputs_embeds is not None and cache_position[0] == 0:
868
+ model_inputs = {"inputs_embeds": inputs_embeds}
869
+ else:
870
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
871
+
872
+ if num_logits_to_keep is not None:
873
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
874
+
875
+ model_inputs.update(
876
+ {
877
+ "position_ids": position_ids,
878
+ "cache_position": cache_position,
879
+ "past_key_values": past_key_values,
880
+ "use_cache": use_cache,
881
+ "attention_mask": attention_mask,
882
+ }
883
+ )
884
+ return model_inputs
885
+
886
+
887
+ class GptBertForSequenceClassification(GptBertModel):
888
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
889
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
890
+
891
+ def __init__(self, config: GptBertConfig, **kwargs):
892
+ super().__init__(config, add_mlm_layer=False, **kwargs)
893
+
894
+ self.num_labels = config.num_labels
895
+ self.classifier = Classifier(config, self.num_labels)
896
+ self.post_init()
897
+
898
+ def forward(
899
+ self,
900
+ input_ids: Optional[torch.Tensor] = None,
901
+ attention_mask: Optional[torch.Tensor] = None,
902
+ output_hidden_states: Optional[bool] = None,
903
+ return_dict: Optional[bool] = None,
904
+ labels: Optional[torch.LongTensor] = None,
905
+ **kwargs
906
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
907
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
908
+
909
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
910
+ logits = self.classifier(sequence_output[:, 0, :])
911
+
912
+ loss = None
913
+ if labels is not None:
914
+ if self.config.problem_type is None:
915
+ if self.num_labels == 1:
916
+ self.config.problem_type = "regression"
917
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
918
+ self.config.problem_type = "single_label_classification"
919
+ else:
920
+ self.config.problem_type = "multi_label_classification"
921
+
922
+ if self.config.problem_type == "regression":
923
+ loss_fct = nn.MSELoss()
924
+ if self.num_labels == 1:
925
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
926
+ else:
927
+ loss = loss_fct(logits, labels)
928
+ elif self.config.problem_type == "single_label_classification":
929
+ loss_fct = nn.CrossEntropyLoss()
930
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
931
+ elif self.config.problem_type == "multi_label_classification":
932
+ loss_fct = nn.BCEWithLogitsLoss()
933
+ loss = loss_fct(logits, labels)
934
+
935
+ if not return_dict:
936
+ output = (
937
+ logits,
938
+ *([contextualized_embeddings] if output_hidden_states else [])
939
+ )
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return SequenceClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=contextualized_embeddings if output_hidden_states else None
946
+ )
947
+
948
+
949
+ class GptBertForTokenClassification(GptBertModel):
950
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
951
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
952
+
953
+ def __init__(self, config: GptBertConfig, **kwargs):
954
+ super().__init__(config, add_mlm_layer=False, **kwargs)
955
+
956
+ self.num_labels = config.num_labels
957
+ self.classifier = Classifier(config, self.num_labels)
958
+ self.post_init()
959
+
960
+ def forward(
961
+ self,
962
+ input_ids: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.Tensor] = None,
964
+ output_hidden_states: Optional[bool] = None,
965
+ return_dict: Optional[bool] = None,
966
+ labels: Optional[torch.LongTensor] = None,
967
+ **kwargs
968
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
969
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
970
+
971
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
972
+ logits = self.classifier(sequence_output)
973
+
974
+ loss = None
975
+ if labels is not None:
976
+ loss_fct = nn.CrossEntropyLoss()
977
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
978
+
979
+ if not return_dict:
980
+ output = (
981
+ logits,
982
+ *([contextualized_embeddings] if output_hidden_states else []),
983
+ *([attention_probs] if output_attentions else [])
984
+ )
985
+ return ((loss,) + output) if loss is not None else output
986
+
987
+ return TokenClassifierOutput(
988
+ loss=loss,
989
+ logits=logits,
990
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
991
+ attentions=attention_probs if output_attentions else None
992
+ )
993
+
994
+
995
+ class GptBertForQuestionAnswering(GptBertModel):
996
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
997
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
998
+
999
+ def __init__(self, config: GptBertConfig, **kwargs):
1000
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1001
+
1002
+ self.num_labels = config.num_labels
1003
+ self.classifier = Classifier(config, self.num_labels)
1004
+ self.post_init()
1005
+
1006
+ def forward(
1007
+ self,
1008
+ input_ids: Optional[torch.Tensor] = None,
1009
+ attention_mask: Optional[torch.Tensor] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ start_positions: Optional[torch.Tensor] = None,
1013
+ end_positions: Optional[torch.Tensor] = None,
1014
+ **kwargs
1015
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1016
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1017
+
1018
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(input_ids, attention_mask, output_hidden_states)
1019
+ logits = self.classifier(sequence_output)
1020
+
1021
+ start_logits, end_logits = logits.split(1, dim=-1)
1022
+ start_logits = start_logits.squeeze(-1).contiguous()
1023
+ end_logits = end_logits.squeeze(-1).contiguous()
1024
+
1025
+ total_loss = None
1026
+ if start_positions is not None and end_positions is not None:
1027
+ # If we are on multi-GPU, split add a dimension
1028
+ if len(start_positions.size()) > 1:
1029
+ start_positions = start_positions.squeeze(-1)
1030
+ if len(end_positions.size()) > 1:
1031
+ end_positions = end_positions.squeeze(-1)
1032
+
1033
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1034
+ ignored_index = start_logits.size(1)
1035
+ start_positions = start_positions.clamp(0, ignored_index)
1036
+ end_positions = end_positions.clamp(0, ignored_index)
1037
+
1038
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1039
+ start_loss = loss_fct(start_logits, start_positions)
1040
+ end_loss = loss_fct(end_logits, end_positions)
1041
+ total_loss = (start_loss + end_loss) / 2
1042
+
1043
+ if not return_dict:
1044
+ output = (
1045
+ start_logits,
1046
+ end_logits,
1047
+ *([contextualized_embeddings] if output_hidden_states else [])
1048
+ )
1049
+ return ((total_loss,) + output) if total_loss is not None else output
1050
+
1051
+ return QuestionAnsweringModelOutput(
1052
+ loss=total_loss,
1053
+ start_logits=start_logits,
1054
+ end_logits=end_logits,
1055
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1056
+ )
1057
+
1058
+
1059
+ class GptBertForMultipleChoice(GptBertModel):
1060
+ _keys_to_ignore_on_load_missing = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1061
+ _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1062
+
1063
+ def __init__(self, config: GptBertConfig, **kwargs):
1064
+ super().__init__(config, add_mlm_layer=False, **kwargs)
1065
+
1066
+ self.num_labels = getattr(config, "num_labels", 2)
1067
+ self.classifier = Classifier(config, self.num_labels)
1068
+ self.post_init()
1069
+
1070
+ def forward(
1071
+ self,
1072
+ input_ids: Optional[torch.Tensor] = None,
1073
+ attention_mask: Optional[torch.Tensor] = None,
1074
+ labels: Optional[torch.Tensor] = None,
1075
+ output_hidden_states: Optional[bool] = None,
1076
+ return_dict: Optional[bool] = None,
1077
+ **kwargs
1078
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1079
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1080
+ num_choices = input_ids.shape[1]
1081
+
1082
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1083
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1084
+
1085
+ sequence_output, contextualized_embeddings = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask, output_hidden_states)
1086
+ logits = self.classifier(sequence_output)
1087
+ reshaped_logits = logits.view(-1, num_choices)
1088
+
1089
+ loss = None
1090
+ if labels is not None:
1091
+ loss_fct = nn.CrossEntropyLoss()
1092
+ loss = loss_fct(reshaped_logits, labels)
1093
+
1094
+ if not return_dict:
1095
+ output = (
1096
+ reshaped_logits,
1097
+ *([contextualized_embeddings] if output_hidden_states else [])
1098
+ )
1099
+ return ((loss,) + output) if loss is not None else output
1100
+
1101
+ return MultipleChoiceModelOutput(
1102
+ loss=loss,
1103
+ logits=reshaped_logits,
1104
+ hidden_states=contextualized_embeddings if output_hidden_states else None
1105
+ )
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 75,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<pad>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "<special_0>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "<special_1>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "7": {
60
+ "content": "<special_2>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "8": {
68
+ "content": "<special_3>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "9": {
76
+ "content": "<special_4>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "10": {
84
+ "content": "<special_5>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "11": {
92
+ "content": "<special_6>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "12": {
100
+ "content": "<special_7>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "13": {
108
+ "content": "<special_8>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "14": {
116
+ "content": "<special_9>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "15": {
124
+ "content": "<special_10>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ }
131
+ },
132
+ "bos_token": "<s>",
133
+ "clean_up_tokenization_spaces": false,
134
+ "cls_token": "<s>",
135
+ "eos_token": "</s>",
136
+ "extra_special_tokens": {},
137
+ "mask_token": "<mask>",
138
+ "max_length": 128,
139
+ "model_max_length": 1000000000000000019884624838656,
140
+ "pad_to_multiple_of": null,
141
+ "pad_token": "<pad>",
142
+ "pad_token_type_id": 0,
143
+ "padding_side": "right",
144
+ "sep_token": "</s>",
145
+ "stride": 0,
146
+ "tokenizer_class": "PreTrainedTokenizerFast",
147
+ "truncation_side": "right",
148
+ "truncation_strategy": "longest_first",
149
+ "unk_token": "<unk>"
150
+ }