File size: 5,558 Bytes
748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 748edd0 5e5aa43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from configuration_spect1 import SpecT1Config
class SpecT1MTPLayers(nn.Module):
def __init__(self, config: SpecT1Config):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
self.token_layernorm = nn.LayerNorm(config.hidden_size)
self.hidden_layernorm = nn.LayerNorm(config.hidden_size)
self.final_layernorm = nn.LayerNorm(config.hidden_size)
self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.self_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.ReLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
def forward(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position=None,
**kwargs
) -> torch.Tensor:
input_embeds = self.token_layernorm(input_embeds)
previous_hidden_states = self.hidden_layernorm(hidden_states)
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = residual + mlp_output
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class SpecT1Model(nn.Module):
config_class = SpecT1Config
def __init__(self, config: SpecT1Config):
super().__init__()
self.config = config
self.mtp_layers = nn.ModuleList([
SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers)
])
def forward(
self,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
hidden_states = input_embeds
for layer in self.mtp_layers:
hidden_states = layer(
input_embeds=input_embeds,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs
)
return hidden_states
class SpecT1ForCausalLM(PreTrainedModel):
config_class = SpecT1Config
def __init__(self, config: SpecT1Config):
super().__init__(config)
self.config = config
self.model = SpecT1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
**kwargs
) -> torch.Tensor:
if inputs_embeds is None:
raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM")
hidden_states = self.model(
input_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
return (logits,) + (loss,) if loss is not None else (logits,)
from transformers.modeling_outputs import CausalLMOutputWithPast
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
past_key_values=None
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if inputs_embeds is None:
raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation")
return {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True)
} |