| from typing import Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from transformers.cache_utils import Cache | |
| from configuration_spect1 import SpecT1Config | |
| class SpecT1MTPLayers(nn.Module): | |
| def __init__(self, config: SpecT1Config): | |
| super().__init__() | |
| self.input_layernorm = nn.LayerNorm(config.hidden_size) | |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) | |
| self.token_layernorm = nn.LayerNorm(config.hidden_size) | |
| self.hidden_layernorm = nn.LayerNorm(config.hidden_size) | |
| self.final_layernorm = nn.LayerNorm(config.hidden_size) | |
| self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) | |
| self.self_attn = nn.MultiheadAttention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| batch_first=True | |
| ) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(config.hidden_size, config.intermediate_size), | |
| nn.ReLU(), | |
| nn.Linear(config.intermediate_size, config.hidden_size) | |
| ) | |
| def forward( | |
| self, | |
| input_embeds: torch.Tensor, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| cache_position=None, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| input_embeds = self.token_layernorm(input_embeds) | |
| previous_hidden_states = self.hidden_layernorm(hidden_states) | |
| hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1)) | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask) | |
| hidden_states = residual + attn_output | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| mlp_output = self.mlp(hidden_states) | |
| hidden_states = residual + mlp_output | |
| hidden_states = self.final_layernorm(hidden_states) | |
| return hidden_states | |
| class SpecT1Model(nn.Module): | |
| config_class = SpecT1Config | |
| def __init__(self, config: SpecT1Config): | |
| super().__init__() | |
| self.config = config | |
| self.mtp_layers = nn.ModuleList([ | |
| SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers) | |
| ]) | |
| def forward( | |
| self, | |
| input_embeds: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| hidden_states = input_embeds | |
| for layer in self.mtp_layers: | |
| hidden_states = layer( | |
| input_embeds=input_embeds, | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| **kwargs | |
| ) | |
| return hidden_states | |
| class SpecT1ForCausalLM(PreTrainedModel): | |
| config_class = SpecT1Config | |
| def __init__(self, config: SpecT1Config): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = SpecT1Model(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| use_cache: Optional[bool] = False, | |
| output_attentions: Optional[bool] = False, | |
| output_hidden_states: Optional[bool] = False, | |
| return_dict: Optional[bool] = True, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| if inputs_embeds is None: | |
| raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM") | |
| hidden_states = self.model( | |
| input_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| **kwargs | |
| ) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
| if not return_dict: | |
| return (logits,) + (loss,) if loss is not None else (logits,) | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=None, | |
| attentions=None, | |
| past_key_values=None | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
| ): | |
| if inputs_embeds is None: | |
| raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation") | |
| return { | |
| "inputs_embeds": inputs_embeds, | |
| "attention_mask": attention_mask, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache", True) | |
| } |