|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from typing import Optional | 
					
						
						|  |  | 
					
						
						|  | from .configuration_bert import FlexBertConfig | 
					
						
						|  | from .normalization import get_norm_layer | 
					
						
						|  | from .initialization import ModuleType, init_weights | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BertAlibiEmbeddings(nn.Module): | 
					
						
						|  | """Construct the embeddings for words, ignoring position. | 
					
						
						|  |  | 
					
						
						|  | There are no positional embeddings since we use ALiBi and token_type | 
					
						
						|  | embeddings. | 
					
						
						|  |  | 
					
						
						|  | This module is modeled after the Hugging Face BERT's | 
					
						
						|  | :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is | 
					
						
						|  | modified as part of Mosaic BERT's ALiBi implementation. The key change is | 
					
						
						|  | that position embeddings are removed. Position information instead comes | 
					
						
						|  | from attention biases that scale linearly with the position distance | 
					
						
						|  | between query and key tokens. | 
					
						
						|  |  | 
					
						
						|  | This module ignores the `position_ids` input to the `forward` method. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | 
					
						
						|  |  | 
					
						
						|  | if getattr(config, "token_type_embeddings", True): | 
					
						
						|  | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) | 
					
						
						|  | self.use_token_type_embeddings = True | 
					
						
						|  | else: | 
					
						
						|  | self.use_token_type_embeddings = False | 
					
						
						|  |  | 
					
						
						|  | self.LayerNorm = get_norm_layer(config) | 
					
						
						|  | self.dropout = nn.Dropout(config.hidden_dropout_prob) | 
					
						
						|  | if self.use_token_type_embeddings: | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | input_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | token_type_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
						
						|  | past_key_values_length: int = 0, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | if (input_ids is not None) == (inputs_embeds is not None): | 
					
						
						|  | raise ValueError("Must specify either input_ids or input_embeds!") | 
					
						
						|  | if input_ids is not None: | 
					
						
						|  | input_shape = input_ids.size() | 
					
						
						|  | else: | 
					
						
						|  | assert inputs_embeds is not None | 
					
						
						|  | input_shape = inputs_embeds.size()[:-1] | 
					
						
						|  |  | 
					
						
						|  | seq_length = input_shape[1] | 
					
						
						|  |  | 
					
						
						|  | if position_ids is None: | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.use_token_type_embeddings and token_type_ids is None: | 
					
						
						|  | if hasattr(self, "token_type_ids"): | 
					
						
						|  | buffered_token_type_ids = self.token_type_ids[:, :seq_length] | 
					
						
						|  | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) | 
					
						
						|  | token_type_ids = buffered_token_type_ids_expanded | 
					
						
						|  | else: | 
					
						
						|  | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) | 
					
						
						|  |  | 
					
						
						|  | if inputs_embeds is None: | 
					
						
						|  | inputs_embeds = self.word_embeddings(input_ids) | 
					
						
						|  |  | 
					
						
						|  | if self.use_token_type_embeddings: | 
					
						
						|  | token_type_embeddings = self.token_type_embeddings(token_type_ids) | 
					
						
						|  | embeddings = inputs_embeds + token_type_embeddings | 
					
						
						|  | else: | 
					
						
						|  | embeddings = inputs_embeds | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | embeddings = self.LayerNorm(embeddings) | 
					
						
						|  | embeddings = self.dropout(embeddings) | 
					
						
						|  | return embeddings | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlexBertEmbeddingsBase(nn.Module): | 
					
						
						|  | """A FlexBERT embeddings base class for type hints.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: FlexBertConfig): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.config = config | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, reset_params: bool = False): | 
					
						
						|  | raise NotImplementedError("This is a base class and should not be used directly.") | 
					
						
						|  |  | 
					
						
						|  | def reset_parameters(self): | 
					
						
						|  | self._init_weights(reset_params=True) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: | 
					
						
						|  | raise NotImplementedError("This is a base class and should not be used directly.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlexBertAbsoluteEmbeddings(FlexBertEmbeddingsBase): | 
					
						
						|  | """Construct the embeddings with absolute positional embeddings.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: FlexBertConfig): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | 
					
						
						|  | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | 
					
						
						|  |  | 
					
						
						|  | self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity() | 
					
						
						|  | self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer( | 
					
						
						|  | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, reset_params: bool = False): | 
					
						
						|  | init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) | 
					
						
						|  | init_weights(self.config, self.position_embeddings, type_of_module=ModuleType.emb) | 
					
						
						|  |  | 
					
						
						|  | if reset_params: | 
					
						
						|  | if self.config.embed_norm: | 
					
						
						|  | self.norm.reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | input_ids: torch.LongTensor, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | if position_ids is None: | 
					
						
						|  | position_ids = self.position_ids[:, 0 : input_ids.shape[1]] | 
					
						
						|  |  | 
					
						
						|  | embeddings = self.tok_embeddings(input_ids) | 
					
						
						|  | position_embeddings = self.position_embeddings(position_ids) | 
					
						
						|  |  | 
					
						
						|  | embeddings = self.norm(embeddings + position_embeddings) | 
					
						
						|  | return self.drop(embeddings) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlexBertCompiledSansPositionEmbeddings(FlexBertEmbeddingsBase): | 
					
						
						|  | """Construct the embeddings from token embeddings without any positional embeddings.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: FlexBertConfig): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | 
					
						
						|  |  | 
					
						
						|  | self.norm = get_norm_layer(config, compiled_norm=config.compile_model) if config.embed_norm else nn.Identity() | 
					
						
						|  | self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, reset_params: bool = False): | 
					
						
						|  | init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) | 
					
						
						|  |  | 
					
						
						|  | if reset_params: | 
					
						
						|  | if self.config.embed_norm: | 
					
						
						|  | self.norm.reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | @torch.compile(dynamic=True) | 
					
						
						|  | def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: | 
					
						
						|  | return self.drop(self.norm(self.tok_embeddings(input_ids))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FlexBertSansPositionEmbeddings(FlexBertEmbeddingsBase): | 
					
						
						|  | """Construct the embeddings from token embeddings without any positional embeddings.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: FlexBertConfig): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | 
					
						
						|  |  | 
					
						
						|  | self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity() | 
					
						
						|  | self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def _init_weights(self, reset_params: bool = False): | 
					
						
						|  | init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) | 
					
						
						|  |  | 
					
						
						|  | if reset_params: | 
					
						
						|  | if self.config.embed_norm: | 
					
						
						|  | self.norm.reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: | 
					
						
						|  | return self.drop(self.norm(self.tok_embeddings(input_ids))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | EBB2CLS = { | 
					
						
						|  | "absolute_pos": FlexBertAbsoluteEmbeddings, | 
					
						
						|  | "sans_pos": FlexBertSansPositionEmbeddings, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_embedding_layer(config: FlexBertConfig) -> FlexBertEmbeddingsBase: | 
					
						
						|  | try: | 
					
						
						|  | if config.compile_model and config.embedding_layer == "sans_pos": | 
					
						
						|  | return FlexBertCompiledSansPositionEmbeddings(config) | 
					
						
						|  | elif config.compile_model: | 
					
						
						|  | raise ValueError(f"{config.compile_model=} only supports sans_pos embeddings.") | 
					
						
						|  | return EBB2CLS[config.embedding_layer](config) | 
					
						
						|  | except KeyError: | 
					
						
						|  | raise ValueError(f"Invalid embeddings layer type: {config.embedding_layer=}, must be one of {EBB2CLS.keys()}.") | 
					
						
						|  |  |