| import torch | |
| from torch import nn | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ContextualizerBlock(nn.Module): | |
| def __init__(self, d_model,d_ffn,dropout,num_tokens): | |
| super().__init__() | |
| self.context_mlp = FeedForward(d_model,d_ffn,dropout) | |
| self.mlp = FeedForward(d_model,d_ffn,dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.upsample = nn.Upsample(scale_factor=num_tokens,mode='nearest') | |
| self.downsample = nn.Upsample(scale_factor= 1/num_tokens, mode='nearest') | |
| def forward(self, x): | |
| res = x | |
| x = self.norm(x) | |
| context = x | |
| dim0 = context.shape[0] | |
| dim1 = context.shape[1] | |
| dim2 = context.shape[2] | |
| context = context.reshape([dim0,1,dim1*dim2]) | |
| context = self.downsample(context) | |
| context = context.reshape([dim0,dim2]) | |
| context = self.context_mlp(context) | |
| context = context.reshape([dim0,1,dim2]) | |
| context = self.upsample(context) | |
| context = context.reshape([dim0,dim1,dim2]) | |
| x = context | |
| x = x + res | |
| res = x | |
| x = self.norm(x) | |
| x = self.mlp(x) | |
| out = x + res | |
| return out | |
| return | |
| class MixerGatingUnit(nn.Module): | |
| def __init__(self,d_model,d_ffn,dropout,num_tokens): | |
| super().__init__() | |
| self.Mixer = ContextualizerBlock(d_model,d_ffn,dropout,num_tokens) | |
| self.proj = nn.Linear(d_model,d_model) | |
| def forward(self, x): | |
| u, v = x, x | |
| u = self.proj(u) | |
| v = self.Mixer(v) | |
| out = u * v | |
| return out | |
| class ContextualizerNiNBlock(nn.Module): | |
| def __init__(self, d_model,d_ffn,dropout,num_tokens): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(d_model) | |
| self.mgu = MixerGatingUnit(d_model,d_ffn,dropout,num_tokens) | |
| self.ffn = FeedForward(d_model,d_ffn,dropout) | |
| def forward(self, x): | |
| residual = x | |
| x = self.norm(x) | |
| x = self.mgu(x) | |
| x = x + residual | |
| residual = x | |
| x = self.norm(x) | |
| x = self.ffn(x) | |
| out = x + residual | |
| return out | |
| class ContextualizerNiN(nn.Module): | |
| def __init__(self, d_model, d_ffn, num_layers,dropout,num_tokens): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| *[ContextualizerNiNBlock(d_model,d_ffn,dropout,num_tokens) for _ in range(num_layers)], | |
| ) | |
| def forward(self, x): | |
| x = self.model(x) | |
| return x | |