Add ablang.py and encoderblock.py to root directory for Hugging Face compatibility
Browse files- ablang.py +181 -0
- encoderblock.py +173 -0
- modeling_ablang2paired.py +10 -20
- test_ablang2_HF_implementation.ipynb +74 -297
- test_module_loading.py +19 -0
ablang.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .encoderblock import TransformerEncoder, get_activation_fn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AbLang(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
AbLang inspired by ESM-2's architecture.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
vocab_size,
|
| 19 |
+
hidden_embed_size,
|
| 20 |
+
n_attn_heads,
|
| 21 |
+
n_encoder_blocks,
|
| 22 |
+
padding_tkn,
|
| 23 |
+
mask_tkn,
|
| 24 |
+
layer_norm_eps: float = 1e-12,
|
| 25 |
+
a_fn: str = "gelu",
|
| 26 |
+
dropout: float = 0.0,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.AbRep = AbRep(
|
| 31 |
+
vocab_size,
|
| 32 |
+
hidden_embed_size,
|
| 33 |
+
n_attn_heads,
|
| 34 |
+
n_encoder_blocks,
|
| 35 |
+
padding_tkn,
|
| 36 |
+
mask_tkn,
|
| 37 |
+
layer_norm_eps,
|
| 38 |
+
a_fn,
|
| 39 |
+
dropout,
|
| 40 |
+
)
|
| 41 |
+
self.AbHead = AbHead(
|
| 42 |
+
vocab_size,
|
| 43 |
+
hidden_embed_size,
|
| 44 |
+
self.AbRep.aa_embed_layer.weight,
|
| 45 |
+
layer_norm_eps,
|
| 46 |
+
a_fn,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
|
| 50 |
+
|
| 51 |
+
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
|
| 52 |
+
|
| 53 |
+
if return_attn_weights:
|
| 54 |
+
return representations.attention_weights
|
| 55 |
+
|
| 56 |
+
elif return_rep_layers != []:
|
| 57 |
+
return representations.many_hidden_states
|
| 58 |
+
else:
|
| 59 |
+
likelihoods = self.AbHead(representations.last_hidden_states)
|
| 60 |
+
return likelihoods
|
| 61 |
+
|
| 62 |
+
def get_aa_embeddings(self):
|
| 63 |
+
"Extracts the trained aa_embeddings."
|
| 64 |
+
return self.AbRep.aa_embed_layer
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AbRep(torch.nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations).
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
vocab_size,
|
| 75 |
+
hidden_embed_size,
|
| 76 |
+
n_attn_heads,
|
| 77 |
+
n_encoder_blocks,
|
| 78 |
+
padding_tkn,
|
| 79 |
+
mask_tkn,
|
| 80 |
+
layer_norm_eps: float = 1e-12,
|
| 81 |
+
a_fn: str = "gelu",
|
| 82 |
+
dropout: float = 0.1,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.padding_tkn = padding_tkn
|
| 86 |
+
self.mask_tkn = mask_tkn
|
| 87 |
+
|
| 88 |
+
self.aa_embed_layer = nn.Embedding(
|
| 89 |
+
vocab_size,
|
| 90 |
+
hidden_embed_size,
|
| 91 |
+
padding_idx=padding_tkn,
|
| 92 |
+
)
|
| 93 |
+
self.encoder_blocks = nn.ModuleList(
|
| 94 |
+
[TransformerEncoder(
|
| 95 |
+
hidden_embed_size,
|
| 96 |
+
n_attn_heads,
|
| 97 |
+
attn_dropout = dropout,
|
| 98 |
+
layer_norm_eps = layer_norm_eps,
|
| 99 |
+
a_fn = a_fn,
|
| 100 |
+
) for _ in range(n_encoder_blocks)]
|
| 101 |
+
)
|
| 102 |
+
self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 103 |
+
|
| 104 |
+
def forward(self,
|
| 105 |
+
tokens,
|
| 106 |
+
return_attn_weights=False,
|
| 107 |
+
return_rep_layers=[],
|
| 108 |
+
):
|
| 109 |
+
|
| 110 |
+
assert tokens.ndim == 2
|
| 111 |
+
padding_mask = tokens.eq(self.padding_tkn)
|
| 112 |
+
|
| 113 |
+
hidden_embed = self.aa_embed_layer(tokens)
|
| 114 |
+
|
| 115 |
+
return_rep_layers = set(return_rep_layers)
|
| 116 |
+
rep_layers = {}
|
| 117 |
+
if 0 in return_rep_layers: rep_layers[0] = hidden_embed
|
| 118 |
+
|
| 119 |
+
all_attn_weights = []
|
| 120 |
+
|
| 121 |
+
for n_layer, encoder_block in enumerate(self.encoder_blocks):
|
| 122 |
+
hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights)
|
| 123 |
+
|
| 124 |
+
if (n_layer + 1) in return_rep_layers:
|
| 125 |
+
rep_layers[n_layer + 1] = hidden_embed
|
| 126 |
+
|
| 127 |
+
if return_attn_weights:
|
| 128 |
+
all_attn_weights.append(attn_weights)
|
| 129 |
+
|
| 130 |
+
hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed)
|
| 131 |
+
|
| 132 |
+
return DataAbRep(
|
| 133 |
+
last_hidden_states=hidden_embed,
|
| 134 |
+
many_hidden_states=rep_layers,
|
| 135 |
+
attention_weights=all_attn_weights
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class AbHead(torch.nn.Module):
|
| 140 |
+
"""
|
| 141 |
+
AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations).
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vocab_size,
|
| 147 |
+
hidden_embed_size,
|
| 148 |
+
weights,
|
| 149 |
+
layer_norm_eps: float = 1e-12,
|
| 150 |
+
a_fn: str = "gelu",
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
activation_fn, scale = get_activation_fn(a_fn)
|
| 155 |
+
|
| 156 |
+
self.ff = torch.nn.Sequential(
|
| 157 |
+
nn.Linear(hidden_embed_size, hidden_embed_size * scale),
|
| 158 |
+
activation_fn(),
|
| 159 |
+
nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.weights = weights
|
| 163 |
+
self.bias = nn.Parameter(torch.zeros(vocab_size))
|
| 164 |
+
|
| 165 |
+
def forward(self, hidden_embed):
|
| 166 |
+
|
| 167 |
+
hidden_embed = self.ff(hidden_embed)
|
| 168 |
+
logits = F.linear(hidden_embed, self.weights) + self.bias
|
| 169 |
+
|
| 170 |
+
return logits
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@dataclass
|
| 174 |
+
class DataAbRep():
|
| 175 |
+
"""
|
| 176 |
+
Dataclass used to store AbRep output.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
last_hidden_states: torch.FloatTensor
|
| 180 |
+
many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 181 |
+
attention_weights: Optional[Tuple[torch.FloatTensor]] = None
|
encoderblock.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import einops
|
| 6 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 7 |
+
|
| 8 |
+
class TransformerEncoder(torch.nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Single Transformer Encoder.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
hidden_embed_size,
|
| 16 |
+
n_attn_heads,
|
| 17 |
+
attn_dropout: float = 0.0,
|
| 18 |
+
layer_norm_eps: float = 1e-05,
|
| 19 |
+
a_fn: str = "gelu",
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
assert hidden_embed_size % n_attn_heads == 0, \
|
| 24 |
+
"Embedding dimension must be devisible with the number of heads."
|
| 25 |
+
|
| 26 |
+
self.multihead_attention = MultiHeadAttention(
|
| 27 |
+
embed_dim = hidden_embed_size,
|
| 28 |
+
num_heads = n_attn_heads,
|
| 29 |
+
attention_dropout_prob = attn_dropout
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
activation_fn, scale = get_activation_fn(a_fn)
|
| 33 |
+
|
| 34 |
+
self.intermediate_layer = torch.nn.Sequential(
|
| 35 |
+
torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale),
|
| 36 |
+
activation_fn(),
|
| 37 |
+
torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 41 |
+
self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
|
| 42 |
+
|
| 43 |
+
def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False):
|
| 44 |
+
|
| 45 |
+
residual = hidden_embed
|
| 46 |
+
hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone())
|
| 47 |
+
hidden_embed, attn_weights = self.multihead_attention(
|
| 48 |
+
hidden_embed,
|
| 49 |
+
attn_mask=attn_mask,
|
| 50 |
+
return_attn_weights=return_attn_weights
|
| 51 |
+
)
|
| 52 |
+
hidden_embed = residual + hidden_embed
|
| 53 |
+
|
| 54 |
+
residual = hidden_embed
|
| 55 |
+
hidden_embed = self.final_layer_norm(hidden_embed)
|
| 56 |
+
hidden_embed = self.intermediate_layer(hidden_embed)
|
| 57 |
+
hidden_embed = residual + hidden_embed
|
| 58 |
+
return hidden_embed, attn_weights
|
| 59 |
+
|
| 60 |
+
class MultiHeadAttention(torch.nn.Module):
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
embed_dim,
|
| 65 |
+
num_heads,
|
| 66 |
+
attention_dropout_prob: float = 0.0,
|
| 67 |
+
bias: bool = True,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
|
| 72 |
+
|
| 73 |
+
self.embed_dim = embed_dim
|
| 74 |
+
self.num_heads = num_heads
|
| 75 |
+
self.head_dim = embed_dim // num_heads
|
| 76 |
+
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
|
| 77 |
+
self.scaling = self.head_dim**-0.5
|
| 78 |
+
|
| 79 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 80 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 81 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 82 |
+
|
| 83 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 84 |
+
|
| 85 |
+
self.reset_parameters()
|
| 86 |
+
|
| 87 |
+
self.rotary_emb = RotaryEmbedding(dim = self.head_dim)
|
| 88 |
+
|
| 89 |
+
def reset_parameters(self):
|
| 90 |
+
|
| 91 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 92 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 93 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 94 |
+
|
| 95 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 96 |
+
if self.out_proj.bias is not None:
|
| 97 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 98 |
+
|
| 99 |
+
def attention(self, q, k, v, attn_mask=None):
|
| 100 |
+
|
| 101 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1))
|
| 102 |
+
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
| 103 |
+
|
| 104 |
+
if attn_mask is not None:
|
| 105 |
+
attn_mask = einops.rearrange(
|
| 106 |
+
attn_mask,
|
| 107 |
+
'b_size (h1 h2 seq_len) -> b_size h1 h2 seq_len',
|
| 108 |
+
h1=1, h2=1
|
| 109 |
+
)
|
| 110 |
+
attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
|
| 111 |
+
|
| 112 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 113 |
+
|
| 114 |
+
attn = self.attention_dropout(attn_weights)
|
| 115 |
+
attn = torch.matmul(attn, v)
|
| 116 |
+
return attn, attn_weights
|
| 117 |
+
|
| 118 |
+
def forward(self, x, attn_mask=None, return_attn_weights: bool = False):
|
| 119 |
+
|
| 120 |
+
batch_size, seq_len, embed_dim = x.size()
|
| 121 |
+
|
| 122 |
+
q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
| 123 |
+
q *= self.scaling
|
| 124 |
+
|
| 125 |
+
q = q.contiguous().view(
|
| 126 |
+
batch_size,
|
| 127 |
+
seq_len,
|
| 128 |
+
self.num_heads,
|
| 129 |
+
self.head_dim
|
| 130 |
+
).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
|
| 131 |
+
k = k.contiguous().view(
|
| 132 |
+
batch_size,
|
| 133 |
+
seq_len,
|
| 134 |
+
self.num_heads,
|
| 135 |
+
self.head_dim
|
| 136 |
+
).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
|
| 137 |
+
v = v.contiguous().view(
|
| 138 |
+
batch_size,
|
| 139 |
+
seq_len,
|
| 140 |
+
self.num_heads,
|
| 141 |
+
self.head_dim
|
| 142 |
+
).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
|
| 143 |
+
|
| 144 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
| 145 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
| 146 |
+
|
| 147 |
+
# Determine value outputs
|
| 148 |
+
attn, attn_weights = self.attention(
|
| 149 |
+
q, k, v,
|
| 150 |
+
attn_mask=attn_mask
|
| 151 |
+
) # attn_weights [n_batch, n_heads, seq_len (target), seq_len (source)]
|
| 152 |
+
|
| 153 |
+
attn = attn.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
|
| 154 |
+
attn = self.out_proj(attn)
|
| 155 |
+
|
| 156 |
+
if return_attn_weights:
|
| 157 |
+
return attn, attn_weights
|
| 158 |
+
else:
|
| 159 |
+
return attn, None
|
| 160 |
+
|
| 161 |
+
class SwiGLU(torch.nn.Module):
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
x, gate = x.chunk(2, dim=-1)
|
| 164 |
+
return F.silu(gate) * x
|
| 165 |
+
|
| 166 |
+
def get_activation_fn(a_fn):
|
| 167 |
+
|
| 168 |
+
if a_fn == "gelu":
|
| 169 |
+
return torch.nn.GELU, 1
|
| 170 |
+
|
| 171 |
+
elif a_fn == "swiglu":
|
| 172 |
+
return SwiGLU, 2
|
| 173 |
+
|
modeling_ablang2paired.py
CHANGED
|
@@ -9,29 +9,19 @@ try:
|
|
| 9 |
except ImportError:
|
| 10 |
from configuration_ablang2paired import AbLang2PairedConfig
|
| 11 |
|
| 12 |
-
# Import the AbLang model from
|
| 13 |
-
|
| 14 |
-
import
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
ablang_path = os.path.join(current_dir, "ablang2", "models", "ablang2", "ablang.py")
|
| 21 |
-
|
| 22 |
-
if os.path.exists(ablang_path):
|
| 23 |
-
spec = importlib.util.spec_from_file_location("ablang", ablang_path)
|
| 24 |
-
ablang_module = importlib.util.module_from_spec(spec)
|
| 25 |
-
spec.loader.exec_module(ablang_module)
|
| 26 |
-
return ablang_module.AbLang
|
| 27 |
-
else:
|
| 28 |
-
# If not found, raise an error with helpful message
|
| 29 |
raise ImportError(
|
| 30 |
-
"Could not find AbLang module. Please ensure
|
| 31 |
-
"in the repository."
|
| 32 |
)
|
| 33 |
|
| 34 |
-
|
| 35 |
|
| 36 |
class AbLang2PairedHFModel(PreTrainedModel):
|
| 37 |
config_class = AbLang2PairedConfig
|
|
|
|
| 9 |
except ImportError:
|
| 10 |
from configuration_ablang2paired import AbLang2PairedConfig
|
| 11 |
|
| 12 |
+
# Import the AbLang model from local files
|
| 13 |
+
try:
|
| 14 |
+
from ablang import AbLang
|
| 15 |
+
except ImportError:
|
| 16 |
+
# Fallback: try to import from the current directory
|
| 17 |
+
try:
|
| 18 |
+
from .ablang import AbLang
|
| 19 |
+
except ImportError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
raise ImportError(
|
| 21 |
+
"Could not find AbLang module. Please ensure ablang.py is present in the repository."
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
+
|
| 25 |
|
| 26 |
class AbLang2PairedHFModel(PreTrainedModel):
|
| 27 |
config_class = AbLang2PairedConfig
|
test_ablang2_HF_implementation.ipynb
CHANGED
|
@@ -86,34 +86,77 @@
|
|
| 86 |
"id": "6d66ad84",
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
{
|
| 90 |
"name": "stderr",
|
| 91 |
"output_type": "stream",
|
| 92 |
"text": [
|
| 93 |
"A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
|
| 94 |
"- configuration_ablang2paired.py\n",
|
| 95 |
-
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
| 96 |
-
"A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
|
| 97 |
-
"- modeling_ablang2paired.py\n",
|
| 98 |
-
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
|
| 99 |
-
"/home/hn533621/.conda/envs/lib_transformer/lib/python3.10/site-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 100 |
-
" warnings.warn(\n"
|
| 101 |
]
|
| 102 |
},
|
| 103 |
{
|
| 104 |
-
"
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
},
|
| 110 |
{
|
| 111 |
"name": "stderr",
|
| 112 |
"output_type": "stream",
|
| 113 |
"text": [
|
| 114 |
-
"
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
]
|
| 118 |
}
|
| 119 |
],
|
|
@@ -162,7 +205,7 @@
|
|
| 162 |
},
|
| 163 |
{
|
| 164 |
"cell_type": "code",
|
| 165 |
-
"execution_count":
|
| 166 |
"id": "ceae4a88-0679-4704-8bad-c06a4569c497",
|
| 167 |
"metadata": {},
|
| 168 |
"outputs": [],
|
|
@@ -187,30 +230,10 @@
|
|
| 187 |
},
|
| 188 |
{
|
| 189 |
"cell_type": "code",
|
| 190 |
-
"execution_count":
|
| 191 |
"id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
|
| 192 |
"metadata": {},
|
| 193 |
-
"outputs": [
|
| 194 |
-
{
|
| 195 |
-
"data": {
|
| 196 |
-
"text/plain": [
|
| 197 |
-
"array([[-0.25206311, 0.18189634, 0.00887137, ..., 0.15365517,\n",
|
| 198 |
-
" -0.14508603, -0.13381317],\n",
|
| 199 |
-
" [-0.25149415, 0.2086455 , 0.07518203, ..., 0.19478269,\n",
|
| 200 |
-
" -0.15227772, -0.08241647],\n",
|
| 201 |
-
" [-0.27468949, 0.16507216, 0.08667156, ..., 0.18776284,\n",
|
| 202 |
-
" -0.14165082, -0.16389885],\n",
|
| 203 |
-
" [-0.1982213 , 0.16841085, -0.04925933, ..., 0.11400164,\n",
|
| 204 |
-
" -0.14723683, -0.09713171],\n",
|
| 205 |
-
" [-0.29553188, 0.17239201, 0.05676926, ..., 0.15943622,\n",
|
| 206 |
-
" -0.16615383, -0.15569784]], shape=(5, 480))"
|
| 207 |
-
]
|
| 208 |
-
},
|
| 209 |
-
"execution_count": 8,
|
| 210 |
-
"metadata": {},
|
| 211 |
-
"output_type": "execute_result"
|
| 212 |
-
}
|
| 213 |
-
],
|
| 214 |
"source": [
|
| 215 |
"ablang(all_seqs, mode='seqcoding')\n"
|
| 216 |
]
|
|
@@ -231,85 +254,10 @@
|
|
| 231 |
},
|
| 232 |
{
|
| 233 |
"cell_type": "code",
|
| 234 |
-
"execution_count":
|
| 235 |
"id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
|
| 236 |
"metadata": {},
|
| 237 |
-
"outputs": [
|
| 238 |
-
{
|
| 239 |
-
"data": {
|
| 240 |
-
"text/plain": [
|
| 241 |
-
"[array([[-0.40741208, -0.5118987 , 0.06096708, ..., 0.3268144 ,\n",
|
| 242 |
-
" 0.03920235, -0.36715826],\n",
|
| 243 |
-
" [-0.5768883 , 0.38245413, -0.21791998, ..., 0.01250262,\n",
|
| 244 |
-
" -0.08844463, -0.32367525],\n",
|
| 245 |
-
" [-0.1475935 , 0.39639047, -0.38226923, ..., -0.10119921,\n",
|
| 246 |
-
" -0.41469565, -0.00319315],\n",
|
| 247 |
-
" ...,\n",
|
| 248 |
-
" [-0.14358369, 0.3124389 , -0.30157998, ..., -0.13289244,\n",
|
| 249 |
-
" -0.45353398, -0.07878865],\n",
|
| 250 |
-
" [ 0.17538925, 0.24394299, 0.20141171, ..., 0.14587352,\n",
|
| 251 |
-
" -0.38479003, 0.07409196],\n",
|
| 252 |
-
" [-0.23031706, -0.35487285, 0.1960684 , ..., -0.1283362 ,\n",
|
| 253 |
-
" 0.31107333, -0.3265108 ]], shape=(238, 480), dtype=float32),\n",
|
| 254 |
-
" array([[-0.41981837, -0.3666375 , 0.10595217, ..., 0.3903574 ,\n",
|
| 255 |
-
" 0.0382378 , -0.36337993],\n",
|
| 256 |
-
" [-0.5054137 , 0.38347068, -0.10992069, ..., -0.05231472,\n",
|
| 257 |
-
" -0.13636623, -0.34830108],\n",
|
| 258 |
-
" [-0.06784609, 0.69349885, -0.4212398 , ..., -0.24805346,\n",
|
| 259 |
-
" -0.39583805, -0.10972726],\n",
|
| 260 |
-
" ...,\n",
|
| 261 |
-
" [-0.2090099 , 0.29489496, -0.11039071, ..., -0.24245434,\n",
|
| 262 |
-
" -0.60625184, -0.02307999],\n",
|
| 263 |
-
" [ 0.19134358, 0.21744648, 0.2575827 , ..., 0.15845427,\n",
|
| 264 |
-
" -0.34743664, 0.10218249],\n",
|
| 265 |
-
" [-0.2551157 , -0.21778448, 0.21906358, ..., -0.09656111,\n",
|
| 266 |
-
" 0.22394855, -0.20267345]], shape=(222, 480), dtype=float32),\n",
|
| 267 |
-
" array([[-0.40043733, -0.48596814, 0.0886725 , ..., 0.38941646,\n",
|
| 268 |
-
" 0.06195956, -0.40999672],\n",
|
| 269 |
-
" [-0.54576075, 0.4312959 , -0.3451486 , ..., -0.09285564,\n",
|
| 270 |
-
" 0.03116508, -0.45269737],\n",
|
| 271 |
-
" [ 0.0221165 , 0.53196615, -0.30137214, ..., -0.1889072 ,\n",
|
| 272 |
-
" -0.32587305, 0.05078396],\n",
|
| 273 |
-
" ...,\n",
|
| 274 |
-
" [ 0.2630385 , -0.22976042, 0.5510368 , ..., 0.47436473,\n",
|
| 275 |
-
" -0.42733562, -0.83135855],\n",
|
| 276 |
-
" [-0.13752195, 0.28678602, -0.18887053, ..., 0.28262627,\n",
|
| 277 |
-
" 0.1254679 , -0.6496486 ],\n",
|
| 278 |
-
" [-0.4541417 , 0.24564984, 0.2132735 , ..., 0.03287445,\n",
|
| 279 |
-
" 0.03825552, -0.34259132]], shape=(124, 480), dtype=float32),\n",
|
| 280 |
-
" array([[-0.26863217, 0.32259187, 0.10813517, ..., 0.03953876,\n",
|
| 281 |
-
" 0.18312076, -0.00498045],\n",
|
| 282 |
-
" [-0.2165424 , -0.38562432, -0.02696264, ..., 0.20541488,\n",
|
| 283 |
-
" 0.18698391, -0.22639504],\n",
|
| 284 |
-
" [-0.41950518, 0.04743317, 0.0048816 , ..., 0.11408642,\n",
|
| 285 |
-
" -0.05384652, 0.1025871 ],\n",
|
| 286 |
-
" ...,\n",
|
| 287 |
-
" [-0.10960457, 0.35151365, -0.21752454, ..., -0.21448943,\n",
|
| 288 |
-
" -0.6396219 , -0.00839792],\n",
|
| 289 |
-
" [ 0.20491892, 0.36294487, 0.19217414, ..., 0.07750722,\n",
|
| 290 |
-
" -0.5039212 , 0.03793833],\n",
|
| 291 |
-
" [-0.11638474, -0.35350856, 0.13215722, ..., -0.1606055 ,\n",
|
| 292 |
-
" 0.23913842, -0.2565337 ]], shape=(115, 480), dtype=float32),\n",
|
| 293 |
-
" array([[-0.42062947, -0.44009134, 0.00152371, ..., 0.27141467,\n",
|
| 294 |
-
" 0.03798106, -0.397461 ],\n",
|
| 295 |
-
" [-0.57318133, 0.5258899 , -0.17001636, ..., -0.23864633,\n",
|
| 296 |
-
" 0.2088059 , -0.57877594],\n",
|
| 297 |
-
" [-0.38988614, 0.46168196, -0.3429413 , ..., -0.14872643,\n",
|
| 298 |
-
" -0.46576905, -0.21224979],\n",
|
| 299 |
-
" ...,\n",
|
| 300 |
-
" [-0.21528634, 0.30046722, -0.25216463, ..., -0.11576828,\n",
|
| 301 |
-
" -0.4704907 , -0.0740136 ],\n",
|
| 302 |
-
" [ 0.0633081 , 0.22700705, 0.28184187, ..., 0.15967266,\n",
|
| 303 |
-
" -0.377182 , 0.06188517],\n",
|
| 304 |
-
" [-0.27826303, -0.37297496, 0.21229912, ..., -0.14886017,\n",
|
| 305 |
-
" 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
|
| 306 |
-
]
|
| 307 |
-
},
|
| 308 |
-
"execution_count": 9,
|
| 309 |
-
"metadata": {},
|
| 310 |
-
"output_type": "execute_result"
|
| 311 |
-
}
|
| 312 |
-
],
|
| 313 |
"source": [
|
| 314 |
"ablang(all_seqs, mode='rescoding', stepwise_masking = False)"
|
| 315 |
]
|
|
@@ -330,80 +278,10 @@
|
|
| 330 |
},
|
| 331 |
{
|
| 332 |
"cell_type": "code",
|
| 333 |
-
"execution_count":
|
| 334 |
"id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
|
| 335 |
"metadata": {},
|
| 336 |
-
"outputs": [
|
| 337 |
-
{
|
| 338 |
-
"name": "stdout",
|
| 339 |
-
"output_type": "stream",
|
| 340 |
-
"text": [
|
| 341 |
-
"['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n",
|
| 342 |
-
" '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n",
|
| 343 |
-
" '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n",
|
| 344 |
-
" '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n",
|
| 345 |
-
" '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n",
|
| 346 |
-
" '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n",
|
| 347 |
-
" '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n",
|
| 348 |
-
" '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n",
|
| 349 |
-
" '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n",
|
| 350 |
-
" '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n",
|
| 351 |
-
" '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n",
|
| 352 |
-
" '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n",
|
| 353 |
-
" '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n",
|
| 354 |
-
" '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n",
|
| 355 |
-
" '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n",
|
| 356 |
-
" '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n",
|
| 357 |
-
" '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n",
|
| 358 |
-
" '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n",
|
| 359 |
-
" '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n",
|
| 360 |
-
" '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n",
|
| 361 |
-
" '125 ' '126 ' '127 ' '>']\n",
|
| 362 |
-
"['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT----->|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*N-RDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>']\n",
|
| 363 |
-
"[[[ 9.31621838 -3.42184329 -3.59397745 ... -14.73707485 -6.8935833\n",
|
| 364 |
-
" -0.23662776]\n",
|
| 365 |
-
" [ -3.54718232 -5.84866619 -4.02423859 ... -12.93966579 -9.5614481\n",
|
| 366 |
-
" -4.48473835]\n",
|
| 367 |
-
" [-11.94997597 -2.245543 -5.69481373 ... -15.19639015 -17.97454071\n",
|
| 368 |
-
" -12.56952095]\n",
|
| 369 |
-
" ...\n",
|
| 370 |
-
" [ -8.94504833 -0.42261261 -4.95588207 ... -16.66817474 -15.2224741\n",
|
| 371 |
-
" -10.37267494]\n",
|
| 372 |
-
" [-11.65150356 -5.44477606 -2.95585775 ... -16.25555801 -9.75158596\n",
|
| 373 |
-
" -11.75897026]\n",
|
| 374 |
-
" [ 1.79469728 -1.95846701 -3.59784532 ... -14.95585823 -7.47080708\n",
|
| 375 |
-
" -0.95226753]]\n",
|
| 376 |
-
"\n",
|
| 377 |
-
" [[ 8.55518723 -3.83663297 -2.33595967 ... -13.87456799 -8.14840603\n",
|
| 378 |
-
" -0.42472434]\n",
|
| 379 |
-
" [ -4.40701294 -5.53201008 -3.69397402 ... -12.97877789 -9.86258411\n",
|
| 380 |
-
" -4.95414352]\n",
|
| 381 |
-
" [-11.95642853 -3.86210871 -5.80935192 ... -14.89213085 -16.94556236\n",
|
| 382 |
-
" -11.36959839]\n",
|
| 383 |
-
" ...\n",
|
| 384 |
-
" [ -7.75924015 -0.66524202 -4.08643246 ... -16.16580772 -14.76507473\n",
|
| 385 |
-
" -8.3507061 ]\n",
|
| 386 |
-
" [-11.91039753 -4.86995983 -2.74777436 ... -16.07694817 -8.44974899\n",
|
| 387 |
-
" -10.45223904]\n",
|
| 388 |
-
" [ 0.86006832 -2.37964034 -3.58130741 ... -15.35423565 -7.73035526\n",
|
| 389 |
-
" -1.11989737]]\n",
|
| 390 |
-
"\n",
|
| 391 |
-
" [[ -4.37902737 -7.55587149 1.21958363 ... -15.48622513 -6.021842\n",
|
| 392 |
-
" -3.79647374]\n",
|
| 393 |
-
" [ 0. 0. 0. ... 0. 0.\n",
|
| 394 |
-
" 0. ]\n",
|
| 395 |
-
" [ 0. 0. 0. ... 0. 0.\n",
|
| 396 |
-
" 0. ]\n",
|
| 397 |
-
" ...\n",
|
| 398 |
-
" [ -8.94207573 -0.51090252 -5.09760332 ... -16.69521713 -15.45450687\n",
|
| 399 |
-
" -10.50823212]\n",
|
| 400 |
-
" [-11.92354965 -5.55152607 -2.87666893 ... -16.40607834 -10.19431686\n",
|
| 401 |
-
" -12.1328764 ]\n",
|
| 402 |
-
" [ 2.42200375 -2.01573253 -3.61701298 ... -14.9590435 -7.19029331\n",
|
| 403 |
-
" -0.89830256]]]\n"
|
| 404 |
-
]
|
| 405 |
-
}
|
| 406 |
-
],
|
| 407 |
"source": [
|
| 408 |
"results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n",
|
| 409 |
"\n",
|
|
@@ -414,60 +292,10 @@
|
|
| 414 |
},
|
| 415 |
{
|
| 416 |
"cell_type": "code",
|
| 417 |
-
"execution_count":
|
| 418 |
"id": "56be8cad",
|
| 419 |
"metadata": {},
|
| 420 |
-
"outputs": [
|
| 421 |
-
{
|
| 422 |
-
"data": {
|
| 423 |
-
"text/plain": [
|
| 424 |
-
"[array([[9.9955505e-01, 2.9358694e-06, 2.4716087e-06, ..., 3.5776201e-11,\n",
|
| 425 |
-
" 9.1196831e-08, 7.0967326e-05],\n",
|
| 426 |
-
" [4.1573694e-06, 4.1619489e-07, 2.5800944e-06, ..., 3.4650952e-10,\n",
|
| 427 |
-
" 1.0159109e-08, 1.6279575e-06],\n",
|
| 428 |
-
" [7.8059600e-08, 1.2794037e-03, 4.0645118e-05, ..., 3.0375720e-09,\n",
|
| 429 |
-
" 1.8879491e-10, 4.2010839e-08],\n",
|
| 430 |
-
" ...,\n",
|
| 431 |
-
" [3.4210879e-07, 1.7195340e-03, 1.8477240e-05, ..., 1.5137445e-10,\n",
|
| 432 |
-
" 6.4255873e-10, 8.2064140e-08],\n",
|
| 433 |
-
" [9.1038084e-09, 4.5161755e-06, 5.4411950e-05, ..., 9.1139631e-11,\n",
|
| 434 |
-
" 6.0862085e-08, 8.1761966e-09],\n",
|
| 435 |
-
" [8.5759175e-04, 2.0104915e-05, 3.9023766e-06, ..., 4.5562460e-11,\n",
|
| 436 |
-
" 8.1156479e-08, 5.4990651e-05]], shape=(238, 26), dtype=float32),\n",
|
| 437 |
-
" array([[9.9939799e-01, 4.1499175e-06, 1.8611167e-05, ..., 1.8139243e-10,\n",
|
| 438 |
-
" 5.5649299e-08, 1.2583815e-04],\n",
|
| 439 |
-
" [1.6735513e-06, 5.4332406e-07, 3.4143472e-06, ..., 3.1693398e-10,\n",
|
| 440 |
-
" 7.1501400e-09, 9.6832969e-07],\n",
|
| 441 |
-
" [3.7784993e-08, 1.2377645e-04, 1.7658784e-05, ..., 2.0061326e-09,\n",
|
| 442 |
-
" 2.5737484e-10, 6.7947965e-08],\n",
|
| 443 |
-
" ...,\n",
|
| 444 |
-
" [1.1050455e-06, 1.3312638e-03, 4.3497097e-05, ..., 2.4686178e-10,\n",
|
| 445 |
-
" 1.0018089e-09, 6.1165900e-07],\n",
|
| 446 |
-
" [5.7270397e-09, 6.5396339e-06, 5.4601755e-05, ..., 8.8801404e-11,\n",
|
| 447 |
-
" 1.8233513e-07, 2.4615032e-08],\n",
|
| 448 |
-
" [7.3952030e-04, 2.8970928e-05, 8.7113440e-06, ..., 6.7168833e-11,\n",
|
| 449 |
-
" 1.3746008e-07, 1.0210846e-04]], shape=(222, 26), dtype=float32),\n",
|
| 450 |
-
" array([[9.99685407e-01, 3.35662639e-06, 1.14241482e-06, ...,\n",
|
| 451 |
-
" 2.32460891e-11, 6.88188067e-08, 5.69467156e-05],\n",
|
| 452 |
-
" [6.38133372e-07, 1.01300586e-07, 5.64459742e-06, ...,\n",
|
| 453 |
-
" 4.09234556e-11, 2.53804799e-09, 4.31722100e-07],\n",
|
| 454 |
-
" [1.49096788e-08, 2.04515047e-04, 9.23794141e-06, ...,\n",
|
| 455 |
-
" 7.46306961e-10, 2.92107380e-11, 2.21786500e-08],\n",
|
| 456 |
-
" ...,\n",
|
| 457 |
-
" [2.15093763e-07, 1.06453872e-03, 1.62486140e-05, ...,\n",
|
| 458 |
-
" 1.12102910e-10, 1.47300866e-10, 4.73037538e-08],\n",
|
| 459 |
-
" [4.30136682e-09, 3.09317988e-06, 3.96632568e-05, ...,\n",
|
| 460 |
-
" 5.24226877e-11, 2.39579450e-08, 3.86403221e-09],\n",
|
| 461 |
-
" [9.77773685e-04, 1.29533228e-05, 2.78623725e-06, ...,\n",
|
| 462 |
-
" 2.73364300e-11, 3.96418649e-08, 4.04014427e-05]],\n",
|
| 463 |
-
" shape=(238, 26), dtype=float32)]"
|
| 464 |
-
]
|
| 465 |
-
},
|
| 466 |
-
"execution_count": 9,
|
| 467 |
-
"metadata": {},
|
| 468 |
-
"output_type": "execute_result"
|
| 469 |
-
}
|
| 470 |
-
],
|
| 471 |
"source": [
|
| 472 |
"ablang(only_both_chains_seqs, mode='probability')"
|
| 473 |
]
|
|
@@ -492,21 +320,10 @@
|
|
| 492 |
},
|
| 493 |
{
|
| 494 |
"cell_type": "code",
|
| 495 |
-
"execution_count":
|
| 496 |
"id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
|
| 497 |
"metadata": {},
|
| 498 |
-
"outputs": [
|
| 499 |
-
{
|
| 500 |
-
"data": {
|
| 501 |
-
"text/plain": [
|
| 502 |
-
"array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
|
| 503 |
-
]
|
| 504 |
-
},
|
| 505 |
-
"execution_count": 12,
|
| 506 |
-
"metadata": {},
|
| 507 |
-
"output_type": "execute_result"
|
| 508 |
-
}
|
| 509 |
-
],
|
| 510 |
"source": [
|
| 511 |
"results = ablang(all_seqs, mode='pseudo_log_likelihood')\n",
|
| 512 |
"np.exp(-results) # convert to pseudo perplexity"
|
|
@@ -514,22 +331,10 @@
|
|
| 514 |
},
|
| 515 |
{
|
| 516 |
"cell_type": "code",
|
| 517 |
-
"execution_count":
|
| 518 |
"id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
|
| 519 |
"metadata": {},
|
| 520 |
-
"outputs": [
|
| 521 |
-
{
|
| 522 |
-
"data": {
|
| 523 |
-
"text/plain": [
|
| 524 |
-
"array([1.2636038, 1.126463 , 1.3123759, 1.2140924, 1.1805094],\n",
|
| 525 |
-
" dtype=float32)"
|
| 526 |
-
]
|
| 527 |
-
},
|
| 528 |
-
"execution_count": 13,
|
| 529 |
-
"metadata": {},
|
| 530 |
-
"output_type": "execute_result"
|
| 531 |
-
}
|
| 532 |
-
],
|
| 533 |
"source": [
|
| 534 |
"results = ablang(all_seqs, mode='confidence')\n",
|
| 535 |
"np.exp(-results)"
|
|
@@ -547,24 +352,10 @@
|
|
| 547 |
},
|
| 548 |
{
|
| 549 |
"cell_type": "code",
|
| 550 |
-
"execution_count":
|
| 551 |
"id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
|
| 552 |
"metadata": {},
|
| 553 |
-
"outputs": [
|
| 554 |
-
{
|
| 555 |
-
"data": {
|
| 556 |
-
"text/plain": [
|
| 557 |
-
"array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
|
| 558 |
-
" '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT>|<PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
|
| 559 |
-
" '<EVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
|
| 560 |
-
" dtype='<U238')"
|
| 561 |
-
]
|
| 562 |
-
},
|
| 563 |
-
"execution_count": 14,
|
| 564 |
-
"metadata": {},
|
| 565 |
-
"output_type": "execute_result"
|
| 566 |
-
}
|
| 567 |
-
],
|
| 568 |
"source": [
|
| 569 |
"restored = ablang(only_both_chains_seqs, mode='restore')\n",
|
| 570 |
"restored"
|
|
@@ -572,24 +363,10 @@
|
|
| 572 |
},
|
| 573 |
{
|
| 574 |
"cell_type": "code",
|
| 575 |
-
"execution_count":
|
| 576 |
"id": "0e9615f7-c490-4947-96f4-7617266c686e",
|
| 577 |
"metadata": {},
|
| 578 |
-
"outputs": [
|
| 579 |
-
{
|
| 580 |
-
"data": {
|
| 581 |
-
"text/plain": [
|
| 582 |
-
"array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
|
| 583 |
-
" '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DVVMTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
|
| 584 |
-
" '<QVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
|
| 585 |
-
" dtype='<U238')"
|
| 586 |
-
]
|
| 587 |
-
},
|
| 588 |
-
"execution_count": 15,
|
| 589 |
-
"metadata": {},
|
| 590 |
-
"output_type": "execute_result"
|
| 591 |
-
}
|
| 592 |
-
],
|
| 593 |
"source": [
|
| 594 |
"restored = ablang(only_both_chains_seqs, mode='restore', align = True)\n",
|
| 595 |
"restored"
|
|
|
|
| 86 |
"id": "6d66ad84",
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [
|
| 89 |
+
{
|
| 90 |
+
"data": {
|
| 91 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 92 |
+
"model_id": "a5acedae3cc4420ea2971400b0915426",
|
| 93 |
+
"version_major": 2,
|
| 94 |
+
"version_minor": 0
|
| 95 |
+
},
|
| 96 |
+
"text/plain": [
|
| 97 |
+
"config.json: 0%| | 0.00/560 [00:00<?, ?B/s]"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"output_type": "display_data"
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"data": {
|
| 105 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 106 |
+
"model_id": "5727addb151447cf9bb091ef1159717c",
|
| 107 |
+
"version_major": 2,
|
| 108 |
+
"version_minor": 0
|
| 109 |
+
},
|
| 110 |
+
"text/plain": [
|
| 111 |
+
"configuration_ablang2paired.py: 0.00B [00:00, ?B/s]"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"output_type": "display_data"
|
| 116 |
+
},
|
| 117 |
{
|
| 118 |
"name": "stderr",
|
| 119 |
"output_type": "stream",
|
| 120 |
"text": [
|
| 121 |
"A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
|
| 122 |
"- configuration_ablang2paired.py\n",
|
| 123 |
+
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
]
|
| 125 |
},
|
| 126 |
{
|
| 127 |
+
"data": {
|
| 128 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 129 |
+
"model_id": "845b45d4aed542dc86ab7b7ac3305a0e",
|
| 130 |
+
"version_major": 2,
|
| 131 |
+
"version_minor": 0
|
| 132 |
+
},
|
| 133 |
+
"text/plain": [
|
| 134 |
+
"modeling_ablang2paired.py: 0.00B [00:00, ?B/s]"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"output_type": "display_data"
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"name": "stderr",
|
| 142 |
"output_type": "stream",
|
| 143 |
"text": [
|
| 144 |
+
"Encountered exception while importing ablang2: No module named 'ablang2'\n"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"ename": "ImportError",
|
| 149 |
+
"evalue": "This modeling file requires the following packages that were not found in your environment: ablang2. Run `pip install ablang2`",
|
| 150 |
+
"output_type": "error",
|
| 151 |
+
"traceback": [
|
| 152 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 153 |
+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
| 154 |
+
"Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Load model and tokenizer from Hugging Face Hub\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhemantn/ablang2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhemantn/ablang2\u001b[39m\u001b[38;5;124m\"\u001b[39m, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Find the cached model directory and import adapter\u001b[39;00m\n",
|
| 155 |
+
"File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:582\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 579\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madapter_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m adapter_kwargs\n\u001b[1;32m 581\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_remote_code \u001b[38;5;129;01mand\u001b[39;00m trust_remote_code:\n\u001b[0;32m--> 582\u001b[0m model_class \u001b[38;5;241m=\u001b[39m \u001b[43mget_class_from_dynamic_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_ref\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode_revision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 585\u001b[0m _ \u001b[38;5;241m=\u001b[39m hub_kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcode_revision\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 586\u001b[0m \u001b[38;5;66;03m# This block handles the case where the user is loading a model with `trust_remote_code=True`\u001b[39;00m\n\u001b[1;32m 587\u001b[0m \u001b[38;5;66;03m# but a library model exists with the same name. We don't want to override the autoclass\u001b[39;00m\n\u001b[1;32m 588\u001b[0m \u001b[38;5;66;03m# mappings in this case, or all future loads of that model will be the remote code model.\u001b[39;00m\n",
|
| 156 |
+
"File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:570\u001b[0m, in \u001b[0;36mget_class_from_dynamic_module\u001b[0;34m(class_reference, pretrained_model_name_or_path, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, code_revision, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m code_revision \u001b[38;5;241m=\u001b[39m revision\n\u001b[1;32m 569\u001b[0m \u001b[38;5;66;03m# And lastly we get the class inside our newly created module\u001b[39;00m\n\u001b[0;32m--> 570\u001b[0m final_module \u001b[38;5;241m=\u001b[39m \u001b[43mget_cached_module_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 571\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 572\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_file\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.py\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 574\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 575\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 577\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 578\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 579\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m get_class_in_module(class_name, final_module, force_reload\u001b[38;5;241m=\u001b[39mforce_download)\n",
|
| 157 |
+
"File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:393\u001b[0m, in \u001b[0;36mget_cached_module_file\u001b[0;34m(pretrained_model_name_or_path, module_file, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 392\u001b[0m \u001b[38;5;66;03m# Check we have all the requirements in our environment\u001b[39;00m\n\u001b[0;32m--> 393\u001b[0m modules_needed \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_imports\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresolved_module_file\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;66;03m# Now we move the module inside our cached dynamic modules.\u001b[39;00m\n\u001b[1;32m 396\u001b[0m full_submodule \u001b[38;5;241m=\u001b[39m TRANSFORMERS_DYNAMIC_MODULE_NAME \u001b[38;5;241m+\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39msep \u001b[38;5;241m+\u001b[39m submodule\n",
|
| 158 |
+
"File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:225\u001b[0m, in \u001b[0;36mcheck_imports\u001b[0;34m(filename)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(missing_packages) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 226\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis modeling file requires the following packages that were not found in your environment: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing_packages)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Run `pip install \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing_packages)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 228\u001b[0m )\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m get_relative_imports(filename)\n",
|
| 159 |
+
"\u001b[0;31mImportError\u001b[0m: This modeling file requires the following packages that were not found in your environment: ablang2. Run `pip install ablang2`"
|
| 160 |
]
|
| 161 |
}
|
| 162 |
],
|
|
|
|
| 205 |
},
|
| 206 |
{
|
| 207 |
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
"id": "ceae4a88-0679-4704-8bad-c06a4569c497",
|
| 210 |
"metadata": {},
|
| 211 |
"outputs": [],
|
|
|
|
| 230 |
},
|
| 231 |
{
|
| 232 |
"cell_type": "code",
|
| 233 |
+
"execution_count": null,
|
| 234 |
"id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
|
| 235 |
"metadata": {},
|
| 236 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
"source": [
|
| 238 |
"ablang(all_seqs, mode='seqcoding')\n"
|
| 239 |
]
|
|
|
|
| 254 |
},
|
| 255 |
{
|
| 256 |
"cell_type": "code",
|
| 257 |
+
"execution_count": null,
|
| 258 |
"id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
|
| 259 |
"metadata": {},
|
| 260 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
"source": [
|
| 262 |
"ablang(all_seqs, mode='rescoding', stepwise_masking = False)"
|
| 263 |
]
|
|
|
|
| 278 |
},
|
| 279 |
{
|
| 280 |
"cell_type": "code",
|
| 281 |
+
"execution_count": null,
|
| 282 |
"id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
|
| 283 |
"metadata": {},
|
| 284 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
"source": [
|
| 286 |
"results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n",
|
| 287 |
"\n",
|
|
|
|
| 292 |
},
|
| 293 |
{
|
| 294 |
"cell_type": "code",
|
| 295 |
+
"execution_count": null,
|
| 296 |
"id": "56be8cad",
|
| 297 |
"metadata": {},
|
| 298 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
"source": [
|
| 300 |
"ablang(only_both_chains_seqs, mode='probability')"
|
| 301 |
]
|
|
|
|
| 320 |
},
|
| 321 |
{
|
| 322 |
"cell_type": "code",
|
| 323 |
+
"execution_count": null,
|
| 324 |
"id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
|
| 325 |
"metadata": {},
|
| 326 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
"source": [
|
| 328 |
"results = ablang(all_seqs, mode='pseudo_log_likelihood')\n",
|
| 329 |
"np.exp(-results) # convert to pseudo perplexity"
|
|
|
|
| 331 |
},
|
| 332 |
{
|
| 333 |
"cell_type": "code",
|
| 334 |
+
"execution_count": null,
|
| 335 |
"id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
|
| 336 |
"metadata": {},
|
| 337 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
"source": [
|
| 339 |
"results = ablang(all_seqs, mode='confidence')\n",
|
| 340 |
"np.exp(-results)"
|
|
|
|
| 352 |
},
|
| 353 |
{
|
| 354 |
"cell_type": "code",
|
| 355 |
+
"execution_count": null,
|
| 356 |
"id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
|
| 357 |
"metadata": {},
|
| 358 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
"source": [
|
| 360 |
"restored = ablang(only_both_chains_seqs, mode='restore')\n",
|
| 361 |
"restored"
|
|
|
|
| 363 |
},
|
| 364 |
{
|
| 365 |
"cell_type": "code",
|
| 366 |
+
"execution_count": null,
|
| 367 |
"id": "0e9615f7-c490-4947-96f4-7617266c686e",
|
| 368 |
"metadata": {},
|
| 369 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
"source": [
|
| 371 |
"restored = ablang(only_both_chains_seqs, mode='restore', align = True)\n",
|
| 372 |
"restored"
|
test_module_loading.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModel, AutoTokenizer
|
| 5 |
+
from transformers.utils import cached_file
|
| 6 |
+
|
| 7 |
+
# Load model and tokenizer from Hugging Face Hub
|
| 8 |
+
model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
|
| 10 |
+
|
| 11 |
+
# Find the cached model directory and import adapter
|
| 12 |
+
adapter_path = cached_file("hemantn/ablang2", "adapter.py")
|
| 13 |
+
cached_model_dir = os.path.dirname(adapter_path)
|
| 14 |
+
sys.path.insert(0, cached_model_dir)
|
| 15 |
+
|
| 16 |
+
# Import and create the adapter
|
| 17 |
+
from adapter import AbLang2PairedHuggingFaceAdapter
|
| 18 |
+
ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
|
| 19 |
+
|