import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForMaskedLM from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput class ModernALBERTConfig(PretrainedConfig): model_type = "ModernALBERT" def __init__( self, vocab_size=30522, embedding_size=128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, layer_norm_eps=1e-12, pad_token_id=0, initializer_range=0.02, use_cache=True, **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range self.use_cache = use_cache class ModernALBERTModel(PreTrainedModel): config_class = ModernALBERTConfig def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.embedding_size) self.embed_proj = nn.Linear(config.embedding_size, config.hidden_size) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob, activation=config.hidden_act, layer_norm_eps=config.layer_norm_eps, batch_first=True ), num_layers=config.num_hidden_layers ) self.post_init() def forward(self, input_ids=None, attention_mask=None, **kwargs): x = self.embeddings(input_ids) x = self.embed_proj(x) x = self.encoder(x, src_key_padding_mask=~attention_mask.bool() if attention_mask is not None else None) return BaseModelOutput(last_hidden_state=x, hidden_states=(x,)) class ModernALBERTForMaskedLM(PreTrainedModel): config_class = ModernALBERTConfig def __init__(self, config): super().__init__(config) self.albert = ModernALBERTModel(config) self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size) self.post_init() def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): outputs = self.albert(input_ids=input_ids, attention_mask=attention_mask) logits = self.mlm_head(outputs.last_hidden_state) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=None, ) # Register with Hugging Face AutoConfig.register("ModernALBERT", ModernALBERTConfig) AutoModelForMaskedLM.register(ModernALBERTConfig, ModernALBERTForMaskedLM)