|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForMaskedLM |
|
|
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput |
|
|
|
|
|
class ModernALBERTConfig(PretrainedConfig): |
|
|
model_type = "ModernALBERT" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=30522, |
|
|
embedding_size=128, |
|
|
hidden_size=768, |
|
|
num_hidden_layers=12, |
|
|
num_attention_heads=12, |
|
|
intermediate_size=3072, |
|
|
hidden_act="gelu", |
|
|
hidden_dropout_prob=0.1, |
|
|
attention_probs_dropout_prob=0.1, |
|
|
max_position_embeddings=512, |
|
|
layer_norm_eps=1e-12, |
|
|
pad_token_id=0, |
|
|
initializer_range=0.02, |
|
|
use_cache=True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.embedding_size = embedding_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.intermediate_size = intermediate_size |
|
|
self.hidden_act = hidden_act |
|
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.layer_norm_eps = layer_norm_eps |
|
|
self.initializer_range = initializer_range |
|
|
self.use_cache = use_cache |
|
|
|
|
|
class ModernALBERTModel(PreTrainedModel): |
|
|
config_class = ModernALBERTConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.embedding_size) |
|
|
self.embed_proj = nn.Linear(config.embedding_size, config.hidden_size) |
|
|
self.encoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.hidden_dropout_prob, |
|
|
activation=config.hidden_act, |
|
|
layer_norm_eps=config.layer_norm_eps, |
|
|
batch_first=True |
|
|
), |
|
|
num_layers=config.num_hidden_layers |
|
|
) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
|
x = self.embeddings(input_ids) |
|
|
x = self.embed_proj(x) |
|
|
x = self.encoder(x, src_key_padding_mask=~attention_mask.bool() if attention_mask is not None else None) |
|
|
return BaseModelOutput(last_hidden_state=x, hidden_states=(x,)) |
|
|
|
|
|
class ModernALBERTForMaskedLM(PreTrainedModel): |
|
|
config_class = ModernALBERTConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.albert = ModernALBERTModel(config) |
|
|
self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
|
outputs = self.albert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = self.mlm_head(outputs.last_hidden_state) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
return MaskedLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
|
|
|
AutoConfig.register("ModernALBERT", ModernALBERTConfig) |
|
|
AutoModelForMaskedLM.register(ModernALBERTConfig, ModernALBERTForMaskedLM) |
|
|
|
|
|
|