|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from .encoderblock import TransformerEncoder, get_activation_fn |
|
|
|
|
|
class AbLang(torch.nn.Module): |
|
""" |
|
AbLang inspired by ESM-2's architecture. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size, |
|
hidden_embed_size, |
|
n_attn_heads, |
|
n_encoder_blocks, |
|
padding_tkn, |
|
mask_tkn, |
|
layer_norm_eps: float = 1e-12, |
|
a_fn: str = "gelu", |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
|
|
self.AbRep = AbRep( |
|
vocab_size, |
|
hidden_embed_size, |
|
n_attn_heads, |
|
n_encoder_blocks, |
|
padding_tkn, |
|
mask_tkn, |
|
layer_norm_eps, |
|
a_fn, |
|
dropout, |
|
) |
|
self.AbHead = AbHead( |
|
vocab_size, |
|
hidden_embed_size, |
|
self.AbRep.aa_embed_layer.weight, |
|
layer_norm_eps, |
|
a_fn, |
|
) |
|
|
|
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): |
|
|
|
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers) |
|
|
|
if return_attn_weights: |
|
return representations.attention_weights |
|
|
|
elif return_rep_layers != []: |
|
return representations.many_hidden_states |
|
else: |
|
likelihoods = self.AbHead(representations.last_hidden_states) |
|
return likelihoods |
|
|
|
def get_aa_embeddings(self): |
|
"Extracts the trained aa_embeddings." |
|
return self.AbRep.aa_embed_layer |
|
|
|
|
|
class AbRep(torch.nn.Module): |
|
""" |
|
AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size, |
|
hidden_embed_size, |
|
n_attn_heads, |
|
n_encoder_blocks, |
|
padding_tkn, |
|
mask_tkn, |
|
layer_norm_eps: float = 1e-12, |
|
a_fn: str = "gelu", |
|
dropout: float = 0.1, |
|
): |
|
super().__init__() |
|
self.padding_tkn = padding_tkn |
|
self.mask_tkn = mask_tkn |
|
|
|
self.aa_embed_layer = nn.Embedding( |
|
vocab_size, |
|
hidden_embed_size, |
|
padding_idx=padding_tkn, |
|
) |
|
self.encoder_blocks = nn.ModuleList( |
|
[TransformerEncoder( |
|
hidden_embed_size, |
|
n_attn_heads, |
|
attn_dropout = dropout, |
|
layer_norm_eps = layer_norm_eps, |
|
a_fn = a_fn, |
|
) for _ in range(n_encoder_blocks)] |
|
) |
|
self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
|
|
def forward(self, |
|
tokens, |
|
return_attn_weights=False, |
|
return_rep_layers=[], |
|
): |
|
|
|
assert tokens.ndim == 2 |
|
padding_mask = tokens.eq(self.padding_tkn) |
|
|
|
hidden_embed = self.aa_embed_layer(tokens) |
|
|
|
return_rep_layers = set(return_rep_layers) |
|
rep_layers = {} |
|
if 0 in return_rep_layers: rep_layers[0] = hidden_embed |
|
|
|
all_attn_weights = [] |
|
|
|
for n_layer, encoder_block in enumerate(self.encoder_blocks): |
|
hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights) |
|
|
|
if (n_layer + 1) in return_rep_layers: |
|
rep_layers[n_layer + 1] = hidden_embed |
|
|
|
if return_attn_weights: |
|
all_attn_weights.append(attn_weights) |
|
|
|
hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed) |
|
|
|
return DataAbRep( |
|
last_hidden_states=hidden_embed, |
|
many_hidden_states=rep_layers, |
|
attention_weights=all_attn_weights |
|
) |
|
|
|
|
|
class AbHead(torch.nn.Module): |
|
""" |
|
AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size, |
|
hidden_embed_size, |
|
weights, |
|
layer_norm_eps: float = 1e-12, |
|
a_fn: str = "gelu", |
|
): |
|
super().__init__() |
|
|
|
activation_fn, scale = get_activation_fn(a_fn) |
|
|
|
self.ff = torch.nn.Sequential( |
|
nn.Linear(hidden_embed_size, hidden_embed_size * scale), |
|
activation_fn(), |
|
nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps), |
|
) |
|
|
|
self.weights = weights |
|
self.bias = nn.Parameter(torch.zeros(vocab_size)) |
|
|
|
def forward(self, hidden_embed): |
|
|
|
hidden_embed = self.ff(hidden_embed) |
|
logits = F.linear(hidden_embed, self.weights) + self.bias |
|
|
|
return logits |
|
|
|
|
|
@dataclass |
|
class DataAbRep(): |
|
""" |
|
Dataclass used to store AbRep output. |
|
""" |
|
|
|
last_hidden_states: torch.FloatTensor |
|
many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attention_weights: Optional[Tuple[torch.FloatTensor]] = None |