update
Browse files- README.md +16 -8
- all_results.json +12 -0
- config.json +3 -3
- eval_results.json +9 -9
- modeling_lsg_bart.py +77 -471
- pytorch_model.bin +1 -1
- tokenizer.json +1 -10
- tokenizer_config.json +1 -1
README.md
CHANGED
|
@@ -21,19 +21,26 @@ should probably proofread and complete it, then remove this comment. -->
|
|
| 21 |
# ccdv/lsg-bart-base-16384-arxiv
|
| 22 |
|
| 23 |
This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv) on the scientific_papers arxiv dataset. \
|
|
|
|
| 24 |
It achieves the following results on the test set:
|
| 25 |
|
| 26 |
-
| Length | Global tokens |
|
| 27 |
|:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
|
| 28 |
-
| 16384 | 64 |
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
## Model description
|
| 32 |
The model relies on Local-Sparse-Global attention to handle long sequences:
|
| 33 |

|
| 34 |
|
| 35 |
The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
|
| 36 |
-
The model is warm started from [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv), converted to handle long sequences (encoder only) and fine tuned.
|
| 37 |
|
| 38 |
## Intended uses & limitations
|
| 39 |
|
|
@@ -49,12 +56,13 @@ More information needed
|
|
| 49 |
|
| 50 |
The following hyperparameters were used during training:
|
| 51 |
- learning_rate: 8e-05
|
| 52 |
-
- train_batch_size:
|
| 53 |
- seed: 42
|
| 54 |
-
- gradient_accumulation_steps:
|
| 55 |
- total_train_batch_size: 32
|
| 56 |
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
| 57 |
- lr_scheduler_type: linear
|
|
|
|
| 58 |
- num_epochs: 1.0
|
| 59 |
|
| 60 |
### Generate hyperparameters
|
|
@@ -62,14 +70,14 @@ The following hyperparameters were used during training:
|
|
| 62 |
The following hyperparameters were used during generation:
|
| 63 |
- dataset_name: scientific_papers
|
| 64 |
- dataset_config_name: arxiv
|
| 65 |
-
- eval_batch_size:
|
|
|
|
| 66 |
- early_stopping: True
|
| 67 |
- ignore_pad_token_for_loss: True
|
| 68 |
- length_penalty: 2.0
|
| 69 |
- max_length: 320
|
| 70 |
-
- min_length:
|
| 71 |
- num_beams: 5
|
| 72 |
-
- num_samples: None
|
| 73 |
- no_repeat_ngram_size: None
|
| 74 |
- seed: 123
|
| 75 |
|
|
|
|
| 21 |
# ccdv/lsg-bart-base-16384-arxiv
|
| 22 |
|
| 23 |
This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv) on the scientific_papers arxiv dataset. \
|
| 24 |
+
The model is converted to handle 16384 long sequences and fine-tuned accordingly during 1 epoch. \
|
| 25 |
It achieves the following results on the test set:
|
| 26 |
|
| 27 |
+
| Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
|
| 28 |
|:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
|
| 29 |
+
| 16384 | 64 | Full | 256 | 0 | 768 | 48.74 | 20.88 | 28.50 | 44.23 |
|
| 30 |
+
| 16384 | 64 | Global only | 256 | 0 | 768 | 48.08 | 20.42 | 28.00 | 43.65 |
|
| 31 |
+
| 16384 | 1 | None | 256 | 0 | 768 | 47.03 | 20.19 | 28.26 | 42.69 |
|
| 32 |
|
| 33 |
+
Reference model:
|
| 34 |
+
| Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
|
| 35 |
+
|:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
|
| 36 |
+
| 4096 | 1 | - | 256 | 0 | 768 | 46.65 | 18.91 | 26.90 | 42.18 |
|
| 37 |
|
| 38 |
## Model description
|
| 39 |
The model relies on Local-Sparse-Global attention to handle long sequences:
|
| 40 |

|
| 41 |
|
| 42 |
The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
|
| 43 |
+
The model is warm started from [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv), converted to handle long sequences (encoder only) and fine tuned.
|
| 44 |
|
| 45 |
## Intended uses & limitations
|
| 46 |
|
|
|
|
| 56 |
|
| 57 |
The following hyperparameters were used during training:
|
| 58 |
- learning_rate: 8e-05
|
| 59 |
+
- train_batch_size: 8
|
| 60 |
- seed: 42
|
| 61 |
+
- gradient_accumulation_steps: 4
|
| 62 |
- total_train_batch_size: 32
|
| 63 |
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
| 64 |
- lr_scheduler_type: linear
|
| 65 |
+
- lr_scheduler_warmup_ratio: 0.1
|
| 66 |
- num_epochs: 1.0
|
| 67 |
|
| 68 |
### Generate hyperparameters
|
|
|
|
| 70 |
The following hyperparameters were used during generation:
|
| 71 |
- dataset_name: scientific_papers
|
| 72 |
- dataset_config_name: arxiv
|
| 73 |
+
- eval_batch_size: 8
|
| 74 |
+
- eval_samples: 6440
|
| 75 |
- early_stopping: True
|
| 76 |
- ignore_pad_token_for_loss: True
|
| 77 |
- length_penalty: 2.0
|
| 78 |
- max_length: 320
|
| 79 |
+
- min_length: 32
|
| 80 |
- num_beams: 5
|
|
|
|
| 81 |
- no_repeat_ngram_size: None
|
| 82 |
- seed: 123
|
| 83 |
|
all_results.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"eval_gen_len": 215.796,
|
| 3 |
+
"eval_loss": 1.7052853107452393,
|
| 4 |
+
"eval_rouge1": 48.7438,
|
| 5 |
+
"eval_rouge2": 20.88,
|
| 6 |
+
"eval_rougeL": 28.4965,
|
| 7 |
+
"eval_rougeLsum": 44.2266,
|
| 8 |
+
"eval_runtime": 18597.9286,
|
| 9 |
+
"eval_samples": 6440,
|
| 10 |
+
"eval_samples_per_second": 0.346,
|
| 11 |
+
"eval_steps_per_second": 0.087
|
| 12 |
+
}
|
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "
|
| 3 |
"activation_dropout": 0.1,
|
| 4 |
"activation_function": "gelu",
|
| 5 |
"adaptive": true,
|
|
@@ -68,7 +68,7 @@
|
|
| 68 |
"scale_embedding": false,
|
| 69 |
"sparse_block_size": 0,
|
| 70 |
"sparsity_factor": 4,
|
| 71 |
-
"sparsity_type": "
|
| 72 |
"task_specific_params": {
|
| 73 |
"summarization": {
|
| 74 |
"length_penalty": 1.0,
|
|
@@ -90,7 +90,7 @@
|
|
| 90 |
}
|
| 91 |
},
|
| 92 |
"torch_dtype": "float32",
|
| 93 |
-
"transformers_version": "4.
|
| 94 |
"use_cache": true,
|
| 95 |
"vocab_size": 50265
|
| 96 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local_16384_trained",
|
| 3 |
"activation_dropout": 0.1,
|
| 4 |
"activation_function": "gelu",
|
| 5 |
"adaptive": true,
|
|
|
|
| 68 |
"scale_embedding": false,
|
| 69 |
"sparse_block_size": 0,
|
| 70 |
"sparsity_factor": 4,
|
| 71 |
+
"sparsity_type": "none",
|
| 72 |
"task_specific_params": {
|
| 73 |
"summarization": {
|
| 74 |
"length_penalty": 1.0,
|
|
|
|
| 90 |
}
|
| 91 |
},
|
| 92 |
"torch_dtype": "float32",
|
| 93 |
+
"transformers_version": "4.19.2",
|
| 94 |
"use_cache": true,
|
| 95 |
"vocab_size": 50265
|
| 96 |
}
|
eval_results.json
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
{
|
| 2 |
-
"eval_gen_len":
|
| 3 |
-
"eval_loss": 1.
|
| 4 |
-
"eval_rouge1": 48.
|
| 5 |
-
"eval_rouge2": 20.
|
| 6 |
-
"eval_rougeL": 28.
|
| 7 |
-
"eval_rougeLsum": 44.
|
| 8 |
-
"eval_runtime":
|
| 9 |
"eval_samples": 6440,
|
| 10 |
-
"eval_samples_per_second": 0.
|
| 11 |
-
"eval_steps_per_second": 0.
|
| 12 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"eval_gen_len": 215.796,
|
| 3 |
+
"eval_loss": 1.7052853107452393,
|
| 4 |
+
"eval_rouge1": 48.7438,
|
| 5 |
+
"eval_rouge2": 20.88,
|
| 6 |
+
"eval_rougeL": 28.4965,
|
| 7 |
+
"eval_rougeLsum": 44.2266,
|
| 8 |
+
"eval_runtime": 18597.9286,
|
| 9 |
"eval_samples": 6440,
|
| 10 |
+
"eval_samples_per_second": 0.346,
|
| 11 |
+
"eval_steps_per_second": 0.087
|
| 12 |
}
|
modeling_lsg_bart.py
CHANGED
|
@@ -54,17 +54,32 @@ class LSGBartConfig(BartConfig):
|
|
| 54 |
self.sparsity_factor = sparsity_factor
|
| 55 |
self.sparsity_type = sparsity_type
|
| 56 |
|
| 57 |
-
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
|
| 58 |
logger.warning(
|
| 59 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
|
| 60 |
self.sparsity_type = None
|
| 61 |
|
| 62 |
-
if self.sparsity_type
|
| 63 |
if self.sparsity_factor > self.encoder_attention_heads:
|
| 64 |
logger.warning(
|
| 65 |
-
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
|
| 66 |
)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
| 70 |
"""
|
|
@@ -217,8 +232,6 @@ class LSGAttentionProduct(nn.Module):
|
|
| 217 |
# Shape of blocks
|
| 218 |
self.local_shapes = (self.block_size*3, self.block_size)
|
| 219 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
| 220 |
-
assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
|
| 221 |
-
assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
|
| 222 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
| 223 |
|
| 224 |
self.attention = BaseAttentionProduct(config)
|
|
@@ -399,6 +412,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 399 |
"pooling": self.get_sparse_tokens_with_pooling,
|
| 400 |
"lsh": self.get_sparse_tokens_with_lsh,
|
| 401 |
"stride": self.get_sparse_tokens_with_stride,
|
|
|
|
| 402 |
}
|
| 403 |
|
| 404 |
self.sparsity_type = config.sparsity_type
|
|
@@ -410,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 410 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
| 411 |
|
| 412 |
if self.sparsity_factor == 1:
|
| 413 |
-
return keys, values, mask
|
| 414 |
|
| 415 |
with torch.no_grad():
|
| 416 |
|
|
@@ -438,7 +452,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 438 |
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
| 439 |
|
| 440 |
if self.sparsity_factor == 1:
|
| 441 |
-
return keys, values, mask
|
| 442 |
|
| 443 |
keys = self.chunk(keys, self.sparsity_factor)
|
| 444 |
values = self.chunk(values, self.sparsity_factor)
|
|
@@ -460,7 +474,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 460 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
| 461 |
|
| 462 |
if self.sparsity_factor == 1:
|
| 463 |
-
return keys, values, mask
|
| 464 |
|
| 465 |
n, h, t, d = keys.size()
|
| 466 |
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
|
|
@@ -473,10 +487,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 473 |
|
| 474 |
return keys, values, mask
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
| 477 |
|
| 478 |
if self.sparsity_factor == 1:
|
| 479 |
-
return keys, values, mask
|
| 480 |
|
| 481 |
block_size = min(self.block_size, self.sparse_block_size)
|
| 482 |
keys = self.chunk(keys, block_size)
|
|
@@ -1307,6 +1341,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
| 1307 |
self.padding_idx = config.pad_token_id
|
| 1308 |
self.max_target_positions = config.max_position_embeddings
|
| 1309 |
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
|
|
| 1310 |
|
| 1311 |
if embed_tokens is not None:
|
| 1312 |
self.embed_tokens = embed_tokens
|
|
@@ -1349,6 +1384,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
| 1349 |
|
| 1350 |
return combined_attention_mask
|
| 1351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1352 |
def forward(
|
| 1353 |
self,
|
| 1354 |
input_ids=None,
|
|
@@ -1389,12 +1433,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
| 1389 |
if inputs_embeds is None:
|
| 1390 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 1391 |
|
| 1392 |
-
#
|
| 1393 |
-
|
| 1394 |
-
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
-
|
|
|
|
|
|
|
| 1398 |
|
| 1399 |
attention_mask = self._prepare_decoder_attention_mask(
|
| 1400 |
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
@@ -1488,6 +1534,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
| 1488 |
if encoder_hidden_states is not None:
|
| 1489 |
all_cross_attentions += (layer_outputs[2],)
|
| 1490 |
|
|
|
|
|
|
|
|
|
|
| 1491 |
# add hidden states from the last decoder layer
|
| 1492 |
if output_hidden_states:
|
| 1493 |
all_hidden_states += (hidden_states,)
|
|
@@ -1624,14 +1673,14 @@ class LSGBartModel(LSGBartPretrainedModel):
|
|
| 1624 |
)
|
| 1625 |
|
| 1626 |
|
| 1627 |
-
class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
|
| 1628 |
|
| 1629 |
base_model_prefix = "model"
|
| 1630 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
| 1631 |
|
| 1632 |
def __init__(self, config):
|
| 1633 |
|
| 1634 |
-
|
| 1635 |
self.model = LSGBartModel(config)
|
| 1636 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
| 1637 |
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
|
@@ -1639,157 +1688,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
|
|
| 1639 |
# Initialize weights and apply final processing
|
| 1640 |
self.post_init()
|
| 1641 |
|
| 1642 |
-
def get_encoder(self):
|
| 1643 |
-
return self.model.get_encoder()
|
| 1644 |
-
|
| 1645 |
-
def get_decoder(self):
|
| 1646 |
-
return self.model.get_decoder()
|
| 1647 |
-
|
| 1648 |
-
def resize_token_embeddings(self, new_num_tokens):
|
| 1649 |
-
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
| 1650 |
-
self._resize_final_logits_bias(new_num_tokens)
|
| 1651 |
-
return new_embeddings
|
| 1652 |
-
|
| 1653 |
-
def _resize_final_logits_bias(self, new_num_tokens):
|
| 1654 |
-
old_num_tokens = self.final_logits_bias.shape[-1]
|
| 1655 |
-
if new_num_tokens <= old_num_tokens:
|
| 1656 |
-
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
| 1657 |
-
else:
|
| 1658 |
-
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
| 1659 |
-
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
| 1660 |
-
self.register_buffer("final_logits_bias", new_bias)
|
| 1661 |
-
|
| 1662 |
-
def get_output_embeddings(self):
|
| 1663 |
-
return self.lm_head
|
| 1664 |
-
|
| 1665 |
-
def set_output_embeddings(self, new_embeddings):
|
| 1666 |
-
self.lm_head = new_embeddings
|
| 1667 |
-
|
| 1668 |
-
def forward(
|
| 1669 |
-
self,
|
| 1670 |
-
input_ids=None,
|
| 1671 |
-
attention_mask=None,
|
| 1672 |
-
decoder_input_ids=None,
|
| 1673 |
-
decoder_attention_mask=None,
|
| 1674 |
-
head_mask=None,
|
| 1675 |
-
decoder_head_mask=None,
|
| 1676 |
-
cross_attn_head_mask=None,
|
| 1677 |
-
encoder_outputs=None,
|
| 1678 |
-
past_key_values=None,
|
| 1679 |
-
inputs_embeds=None,
|
| 1680 |
-
decoder_inputs_embeds=None,
|
| 1681 |
-
labels=None,
|
| 1682 |
-
use_cache=None,
|
| 1683 |
-
output_attentions=None,
|
| 1684 |
-
output_hidden_states=None,
|
| 1685 |
-
return_dict=None,
|
| 1686 |
-
):
|
| 1687 |
-
|
| 1688 |
-
r"""
|
| 1689 |
-
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 1690 |
-
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
| 1691 |
-
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
|
| 1692 |
-
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
|
| 1693 |
-
Returns:
|
| 1694 |
-
"""
|
| 1695 |
|
| 1696 |
-
|
| 1697 |
|
| 1698 |
-
|
| 1699 |
-
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 1700 |
-
decoder_input_ids = shift_tokens_right(
|
| 1701 |
-
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
| 1702 |
-
)
|
| 1703 |
|
| 1704 |
-
|
| 1705 |
-
input_ids,
|
| 1706 |
-
attention_mask=attention_mask,
|
| 1707 |
-
decoder_input_ids=decoder_input_ids,
|
| 1708 |
-
encoder_outputs=encoder_outputs,
|
| 1709 |
-
decoder_attention_mask=decoder_attention_mask,
|
| 1710 |
-
head_mask=head_mask,
|
| 1711 |
-
decoder_head_mask=decoder_head_mask,
|
| 1712 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
| 1713 |
-
past_key_values=past_key_values,
|
| 1714 |
-
inputs_embeds=inputs_embeds,
|
| 1715 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1716 |
-
use_cache=use_cache,
|
| 1717 |
-
output_attentions=output_attentions,
|
| 1718 |
-
output_hidden_states=output_hidden_states,
|
| 1719 |
-
return_dict=return_dict,
|
| 1720 |
-
)
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
| 1724 |
-
|
| 1725 |
-
masked_lm_loss = None
|
| 1726 |
-
if labels is not None:
|
| 1727 |
-
loss_fct = CrossEntropyLoss()
|
| 1728 |
-
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1729 |
-
|
| 1730 |
-
if not return_dict:
|
| 1731 |
-
output = (lm_logits,) + outputs[1:]
|
| 1732 |
-
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1733 |
-
|
| 1734 |
-
return Seq2SeqLMOutput(
|
| 1735 |
-
loss=masked_lm_loss,
|
| 1736 |
-
logits=lm_logits,
|
| 1737 |
-
past_key_values=outputs.past_key_values,
|
| 1738 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1739 |
-
decoder_attentions=outputs.decoder_attentions,
|
| 1740 |
-
cross_attentions=outputs.cross_attentions,
|
| 1741 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1742 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1743 |
-
encoder_attentions=outputs.encoder_attentions,
|
| 1744 |
-
)
|
| 1745 |
-
|
| 1746 |
-
def prepare_inputs_for_generation(
|
| 1747 |
-
self,
|
| 1748 |
-
decoder_input_ids,
|
| 1749 |
-
past=None,
|
| 1750 |
-
attention_mask=None,
|
| 1751 |
-
head_mask=None,
|
| 1752 |
-
decoder_head_mask=None,
|
| 1753 |
-
cross_attn_head_mask=None,
|
| 1754 |
-
use_cache=None,
|
| 1755 |
-
encoder_outputs=None,
|
| 1756 |
-
**kwargs
|
| 1757 |
-
):
|
| 1758 |
-
# cut decoder_input_ids if past is used
|
| 1759 |
-
if past is not None:
|
| 1760 |
-
decoder_input_ids = decoder_input_ids[:, -1:]
|
| 1761 |
-
|
| 1762 |
-
return {
|
| 1763 |
-
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
| 1764 |
-
"encoder_outputs": encoder_outputs,
|
| 1765 |
-
"past_key_values": past,
|
| 1766 |
-
"decoder_input_ids": decoder_input_ids,
|
| 1767 |
-
"attention_mask": attention_mask,
|
| 1768 |
-
"head_mask": head_mask,
|
| 1769 |
-
"decoder_head_mask": decoder_head_mask,
|
| 1770 |
-
"cross_attn_head_mask": cross_attn_head_mask,
|
| 1771 |
-
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
| 1772 |
-
}
|
| 1773 |
-
|
| 1774 |
-
def prepare_decoder_input_ids_from_labels(self, labels):
|
| 1775 |
-
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
| 1776 |
-
|
| 1777 |
-
@staticmethod
|
| 1778 |
-
def _reorder_cache(past, beam_idx):
|
| 1779 |
-
reordered_past = ()
|
| 1780 |
-
for layer_past in past:
|
| 1781 |
-
# cached cross_attention states don't have to be reordered -> they are always the same
|
| 1782 |
-
reordered_past += (
|
| 1783 |
-
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
| 1784 |
-
)
|
| 1785 |
-
return reordered_past
|
| 1786 |
-
|
| 1787 |
-
|
| 1788 |
-
class LSGBartForSequenceClassification(LSGBartPretrainedModel):
|
| 1789 |
-
|
| 1790 |
-
def __init__(self, config, **kwargs):
|
| 1791 |
-
|
| 1792 |
-
super().__init__(config, **kwargs)
|
| 1793 |
self.model = LSGBartModel(config)
|
| 1794 |
self.classification_head = LSGBartClassificationHead(
|
| 1795 |
config.d_model,
|
|
@@ -1800,115 +1704,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
|
|
| 1800 |
self.model._init_weights(self.classification_head.dense)
|
| 1801 |
self.model._init_weights(self.classification_head.out_proj)
|
| 1802 |
|
| 1803 |
-
def forward(
|
| 1804 |
-
self,
|
| 1805 |
-
input_ids=None,
|
| 1806 |
-
attention_mask=None,
|
| 1807 |
-
decoder_input_ids=None,
|
| 1808 |
-
decoder_attention_mask=None,
|
| 1809 |
-
head_mask=None,
|
| 1810 |
-
decoder_head_mask=None,
|
| 1811 |
-
cross_attn_head_mask=None,
|
| 1812 |
-
encoder_outputs=None,
|
| 1813 |
-
inputs_embeds=None,
|
| 1814 |
-
decoder_inputs_embeds=None,
|
| 1815 |
-
labels=None,
|
| 1816 |
-
use_cache=None,
|
| 1817 |
-
output_attentions=None,
|
| 1818 |
-
output_hidden_states=None,
|
| 1819 |
-
return_dict=None,
|
| 1820 |
-
):
|
| 1821 |
-
|
| 1822 |
-
r"""
|
| 1823 |
-
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
| 1824 |
-
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
| 1825 |
-
config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1826 |
-
"""
|
| 1827 |
-
|
| 1828 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1829 |
-
if labels is not None:
|
| 1830 |
-
use_cache = False
|
| 1831 |
-
|
| 1832 |
-
if input_ids is None and inputs_embeds is not None:
|
| 1833 |
-
raise NotImplementedError(
|
| 1834 |
-
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
| 1835 |
-
)
|
| 1836 |
-
|
| 1837 |
-
outputs = self.model(
|
| 1838 |
-
input_ids,
|
| 1839 |
-
attention_mask=attention_mask,
|
| 1840 |
-
decoder_input_ids=decoder_input_ids,
|
| 1841 |
-
decoder_attention_mask=decoder_attention_mask,
|
| 1842 |
-
head_mask=head_mask,
|
| 1843 |
-
decoder_head_mask=decoder_head_mask,
|
| 1844 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
| 1845 |
-
encoder_outputs=encoder_outputs,
|
| 1846 |
-
inputs_embeds=inputs_embeds,
|
| 1847 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1848 |
-
use_cache=use_cache,
|
| 1849 |
-
output_attentions=output_attentions,
|
| 1850 |
-
output_hidden_states=output_hidden_states,
|
| 1851 |
-
return_dict=return_dict,
|
| 1852 |
-
)
|
| 1853 |
-
hidden_states = outputs[0] # last hidden state
|
| 1854 |
-
|
| 1855 |
-
eos_mask = input_ids.eq(self.config.eos_token_id)
|
| 1856 |
-
|
| 1857 |
-
t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
|
| 1858 |
-
if t > t_:
|
| 1859 |
-
eos_mask = eos_mask[:, :t_]
|
| 1860 |
-
|
| 1861 |
-
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
| 1862 |
-
raise ValueError("All examples must have the same number of <eos> tokens.")
|
| 1863 |
-
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
| 1864 |
-
:, -1, :
|
| 1865 |
-
]
|
| 1866 |
-
logits = self.classification_head(sentence_representation)
|
| 1867 |
-
|
| 1868 |
-
loss = None
|
| 1869 |
-
if labels is not None:
|
| 1870 |
-
if self.config.problem_type is None:
|
| 1871 |
-
if self.config.num_labels == 1:
|
| 1872 |
-
self.config.problem_type = "regression"
|
| 1873 |
-
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1874 |
-
self.config.problem_type = "single_label_classification"
|
| 1875 |
-
else:
|
| 1876 |
-
self.config.problem_type = "multi_label_classification"
|
| 1877 |
-
|
| 1878 |
-
if self.config.problem_type == "regression":
|
| 1879 |
-
loss_fct = MSELoss()
|
| 1880 |
-
if self.config.num_labels == 1:
|
| 1881 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1882 |
-
else:
|
| 1883 |
-
loss = loss_fct(logits, labels)
|
| 1884 |
-
elif self.config.problem_type == "single_label_classification":
|
| 1885 |
-
loss_fct = CrossEntropyLoss()
|
| 1886 |
-
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1887 |
-
elif self.config.problem_type == "multi_label_classification":
|
| 1888 |
-
loss_fct = BCEWithLogitsLoss()
|
| 1889 |
-
loss = loss_fct(logits, labels)
|
| 1890 |
-
if not return_dict:
|
| 1891 |
-
output = (logits,) + outputs[1:]
|
| 1892 |
-
return ((loss,) + output) if loss is not None else output
|
| 1893 |
-
|
| 1894 |
-
return Seq2SeqSequenceClassifierOutput(
|
| 1895 |
-
loss=loss,
|
| 1896 |
-
logits=logits,
|
| 1897 |
-
past_key_values=outputs.past_key_values,
|
| 1898 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1899 |
-
decoder_attentions=outputs.decoder_attentions,
|
| 1900 |
-
cross_attentions=outputs.cross_attentions,
|
| 1901 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1902 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1903 |
-
encoder_attentions=outputs.encoder_attentions,
|
| 1904 |
-
)
|
| 1905 |
|
|
|
|
| 1906 |
|
| 1907 |
-
|
| 1908 |
|
| 1909 |
-
|
| 1910 |
-
|
| 1911 |
-
super().__init__(config)
|
| 1912 |
|
| 1913 |
config.num_labels = 2
|
| 1914 |
self.num_labels = config.num_labels
|
|
@@ -1918,102 +1719,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
|
|
| 1918 |
|
| 1919 |
self.model._init_weights(self.qa_outputs)
|
| 1920 |
|
| 1921 |
-
def forward(
|
| 1922 |
-
self,
|
| 1923 |
-
input_ids=None,
|
| 1924 |
-
attention_mask=None,
|
| 1925 |
-
decoder_input_ids=None,
|
| 1926 |
-
decoder_attention_mask=None,
|
| 1927 |
-
head_mask=None,
|
| 1928 |
-
decoder_head_mask=None,
|
| 1929 |
-
cross_attn_head_mask=None,
|
| 1930 |
-
encoder_outputs=None,
|
| 1931 |
-
start_positions=None,
|
| 1932 |
-
end_positions=None,
|
| 1933 |
-
inputs_embeds=None,
|
| 1934 |
-
decoder_inputs_embeds=None,
|
| 1935 |
-
use_cache=None,
|
| 1936 |
-
output_attentions=None,
|
| 1937 |
-
output_hidden_states=None,
|
| 1938 |
-
return_dict=None,
|
| 1939 |
-
):
|
| 1940 |
-
|
| 1941 |
-
r"""
|
| 1942 |
-
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
| 1943 |
-
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1944 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1945 |
-
are not taken into account for computing the loss.
|
| 1946 |
-
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
| 1947 |
-
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1948 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1949 |
-
are not taken into account for computing the loss.
|
| 1950 |
-
"""
|
| 1951 |
-
|
| 1952 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1953 |
-
if start_positions is not None and end_positions is not None:
|
| 1954 |
-
use_cache = False
|
| 1955 |
-
|
| 1956 |
-
outputs = self.model(
|
| 1957 |
-
input_ids,
|
| 1958 |
-
attention_mask=attention_mask,
|
| 1959 |
-
decoder_input_ids=decoder_input_ids,
|
| 1960 |
-
decoder_attention_mask=decoder_attention_mask,
|
| 1961 |
-
head_mask=head_mask,
|
| 1962 |
-
decoder_head_mask=decoder_head_mask,
|
| 1963 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
| 1964 |
-
encoder_outputs=encoder_outputs,
|
| 1965 |
-
inputs_embeds=inputs_embeds,
|
| 1966 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1967 |
-
use_cache=use_cache,
|
| 1968 |
-
output_attentions=output_attentions,
|
| 1969 |
-
output_hidden_states=output_hidden_states,
|
| 1970 |
-
return_dict=return_dict,
|
| 1971 |
-
)
|
| 1972 |
-
|
| 1973 |
-
sequence_output = outputs[0]
|
| 1974 |
-
|
| 1975 |
-
logits = self.qa_outputs(sequence_output)
|
| 1976 |
-
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1977 |
-
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1978 |
-
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1979 |
-
|
| 1980 |
-
total_loss = None
|
| 1981 |
-
if start_positions is not None and end_positions is not None:
|
| 1982 |
-
# If we are on multi-GPU, split add a dimension
|
| 1983 |
-
if len(start_positions.size()) > 1:
|
| 1984 |
-
start_positions = start_positions.squeeze(-1)
|
| 1985 |
-
if len(end_positions.size()) > 1:
|
| 1986 |
-
end_positions = end_positions.squeeze(-1)
|
| 1987 |
-
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1988 |
-
ignored_index = start_logits.size(1)
|
| 1989 |
-
start_positions = start_positions.clamp(0, ignored_index)
|
| 1990 |
-
end_positions = end_positions.clamp(0, ignored_index)
|
| 1991 |
-
|
| 1992 |
-
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1993 |
-
start_loss = loss_fct(start_logits, start_positions)
|
| 1994 |
-
end_loss = loss_fct(end_logits, end_positions)
|
| 1995 |
-
total_loss = (start_loss + end_loss) / 2
|
| 1996 |
-
|
| 1997 |
-
if not return_dict:
|
| 1998 |
-
output = (
|
| 1999 |
-
start_logits,
|
| 2000 |
-
end_logits,
|
| 2001 |
-
) + outputs[1:]
|
| 2002 |
-
return ((total_loss,) + output) if total_loss is not None else output
|
| 2003 |
-
|
| 2004 |
-
return Seq2SeqQuestionAnsweringModelOutput(
|
| 2005 |
-
loss=total_loss,
|
| 2006 |
-
start_logits=start_logits,
|
| 2007 |
-
end_logits=end_logits,
|
| 2008 |
-
past_key_values=outputs.past_key_values,
|
| 2009 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 2010 |
-
decoder_attentions=outputs.decoder_attentions,
|
| 2011 |
-
cross_attentions=outputs.cross_attentions,
|
| 2012 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 2013 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 2014 |
-
encoder_attentions=outputs.encoder_attentions,
|
| 2015 |
-
)
|
| 2016 |
-
|
| 2017 |
|
| 2018 |
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
| 2019 |
"""
|
|
@@ -2021,7 +1726,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
| 2021 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
| 2022 |
"""
|
| 2023 |
|
| 2024 |
-
def __init__(self, config):
|
| 2025 |
super().__init__(config)
|
| 2026 |
self.decoder = LSGBartDecoder(config)
|
| 2027 |
|
|
@@ -2029,14 +1734,14 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
| 2029 |
return self.decoder(*args, **kwargs)
|
| 2030 |
|
| 2031 |
|
| 2032 |
-
class LSGBartForCausalLM(LSGBartPretrainedModel):
|
| 2033 |
|
| 2034 |
-
def __init__(self, config):
|
| 2035 |
|
| 2036 |
-
super().__init__(config)
|
| 2037 |
config = copy.deepcopy(config)
|
| 2038 |
config.is_decoder = True
|
| 2039 |
config.is_encoder_decoder = False
|
|
|
|
| 2040 |
self.model = LSGBartDecoderWrapper(config)
|
| 2041 |
|
| 2042 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
@@ -2044,105 +1749,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
|
|
| 2044 |
# Initialize weights and apply final processing
|
| 2045 |
self.post_init()
|
| 2046 |
|
| 2047 |
-
def get_input_embeddings(self):
|
| 2048 |
-
return self.model.decoder.embed_tokens
|
| 2049 |
-
|
| 2050 |
-
def set_input_embeddings(self, value):
|
| 2051 |
-
self.model.decoder.embed_tokens = value
|
| 2052 |
-
|
| 2053 |
-
def get_output_embeddings(self):
|
| 2054 |
-
return self.lm_head
|
| 2055 |
-
|
| 2056 |
-
def set_output_embeddings(self, new_embeddings):
|
| 2057 |
-
self.lm_head = new_embeddings
|
| 2058 |
-
|
| 2059 |
-
def set_decoder(self, decoder):
|
| 2060 |
-
self.model.decoder = decoder
|
| 2061 |
-
|
| 2062 |
-
def get_decoder(self):
|
| 2063 |
-
return self.model.decoder
|
| 2064 |
-
|
| 2065 |
-
def forward(
|
| 2066 |
-
self,
|
| 2067 |
-
input_ids=None,
|
| 2068 |
-
attention_mask=None,
|
| 2069 |
-
encoder_hidden_states=None,
|
| 2070 |
-
encoder_attention_mask=None,
|
| 2071 |
-
head_mask=None,
|
| 2072 |
-
cross_attn_head_mask=None,
|
| 2073 |
-
past_key_values=None,
|
| 2074 |
-
inputs_embeds=None,
|
| 2075 |
-
labels=None,
|
| 2076 |
-
use_cache=None,
|
| 2077 |
-
output_attentions=None,
|
| 2078 |
-
output_hidden_states=None,
|
| 2079 |
-
return_dict=None,
|
| 2080 |
-
):
|
| 2081 |
-
|
| 2082 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2083 |
-
output_hidden_states = (
|
| 2084 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2085 |
-
)
|
| 2086 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2087 |
-
|
| 2088 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 2089 |
-
outputs = self.model.decoder(
|
| 2090 |
-
input_ids=input_ids,
|
| 2091 |
-
attention_mask=attention_mask,
|
| 2092 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 2093 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 2094 |
-
head_mask=head_mask,
|
| 2095 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
| 2096 |
-
past_key_values=past_key_values,
|
| 2097 |
-
inputs_embeds=inputs_embeds,
|
| 2098 |
-
use_cache=use_cache,
|
| 2099 |
-
output_attentions=output_attentions,
|
| 2100 |
-
output_hidden_states=output_hidden_states,
|
| 2101 |
-
return_dict=return_dict,
|
| 2102 |
-
)
|
| 2103 |
-
|
| 2104 |
-
logits = self.lm_head(outputs[0])
|
| 2105 |
-
|
| 2106 |
-
loss = None
|
| 2107 |
-
if labels is not None:
|
| 2108 |
-
loss_fct = CrossEntropyLoss()
|
| 2109 |
-
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 2110 |
-
|
| 2111 |
-
if not return_dict:
|
| 2112 |
-
output = (logits,) + outputs[1:]
|
| 2113 |
-
return (loss,) + output if loss is not None else output
|
| 2114 |
-
|
| 2115 |
-
return CausalLMOutputWithCrossAttentions(
|
| 2116 |
-
loss=loss,
|
| 2117 |
-
logits=logits,
|
| 2118 |
-
past_key_values=outputs.past_key_values,
|
| 2119 |
-
hidden_states=outputs.hidden_states,
|
| 2120 |
-
attentions=outputs.attentions,
|
| 2121 |
-
cross_attentions=outputs.cross_attentions,
|
| 2122 |
-
)
|
| 2123 |
-
|
| 2124 |
-
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
| 2125 |
-
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 2126 |
-
if attention_mask is None:
|
| 2127 |
-
attention_mask = input_ids.new_ones(input_ids.shape)
|
| 2128 |
-
|
| 2129 |
-
if past:
|
| 2130 |
-
input_ids = input_ids[:, -1:]
|
| 2131 |
-
# first step, decoder_cached_states are empty
|
| 2132 |
-
return {
|
| 2133 |
-
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
| 2134 |
-
"attention_mask": attention_mask,
|
| 2135 |
-
"past_key_values": past,
|
| 2136 |
-
"use_cache": use_cache,
|
| 2137 |
-
}
|
| 2138 |
-
|
| 2139 |
-
@staticmethod
|
| 2140 |
-
def _reorder_cache(past, beam_idx):
|
| 2141 |
-
reordered_past = ()
|
| 2142 |
-
for layer_past in past:
|
| 2143 |
-
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 2144 |
-
return reordered_past
|
| 2145 |
-
|
| 2146 |
|
| 2147 |
def str_to_class(classname):
|
| 2148 |
return getattr(sys.modules[__name__], classname)
|
|
|
|
| 54 |
self.sparsity_factor = sparsity_factor
|
| 55 |
self.sparsity_type = sparsity_type
|
| 56 |
|
| 57 |
+
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 58 |
logger.warning(
|
| 59 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
|
| 60 |
self.sparsity_type = None
|
| 61 |
|
| 62 |
+
if self.sparsity_type in ["stride", "block_stride"]:
|
| 63 |
if self.sparsity_factor > self.encoder_attention_heads:
|
| 64 |
logger.warning(
|
| 65 |
+
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
| 66 |
)
|
| 67 |
|
| 68 |
+
if self.num_global_tokens < 1:
|
| 69 |
+
logger.warning(
|
| 70 |
+
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
| 71 |
+
)
|
| 72 |
+
self.num_global_tokens = 1
|
| 73 |
+
elif self.num_global_tokens > 512:
|
| 74 |
+
logger.warning(
|
| 75 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
|
| 76 |
+
)
|
| 77 |
+
self.num_global_tokens = 512
|
| 78 |
+
|
| 79 |
+
if self.sparsity_factor > 0:
|
| 80 |
+
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
| 81 |
+
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 82 |
+
|
| 83 |
|
| 84 |
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
| 85 |
"""
|
|
|
|
| 232 |
# Shape of blocks
|
| 233 |
self.local_shapes = (self.block_size*3, self.block_size)
|
| 234 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
|
|
|
|
|
|
| 235 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
| 236 |
|
| 237 |
self.attention = BaseAttentionProduct(config)
|
|
|
|
| 412 |
"pooling": self.get_sparse_tokens_with_pooling,
|
| 413 |
"lsh": self.get_sparse_tokens_with_lsh,
|
| 414 |
"stride": self.get_sparse_tokens_with_stride,
|
| 415 |
+
"block_stride": self.get_sparse_tokens_with_block_stride,
|
| 416 |
}
|
| 417 |
|
| 418 |
self.sparsity_type = config.sparsity_type
|
|
|
|
| 424 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
| 425 |
|
| 426 |
if self.sparsity_factor == 1:
|
| 427 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 428 |
|
| 429 |
with torch.no_grad():
|
| 430 |
|
|
|
|
| 452 |
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
| 453 |
|
| 454 |
if self.sparsity_factor == 1:
|
| 455 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 456 |
|
| 457 |
keys = self.chunk(keys, self.sparsity_factor)
|
| 458 |
values = self.chunk(values, self.sparsity_factor)
|
|
|
|
| 474 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
| 475 |
|
| 476 |
if self.sparsity_factor == 1:
|
| 477 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 478 |
|
| 479 |
n, h, t, d = keys.size()
|
| 480 |
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
|
|
|
|
| 487 |
|
| 488 |
return keys, values, mask
|
| 489 |
|
| 490 |
+
def get_sparse_tokens_with_block_stride(self, keys, values, mask):
|
| 491 |
+
|
| 492 |
+
if self.sparsity_factor == 1:
|
| 493 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 494 |
+
|
| 495 |
+
n, h, t, d = keys.size()
|
| 496 |
+
|
| 497 |
+
t, b = self.block_size, t // self.block_size
|
| 498 |
+
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
|
| 499 |
+
sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
|
| 500 |
+
sparse_idx = (sparse_idx % t)
|
| 501 |
+
sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
|
| 502 |
+
sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
|
| 503 |
+
|
| 504 |
+
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
| 505 |
+
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
| 506 |
+
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
|
| 507 |
+
|
| 508 |
+
return keys, values, mask
|
| 509 |
+
|
| 510 |
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
| 511 |
|
| 512 |
if self.sparsity_factor == 1:
|
| 513 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
| 514 |
|
| 515 |
block_size = min(self.block_size, self.sparse_block_size)
|
| 516 |
keys = self.chunk(keys, block_size)
|
|
|
|
| 1341 |
self.padding_idx = config.pad_token_id
|
| 1342 |
self.max_target_positions = config.max_position_embeddings
|
| 1343 |
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1344 |
+
self.adaptive = config.adaptive
|
| 1345 |
|
| 1346 |
if embed_tokens is not None:
|
| 1347 |
self.embed_tokens = embed_tokens
|
|
|
|
| 1384 |
|
| 1385 |
return combined_attention_mask
|
| 1386 |
|
| 1387 |
+
def resize_inputs(self, inputs_embeds, attention_mask):
|
| 1388 |
+
pad = 0
|
| 1389 |
+
|
| 1390 |
+
max_len = int(attention_mask.sum(dim=-1).max())
|
| 1391 |
+
pad = attention_mask.size()[-1] - max_len
|
| 1392 |
+
inputs_embeds = inputs_embeds[:, :max_len]
|
| 1393 |
+
attention_mask = attention_mask[..., :max_len]
|
| 1394 |
+
return pad, inputs_embeds, attention_mask
|
| 1395 |
+
|
| 1396 |
def forward(
|
| 1397 |
self,
|
| 1398 |
input_ids=None,
|
|
|
|
| 1433 |
if inputs_embeds is None:
|
| 1434 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 1435 |
|
| 1436 |
+
# Resize to reduce computation
|
| 1437 |
+
pad = 0
|
| 1438 |
+
if self.adaptive:
|
| 1439 |
+
if attention_mask is not None:
|
| 1440 |
+
pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
|
| 1441 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1442 |
+
if encoder_attention_mask is not None:
|
| 1443 |
+
_, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
|
| 1444 |
|
| 1445 |
attention_mask = self._prepare_decoder_attention_mask(
|
| 1446 |
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
|
|
| 1534 |
if encoder_hidden_states is not None:
|
| 1535 |
all_cross_attentions += (layer_outputs[2],)
|
| 1536 |
|
| 1537 |
+
# Resize to original shape
|
| 1538 |
+
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
|
| 1539 |
+
|
| 1540 |
# add hidden states from the last decoder layer
|
| 1541 |
if output_hidden_states:
|
| 1542 |
all_hidden_states += (hidden_states,)
|
|
|
|
| 1673 |
)
|
| 1674 |
|
| 1675 |
|
| 1676 |
+
class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
|
| 1677 |
|
| 1678 |
base_model_prefix = "model"
|
| 1679 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
| 1680 |
|
| 1681 |
def __init__(self, config):
|
| 1682 |
|
| 1683 |
+
LSGBartPretrainedModel.__init__(self, config)
|
| 1684 |
self.model = LSGBartModel(config)
|
| 1685 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
| 1686 |
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
|
|
|
| 1688 |
# Initialize weights and apply final processing
|
| 1689 |
self.post_init()
|
| 1690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1691 |
|
| 1692 |
+
class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
|
| 1693 |
|
| 1694 |
+
def __init__(self, config: LSGBartConfig, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1695 |
|
| 1696 |
+
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1697 |
self.model = LSGBartModel(config)
|
| 1698 |
self.classification_head = LSGBartClassificationHead(
|
| 1699 |
config.d_model,
|
|
|
|
| 1704 |
self.model._init_weights(self.classification_head.dense)
|
| 1705 |
self.model._init_weights(self.classification_head.out_proj)
|
| 1706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1707 |
|
| 1708 |
+
class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
|
| 1709 |
|
| 1710 |
+
def __init__(self, config: LSGBartConfig):
|
| 1711 |
|
| 1712 |
+
LSGBartPretrainedModel.__init__(self, config)
|
|
|
|
|
|
|
| 1713 |
|
| 1714 |
config.num_labels = 2
|
| 1715 |
self.num_labels = config.num_labels
|
|
|
|
| 1719 |
|
| 1720 |
self.model._init_weights(self.qa_outputs)
|
| 1721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1722 |
|
| 1723 |
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
| 1724 |
"""
|
|
|
|
| 1726 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
| 1727 |
"""
|
| 1728 |
|
| 1729 |
+
def __init__(self, config: LSGBartConfig):
|
| 1730 |
super().__init__(config)
|
| 1731 |
self.decoder = LSGBartDecoder(config)
|
| 1732 |
|
|
|
|
| 1734 |
return self.decoder(*args, **kwargs)
|
| 1735 |
|
| 1736 |
|
| 1737 |
+
class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
|
| 1738 |
|
| 1739 |
+
def __init__(self, config: LSGBartConfig):
|
| 1740 |
|
|
|
|
| 1741 |
config = copy.deepcopy(config)
|
| 1742 |
config.is_decoder = True
|
| 1743 |
config.is_encoder_decoder = False
|
| 1744 |
+
LSGBartPretrainedModel.__init__(self, config)
|
| 1745 |
self.model = LSGBartDecoderWrapper(config)
|
| 1746 |
|
| 1747 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
| 1749 |
# Initialize weights and apply final processing
|
| 1750 |
self.post_init()
|
| 1751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1752 |
|
| 1753 |
def str_to_class(classname):
|
| 1754 |
return getattr(sys.modules[__name__], classname)
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 653914167
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b88fc0094b185dff97f0e9d44c155c561da2a130efd7c1860a1af192272ef286
|
| 3 |
size 653914167
|
tokenizer.json
CHANGED
|
@@ -6,16 +6,7 @@
|
|
| 6 |
"strategy": "LongestFirst",
|
| 7 |
"stride": 0
|
| 8 |
},
|
| 9 |
-
"padding":
|
| 10 |
-
"strategy": {
|
| 11 |
-
"Fixed": 320
|
| 12 |
-
},
|
| 13 |
-
"direction": "Right",
|
| 14 |
-
"pad_to_multiple_of": null,
|
| 15 |
-
"pad_id": 1,
|
| 16 |
-
"pad_type_id": 0,
|
| 17 |
-
"pad_token": "<pad>"
|
| 18 |
-
},
|
| 19 |
"added_tokens": [
|
| 20 |
{
|
| 21 |
"id": 0,
|
|
|
|
| 6 |
"strategy": "LongestFirst",
|
| 7 |
"stride": 0
|
| 8 |
},
|
| 9 |
+
"padding": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"added_tokens": [
|
| 11 |
{
|
| 12 |
"id": 0,
|
tokenizer_config.json
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
{"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "/data/ccondevaux/lsg/text-summarization/
|
|
|
|
| 1 |
+
{"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local_16384_trained", "tokenizer_class": "BartTokenizer"}
|