| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel, GenerationMixin | |
| from transformers.modeling_outputs import CausalLMOutput | |
| from configuration_avey import AveyConfig | |
| class SGU(nn.Module): | |
| def __init__(self, config: AveyConfig): | |
| super().__init__() | |
| self.ctxt_mat = nn.Parameter(torch.empty(config.context_len, config.context_len)) | |
| nn.init.xavier_normal_(self.ctxt_mat) | |
| def cosim(self, embeddings: torch.Tensor) -> torch.Tensor: | |
| norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + 1e-8) | |
| normalized = embeddings / norm | |
| cosim = torch.matmul(normalized, normalized.transpose(-1, -2)) | |
| return cosim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x0, x1 = x.chunk(2, dim=-1) | |
| c = torch.tril(self.cosim(x0)) * torch.tril(self.ctxt_mat) | |
| x0 = c @ x0 | |
| output = x0 * x1 | |
| return output | |
| class NeuralContextualizerLayer(nn.Module): | |
| def __init__(self, config: AveyConfig): | |
| super().__init__() | |
| self.split_factor = [ | |
| int(config.d_embed * config.expansion_factor * 0.75), | |
| int(config.d_embed * config.expansion_factor * 0.25) | |
| ] | |
| self.enricher = nn.Linear(config.d_embed, config.d_embed * config.expansion_factor) | |
| self.sgu = SGU(config) | |
| proj_in_features = int( | |
| config.d_embed * config.expansion_factor * 0.5 + config.d_embed * 0.5 | |
| ) | |
| self.fuser = nn.Linear(proj_in_features, config.d_embed) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x_proj = F.gelu(self.enricher(x)) | |
| x0, x1 = x_proj.split(self.split_factor, dim=-1) | |
| x0 = self.sgu(x0) | |
| combined = torch.cat([x0, x1], dim=-1) | |
| return self.fuser(combined) | |
| class AveyBlock(nn.Module): | |
| def __init__(self, config: AveyConfig): | |
| super().__init__() | |
| self.rms_norm = nn.RMSNorm(config.d_embed, eps=1e-10) | |
| self.ctxt = NeuralContextualizerLayer(config) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.ctxt(self.rms_norm(x)) | |
| class AveyForCausalLM(PreTrainedModel, GenerationMixin): | |
| config_class = AveyConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wte = nn.Embedding(config.vocab_size, config.d_embed) | |
| nn.init.xavier_normal_(self.wte.weight) | |
| self.blocks = nn.ModuleList([AveyBlock(config) for _ in range(config.n_blocks)]) | |
| self.ln_f = nn.RMSNorm(config.d_embed, eps=1e-10) | |
| def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs): | |
| x = self.wte(input_ids) | |
| B, T, E = x.shape | |
| padded = False | |
| orig_T = T | |
| if T % self.config.context_len != 0: | |
| pad_length = self.config.context_len - (T % self.config.context_len) | |
| pad_tensor = torch.zeros(B, pad_length, E, device=x.device, dtype=x.dtype) | |
| x = torch.cat([x, pad_tensor], dim=1) | |
| T = x.shape[1] | |
| padded = True | |
| for block in self.blocks: | |
| x = block(x) | |
| logits = F.linear(self.ln_f(x), self.wte.weight) | |
| if padded: | |
| logits = logits[:, :orig_T, :] | |
| if labels is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
| return CausalLMOutput(logits=logits, loss=loss) | |
| return CausalLMOutput(logits=logits) | |