import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutput from configuration_avey import AveyConfig class SGU(nn.Module): def __init__(self, config: AveyConfig): super().__init__() self.ctxt_mat = nn.Parameter(torch.empty(config.context_len, config.context_len)) nn.init.xavier_normal_(self.ctxt_mat) def cosim(self, embeddings: torch.Tensor) -> torch.Tensor: norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + 1e-8) normalized = embeddings / norm cosim = torch.matmul(normalized, normalized.transpose(-1, -2)) return cosim def forward(self, x: torch.Tensor) -> torch.Tensor: x0, x1 = x.chunk(2, dim=-1) c = torch.tril(self.cosim(x0)) * torch.tril(self.ctxt_mat) x0 = c @ x0 output = x0 * x1 return output class NeuralContextualizerLayer(nn.Module): def __init__(self, config: AveyConfig): super().__init__() self.split_factor = [ int(config.d_embed * config.expansion_factor * 0.75), int(config.d_embed * config.expansion_factor * 0.25) ] self.enricher = nn.Linear(config.d_embed, config.d_embed * config.expansion_factor) self.sgu = SGU(config) proj_in_features = int( config.d_embed * config.expansion_factor * 0.5 + config.d_embed * 0.5 ) self.fuser = nn.Linear(proj_in_features, config.d_embed) def forward(self, x: torch.Tensor) -> torch.Tensor: x_proj = F.gelu(self.enricher(x)) x0, x1 = x_proj.split(self.split_factor, dim=-1) x0 = self.sgu(x0) combined = torch.cat([x0, x1], dim=-1) return self.fuser(combined) class AveyBlock(nn.Module): def __init__(self, config: AveyConfig): super().__init__() self.rms_norm = nn.RMSNorm(config.d_embed, eps=1e-10) self.ctxt = NeuralContextualizerLayer(config) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.ctxt(self.rms_norm(x)) class AveyForCausalLM(PreTrainedModel, GenerationMixin): config_class = AveyConfig def __init__(self, config): super().__init__(config) self.config = config self.wte = nn.Embedding(config.vocab_size, config.d_embed) nn.init.xavier_normal_(self.wte.weight) self.blocks = nn.ModuleList([AveyBlock(config) for _ in range(config.n_blocks)]) self.ln_f = nn.RMSNorm(config.d_embed, eps=1e-10) def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs): x = self.wte(input_ids) B, T, E = x.shape padded = False orig_T = T if T % self.config.context_len != 0: pad_length = self.config.context_len - (T % self.config.context_len) pad_tensor = torch.zeros(B, pad_length, E, device=x.device, dtype=x.dtype) x = torch.cat([x, pad_tensor], dim=1) T = x.shape[1] padded = True for block in self.blocks: x = block(x) logits = F.linear(self.ln_f(x), self.wte.weight) if padded: logits = logits[:, :orig_T, :] if labels is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) return CausalLMOutput(logits=logits, loss=loss) return CausalLMOutput(logits=logits)