automodel-test0 / modeling_avey.py
Hibiki711's picture
Upload AveyForCausalLM
fa6e017 verified
raw
history blame
3.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
from configuration_avey import AveyConfig
class SGU(nn.Module):
def __init__(self, config: AveyConfig):
super().__init__()
self.ctxt_mat = nn.Parameter(torch.empty(config.context_len, config.context_len))
nn.init.xavier_normal_(self.ctxt_mat)
def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + 1e-8)
normalized = embeddings / norm
cosim = torch.matmul(normalized, normalized.transpose(-1, -2))
return cosim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0, x1 = x.chunk(2, dim=-1)
c = torch.tril(self.cosim(x0)) * torch.tril(self.ctxt_mat)
x0 = c @ x0
output = x0 * x1
return output
class NeuralContextualizerLayer(nn.Module):
def __init__(self, config: AveyConfig):
super().__init__()
self.split_factor = [
int(config.d_embed * config.expansion_factor * 0.75),
int(config.d_embed * config.expansion_factor * 0.25)
]
self.enricher = nn.Linear(config.d_embed, config.d_embed * config.expansion_factor)
self.sgu = SGU(config)
proj_in_features = int(
config.d_embed * config.expansion_factor * 0.5 + config.d_embed * 0.5
)
self.fuser = nn.Linear(proj_in_features, config.d_embed)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_proj = F.gelu(self.enricher(x))
x0, x1 = x_proj.split(self.split_factor, dim=-1)
x0 = self.sgu(x0)
combined = torch.cat([x0, x1], dim=-1)
return self.fuser(combined)
class AveyBlock(nn.Module):
def __init__(self, config: AveyConfig):
super().__init__()
self.rms_norm = nn.RMSNorm(config.d_embed, eps=1e-10)
self.ctxt = NeuralContextualizerLayer(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.ctxt(self.rms_norm(x))
class AveyForCausalLM(PreTrainedModel, GenerationMixin):
config_class = AveyConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.d_embed)
nn.init.xavier_normal_(self.wte.weight)
self.blocks = nn.ModuleList([AveyBlock(config) for _ in range(config.n_blocks)])
self.ln_f = nn.RMSNorm(config.d_embed, eps=1e-10)
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
x = self.wte(input_ids)
B, T, E = x.shape
padded = False
orig_T = T
if T % self.config.context_len != 0:
pad_length = self.config.context_len - (T % self.config.context_len)
pad_tensor = torch.zeros(B, pad_length, E, device=x.device, dtype=x.dtype)
x = torch.cat([x, pad_tensor], dim=1)
T = x.shape[1]
padded = True
for block in self.blocks:
x = block(x)
logits = F.linear(self.ln_f(x), self.wte.weight)
if padded:
logits = logits[:, :orig_T, :]
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
return CausalLMOutput(logits=logits, loss=loss)
return CausalLMOutput(logits=logits)