Add NeoBERTForTokenClassification class
Browse files- config.json +23 -2
- model.py +70 -0
config.json
CHANGED
|
@@ -6,7 +6,8 @@
|
|
| 6 |
"AutoConfig": "model.NeoBERTConfig",
|
| 7 |
"AutoModel": "model.NeoBERT",
|
| 8 |
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
| 9 |
-
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
|
|
|
|
| 10 |
},
|
| 11 |
"classifier_init_range": 0.02,
|
| 12 |
"decoder_init_range": 0.02,
|
|
@@ -15,8 +16,28 @@
|
|
| 15 |
"hidden_size": 768,
|
| 16 |
"intermediate_size": 3072,
|
| 17 |
"kwargs": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"classifier_init_range": 0.02,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
|
|
|
|
|
|
| 20 |
"trust_remote_code": true
|
| 21 |
},
|
| 22 |
"max_length": 4096,
|
|
@@ -27,7 +48,7 @@
|
|
| 27 |
"pad_token_id": 0,
|
| 28 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
| 29 |
"torch_dtype": "float32",
|
| 30 |
-
"transformers_version": "4.
|
| 31 |
"trust_remote_code": true,
|
| 32 |
"vocab_size": 30522
|
| 33 |
}
|
|
|
|
| 6 |
"AutoConfig": "model.NeoBERTConfig",
|
| 7 |
"AutoModel": "model.NeoBERT",
|
| 8 |
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
| 9 |
+
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification",
|
| 10 |
+
"AutoModelForTokenClassification": "model.NeoBERTForTokenClassification"
|
| 11 |
},
|
| 12 |
"classifier_init_range": 0.02,
|
| 13 |
"decoder_init_range": 0.02,
|
|
|
|
| 16 |
"hidden_size": 768,
|
| 17 |
"intermediate_size": 3072,
|
| 18 |
"kwargs": {
|
| 19 |
+
"_commit_hash": null,
|
| 20 |
+
"architectures": [
|
| 21 |
+
"NeoBERTLMHead"
|
| 22 |
+
],
|
| 23 |
+
"attn_implementation": null,
|
| 24 |
+
"auto_map": {
|
| 25 |
+
"AutoConfig": "model.NeoBERTConfig",
|
| 26 |
+
"AutoModel": "model.NeoBERT",
|
| 27 |
+
"AutoModelForMaskedLM": "model.NeoBERTLMHead",
|
| 28 |
+
"AutoModelForSequenceClassification": "model.NeoBERTForSequenceClassification"
|
| 29 |
+
},
|
| 30 |
"classifier_init_range": 0.02,
|
| 31 |
+
"dim_head": 64,
|
| 32 |
+
"kwargs": {
|
| 33 |
+
"classifier_init_range": 0.02,
|
| 34 |
+
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
| 35 |
+
"trust_remote_code": true
|
| 36 |
+
},
|
| 37 |
+
"model_type": "neobert",
|
| 38 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
| 39 |
+
"torch_dtype": "float32",
|
| 40 |
+
"transformers_version": "4.48.2",
|
| 41 |
"trust_remote_code": true
|
| 42 |
},
|
| 43 |
"max_length": 4096,
|
|
|
|
| 48 |
"pad_token_id": 0,
|
| 49 |
"pretrained_model_name_or_path": "google-bert/bert-base-uncased",
|
| 50 |
"torch_dtype": "float32",
|
| 51 |
+
"transformers_version": "4.51.3",
|
| 52 |
"trust_remote_code": true,
|
| 53 |
"vocab_size": 30522
|
| 54 |
}
|
model.py
CHANGED
|
@@ -27,6 +27,7 @@ from transformers.modeling_outputs import (
|
|
| 27 |
BaseModelOutput,
|
| 28 |
MaskedLMOutput,
|
| 29 |
SequenceClassifierOutput,
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
from .rotary import precompute_freqs_cis, apply_rotary_emb
|
|
@@ -432,3 +433,72 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
|
|
| 432 |
hidden_states=output.hidden_states if output_hidden_states else None,
|
| 433 |
attentions=output.attentions if output_attentions else None,
|
| 434 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
BaseModelOutput,
|
| 28 |
MaskedLMOutput,
|
| 29 |
SequenceClassifierOutput,
|
| 30 |
+
TokenClassifierOutput
|
| 31 |
)
|
| 32 |
|
| 33 |
from .rotary import precompute_freqs_cis, apply_rotary_emb
|
|
|
|
| 433 |
hidden_states=output.hidden_states if output_hidden_states else None,
|
| 434 |
attentions=output.attentions if output_attentions else None,
|
| 435 |
)
|
| 436 |
+
|
| 437 |
+
class NeoBERTForTokenClassification(NeoBERTPreTrainedModel):
|
| 438 |
+
config_class = NeoBERTConfig
|
| 439 |
+
|
| 440 |
+
def __init__(self, config: NeoBERTConfig):
|
| 441 |
+
super().__init__(config)
|
| 442 |
+
|
| 443 |
+
self.config = config
|
| 444 |
+
self.num_labels = getattr(config, "num_labels", 2)
|
| 445 |
+
self.classifier_dropout = getattr(config, "classifier_dropout", 0.1)
|
| 446 |
+
self.classifier_init_range = getattr(config, "classifier_init_range", 0.02)
|
| 447 |
+
|
| 448 |
+
self.model = NeoBERT(config)
|
| 449 |
+
|
| 450 |
+
self.dropout = nn.Dropout(self.classifier_dropout)
|
| 451 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
|
| 452 |
+
|
| 453 |
+
self.post_init()
|
| 454 |
+
|
| 455 |
+
def _init_weights(self, module):
|
| 456 |
+
if isinstance(module, nn.Linear):
|
| 457 |
+
module.weight.data.normal_(mean=0.0, std=self.classifier_init_range)
|
| 458 |
+
if module.bias is not None:
|
| 459 |
+
module.bias.data.zero_()
|
| 460 |
+
|
| 461 |
+
def forward(
|
| 462 |
+
self,
|
| 463 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 464 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 465 |
+
max_seqlen: Optional[int] = None,
|
| 466 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 467 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 468 |
+
output_hidden_states: Optional[bool] = False,
|
| 469 |
+
output_attentions: Optional[bool] = False,
|
| 470 |
+
labels: Optional[torch.Tensor] = None,
|
| 471 |
+
return_dict: Optional[bool] = None,
|
| 472 |
+
):
|
| 473 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 474 |
+
|
| 475 |
+
output = self.model(
|
| 476 |
+
input_ids=input_ids,
|
| 477 |
+
position_ids=position_ids,
|
| 478 |
+
max_seqlen=max_seqlen,
|
| 479 |
+
cu_seqlens=cu_seqlens,
|
| 480 |
+
attention_mask=attention_mask,
|
| 481 |
+
output_hidden_states=output_hidden_states,
|
| 482 |
+
output_attentions=output_attentions,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
sequence_output = output.last_hidden_state
|
| 486 |
+
sequence_output = self.dropout(sequence_output)
|
| 487 |
+
logits = self.classifier(sequence_output)
|
| 488 |
+
|
| 489 |
+
loss = None
|
| 490 |
+
if labels is not None:
|
| 491 |
+
loss_fct = CrossEntropyLoss()
|
| 492 |
+
# Reshape logits and labels to compute token classification loss
|
| 493 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 494 |
+
|
| 495 |
+
if not return_dict:
|
| 496 |
+
output = (logits,)
|
| 497 |
+
return ((loss,) + output) if loss is not None else output
|
| 498 |
+
|
| 499 |
+
return TokenClassifierOutput(
|
| 500 |
+
loss=loss,
|
| 501 |
+
logits=logits,
|
| 502 |
+
hidden_states=output.hidden_states if output_hidden_states else None,
|
| 503 |
+
attentions=output.attentions if output_attentions else None,
|
| 504 |
+
)
|