hemantn commited on
Commit
7ebbadf
·
1 Parent(s): cf7aac0

Add ablang.py and encoderblock.py to root directory for Hugging Face compatibility

Browse files
ablang.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from .encoderblock import TransformerEncoder, get_activation_fn
9
+
10
+
11
+ class AbLang(torch.nn.Module):
12
+ """
13
+ AbLang inspired by ESM-2's architecture.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size,
19
+ hidden_embed_size,
20
+ n_attn_heads,
21
+ n_encoder_blocks,
22
+ padding_tkn,
23
+ mask_tkn,
24
+ layer_norm_eps: float = 1e-12,
25
+ a_fn: str = "gelu",
26
+ dropout: float = 0.0,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.AbRep = AbRep(
31
+ vocab_size,
32
+ hidden_embed_size,
33
+ n_attn_heads,
34
+ n_encoder_blocks,
35
+ padding_tkn,
36
+ mask_tkn,
37
+ layer_norm_eps,
38
+ a_fn,
39
+ dropout,
40
+ )
41
+ self.AbHead = AbHead(
42
+ vocab_size,
43
+ hidden_embed_size,
44
+ self.AbRep.aa_embed_layer.weight,
45
+ layer_norm_eps,
46
+ a_fn,
47
+ )
48
+
49
+ def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
50
+
51
+ representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
52
+
53
+ if return_attn_weights:
54
+ return representations.attention_weights
55
+
56
+ elif return_rep_layers != []:
57
+ return representations.many_hidden_states
58
+ else:
59
+ likelihoods = self.AbHead(representations.last_hidden_states)
60
+ return likelihoods
61
+
62
+ def get_aa_embeddings(self):
63
+ "Extracts the trained aa_embeddings."
64
+ return self.AbRep.aa_embed_layer
65
+
66
+
67
+ class AbRep(torch.nn.Module):
68
+ """
69
+ AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations).
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size,
75
+ hidden_embed_size,
76
+ n_attn_heads,
77
+ n_encoder_blocks,
78
+ padding_tkn,
79
+ mask_tkn,
80
+ layer_norm_eps: float = 1e-12,
81
+ a_fn: str = "gelu",
82
+ dropout: float = 0.1,
83
+ ):
84
+ super().__init__()
85
+ self.padding_tkn = padding_tkn
86
+ self.mask_tkn = mask_tkn
87
+
88
+ self.aa_embed_layer = nn.Embedding(
89
+ vocab_size,
90
+ hidden_embed_size,
91
+ padding_idx=padding_tkn,
92
+ )
93
+ self.encoder_blocks = nn.ModuleList(
94
+ [TransformerEncoder(
95
+ hidden_embed_size,
96
+ n_attn_heads,
97
+ attn_dropout = dropout,
98
+ layer_norm_eps = layer_norm_eps,
99
+ a_fn = a_fn,
100
+ ) for _ in range(n_encoder_blocks)]
101
+ )
102
+ self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
103
+
104
+ def forward(self,
105
+ tokens,
106
+ return_attn_weights=False,
107
+ return_rep_layers=[],
108
+ ):
109
+
110
+ assert tokens.ndim == 2
111
+ padding_mask = tokens.eq(self.padding_tkn)
112
+
113
+ hidden_embed = self.aa_embed_layer(tokens)
114
+
115
+ return_rep_layers = set(return_rep_layers)
116
+ rep_layers = {}
117
+ if 0 in return_rep_layers: rep_layers[0] = hidden_embed
118
+
119
+ all_attn_weights = []
120
+
121
+ for n_layer, encoder_block in enumerate(self.encoder_blocks):
122
+ hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights)
123
+
124
+ if (n_layer + 1) in return_rep_layers:
125
+ rep_layers[n_layer + 1] = hidden_embed
126
+
127
+ if return_attn_weights:
128
+ all_attn_weights.append(attn_weights)
129
+
130
+ hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed)
131
+
132
+ return DataAbRep(
133
+ last_hidden_states=hidden_embed,
134
+ many_hidden_states=rep_layers,
135
+ attention_weights=all_attn_weights
136
+ )
137
+
138
+
139
+ class AbHead(torch.nn.Module):
140
+ """
141
+ AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations).
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ vocab_size,
147
+ hidden_embed_size,
148
+ weights,
149
+ layer_norm_eps: float = 1e-12,
150
+ a_fn: str = "gelu",
151
+ ):
152
+ super().__init__()
153
+
154
+ activation_fn, scale = get_activation_fn(a_fn)
155
+
156
+ self.ff = torch.nn.Sequential(
157
+ nn.Linear(hidden_embed_size, hidden_embed_size * scale),
158
+ activation_fn(),
159
+ nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps),
160
+ )
161
+
162
+ self.weights = weights
163
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
164
+
165
+ def forward(self, hidden_embed):
166
+
167
+ hidden_embed = self.ff(hidden_embed)
168
+ logits = F.linear(hidden_embed, self.weights) + self.bias
169
+
170
+ return logits
171
+
172
+
173
+ @dataclass
174
+ class DataAbRep():
175
+ """
176
+ Dataclass used to store AbRep output.
177
+ """
178
+
179
+ last_hidden_states: torch.FloatTensor
180
+ many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
181
+ attention_weights: Optional[Tuple[torch.FloatTensor]] = None
encoderblock.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import einops
6
+ from rotary_embedding_torch import RotaryEmbedding
7
+
8
+ class TransformerEncoder(torch.nn.Module):
9
+ """
10
+ Single Transformer Encoder.
11
+
12
+ """
13
+ def __init__(
14
+ self,
15
+ hidden_embed_size,
16
+ n_attn_heads,
17
+ attn_dropout: float = 0.0,
18
+ layer_norm_eps: float = 1e-05,
19
+ a_fn: str = "gelu",
20
+ ):
21
+ super().__init__()
22
+
23
+ assert hidden_embed_size % n_attn_heads == 0, \
24
+ "Embedding dimension must be devisible with the number of heads."
25
+
26
+ self.multihead_attention = MultiHeadAttention(
27
+ embed_dim = hidden_embed_size,
28
+ num_heads = n_attn_heads,
29
+ attention_dropout_prob = attn_dropout
30
+ )
31
+
32
+ activation_fn, scale = get_activation_fn(a_fn)
33
+
34
+ self.intermediate_layer = torch.nn.Sequential(
35
+ torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale),
36
+ activation_fn(),
37
+ torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size),
38
+ )
39
+
40
+ self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
41
+ self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
42
+
43
+ def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False):
44
+
45
+ residual = hidden_embed
46
+ hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone())
47
+ hidden_embed, attn_weights = self.multihead_attention(
48
+ hidden_embed,
49
+ attn_mask=attn_mask,
50
+ return_attn_weights=return_attn_weights
51
+ )
52
+ hidden_embed = residual + hidden_embed
53
+
54
+ residual = hidden_embed
55
+ hidden_embed = self.final_layer_norm(hidden_embed)
56
+ hidden_embed = self.intermediate_layer(hidden_embed)
57
+ hidden_embed = residual + hidden_embed
58
+ return hidden_embed, attn_weights
59
+
60
+ class MultiHeadAttention(torch.nn.Module):
61
+
62
+ def __init__(
63
+ self,
64
+ embed_dim,
65
+ num_heads,
66
+ attention_dropout_prob: float = 0.0,
67
+ bias: bool = True,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
72
+
73
+ self.embed_dim = embed_dim
74
+ self.num_heads = num_heads
75
+ self.head_dim = embed_dim // num_heads
76
+ assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
77
+ self.scaling = self.head_dim**-0.5
78
+
79
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
80
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
81
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
82
+
83
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
84
+
85
+ self.reset_parameters()
86
+
87
+ self.rotary_emb = RotaryEmbedding(dim = self.head_dim)
88
+
89
+ def reset_parameters(self):
90
+
91
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
92
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
93
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
94
+
95
+ nn.init.xavier_uniform_(self.out_proj.weight)
96
+ if self.out_proj.bias is not None:
97
+ nn.init.constant_(self.out_proj.bias, 0.0)
98
+
99
+ def attention(self, q, k, v, attn_mask=None):
100
+
101
+ attn_weights = torch.matmul(q, k.transpose(-2, -1))
102
+ attn_weights = attn_weights / math.sqrt(self.head_dim)
103
+
104
+ if attn_mask is not None:
105
+ attn_mask = einops.rearrange(
106
+ attn_mask,
107
+ 'b_size (h1 h2 seq_len) -> b_size h1 h2 seq_len',
108
+ h1=1, h2=1
109
+ )
110
+ attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
111
+
112
+ attn_weights = F.softmax(attn_weights, dim=-1)
113
+
114
+ attn = self.attention_dropout(attn_weights)
115
+ attn = torch.matmul(attn, v)
116
+ return attn, attn_weights
117
+
118
+ def forward(self, x, attn_mask=None, return_attn_weights: bool = False):
119
+
120
+ batch_size, seq_len, embed_dim = x.size()
121
+
122
+ q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
123
+ q *= self.scaling
124
+
125
+ q = q.contiguous().view(
126
+ batch_size,
127
+ seq_len,
128
+ self.num_heads,
129
+ self.head_dim
130
+ ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
131
+ k = k.contiguous().view(
132
+ batch_size,
133
+ seq_len,
134
+ self.num_heads,
135
+ self.head_dim
136
+ ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
137
+ v = v.contiguous().view(
138
+ batch_size,
139
+ seq_len,
140
+ self.num_heads,
141
+ self.head_dim
142
+ ).transpose(1, 2) # [n_batch, n_heads, seq_len, head_dim]
143
+
144
+ q = self.rotary_emb.rotate_queries_or_keys(q)
145
+ k = self.rotary_emb.rotate_queries_or_keys(k)
146
+
147
+ # Determine value outputs
148
+ attn, attn_weights = self.attention(
149
+ q, k, v,
150
+ attn_mask=attn_mask
151
+ ) # attn_weights [n_batch, n_heads, seq_len (target), seq_len (source)]
152
+
153
+ attn = attn.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
154
+ attn = self.out_proj(attn)
155
+
156
+ if return_attn_weights:
157
+ return attn, attn_weights
158
+ else:
159
+ return attn, None
160
+
161
+ class SwiGLU(torch.nn.Module):
162
+ def forward(self, x):
163
+ x, gate = x.chunk(2, dim=-1)
164
+ return F.silu(gate) * x
165
+
166
+ def get_activation_fn(a_fn):
167
+
168
+ if a_fn == "gelu":
169
+ return torch.nn.GELU, 1
170
+
171
+ elif a_fn == "swiglu":
172
+ return SwiGLU, 2
173
+
modeling_ablang2paired.py CHANGED
@@ -9,29 +9,19 @@ try:
9
  except ImportError:
10
  from configuration_ablang2paired import AbLang2PairedConfig
11
 
12
- # Import the AbLang model from the local file structure
13
- import importlib.util
14
- import os
15
-
16
- def load_ablang_module():
17
- """Load the AbLang module from the local directory structure."""
18
- # Try to find the ablang.py file in the local directory
19
- current_dir = os.path.dirname(os.path.abspath(__file__))
20
- ablang_path = os.path.join(current_dir, "ablang2", "models", "ablang2", "ablang.py")
21
-
22
- if os.path.exists(ablang_path):
23
- spec = importlib.util.spec_from_file_location("ablang", ablang_path)
24
- ablang_module = importlib.util.module_from_spec(spec)
25
- spec.loader.exec_module(ablang_module)
26
- return ablang_module.AbLang
27
- else:
28
- # If not found, raise an error with helpful message
29
  raise ImportError(
30
- "Could not find AbLang module. Please ensure the ablang2 directory structure is present "
31
- "in the repository."
32
  )
33
 
34
- AbLang = load_ablang_module()
35
 
36
  class AbLang2PairedHFModel(PreTrainedModel):
37
  config_class = AbLang2PairedConfig
 
9
  except ImportError:
10
  from configuration_ablang2paired import AbLang2PairedConfig
11
 
12
+ # Import the AbLang model from local files
13
+ try:
14
+ from ablang import AbLang
15
+ except ImportError:
16
+ # Fallback: try to import from the current directory
17
+ try:
18
+ from .ablang import AbLang
19
+ except ImportError:
 
 
 
 
 
 
 
 
 
20
  raise ImportError(
21
+ "Could not find AbLang module. Please ensure ablang.py is present in the repository."
 
22
  )
23
 
24
+
25
 
26
  class AbLang2PairedHFModel(PreTrainedModel):
27
  config_class = AbLang2PairedConfig
test_ablang2_HF_implementation.ipynb CHANGED
@@ -86,34 +86,77 @@
86
  "id": "6d66ad84",
87
  "metadata": {},
88
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  {
90
  "name": "stderr",
91
  "output_type": "stream",
92
  "text": [
93
  "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
94
  "- configuration_ablang2paired.py\n",
95
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
96
- "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
97
- "- modeling_ablang2paired.py\n",
98
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
99
- "/home/hn533621/.conda/envs/lib_transformer/lib/python3.10/site-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
100
- " warnings.warn(\n"
101
  ]
102
  },
103
  {
104
- "name": "stdout",
105
- "output_type": "stream",
106
- "text": [
107
- "✅ Loaded custom weights from: /home/hn533621/.cache/huggingface/hub/models--hemantn--ablang2/snapshots/e1df3c0a25269eaeb91c4891125dd9a8580a01b7/model.pt\n"
108
- ]
 
 
 
 
 
 
 
109
  },
110
  {
111
  "name": "stderr",
112
  "output_type": "stream",
113
  "text": [
114
- "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
115
- "- tokenizer_ablang2paired.py\n",
116
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  ]
118
  }
119
  ],
@@ -162,7 +205,7 @@
162
  },
163
  {
164
  "cell_type": "code",
165
- "execution_count": 7,
166
  "id": "ceae4a88-0679-4704-8bad-c06a4569c497",
167
  "metadata": {},
168
  "outputs": [],
@@ -187,30 +230,10 @@
187
  },
188
  {
189
  "cell_type": "code",
190
- "execution_count": 8,
191
  "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
192
  "metadata": {},
193
- "outputs": [
194
- {
195
- "data": {
196
- "text/plain": [
197
- "array([[-0.25206311, 0.18189634, 0.00887137, ..., 0.15365517,\n",
198
- " -0.14508603, -0.13381317],\n",
199
- " [-0.25149415, 0.2086455 , 0.07518203, ..., 0.19478269,\n",
200
- " -0.15227772, -0.08241647],\n",
201
- " [-0.27468949, 0.16507216, 0.08667156, ..., 0.18776284,\n",
202
- " -0.14165082, -0.16389885],\n",
203
- " [-0.1982213 , 0.16841085, -0.04925933, ..., 0.11400164,\n",
204
- " -0.14723683, -0.09713171],\n",
205
- " [-0.29553188, 0.17239201, 0.05676926, ..., 0.15943622,\n",
206
- " -0.16615383, -0.15569784]], shape=(5, 480))"
207
- ]
208
- },
209
- "execution_count": 8,
210
- "metadata": {},
211
- "output_type": "execute_result"
212
- }
213
- ],
214
  "source": [
215
  "ablang(all_seqs, mode='seqcoding')\n"
216
  ]
@@ -231,85 +254,10 @@
231
  },
232
  {
233
  "cell_type": "code",
234
- "execution_count": 9,
235
  "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
236
  "metadata": {},
237
- "outputs": [
238
- {
239
- "data": {
240
- "text/plain": [
241
- "[array([[-0.40741208, -0.5118987 , 0.06096708, ..., 0.3268144 ,\n",
242
- " 0.03920235, -0.36715826],\n",
243
- " [-0.5768883 , 0.38245413, -0.21791998, ..., 0.01250262,\n",
244
- " -0.08844463, -0.32367525],\n",
245
- " [-0.1475935 , 0.39639047, -0.38226923, ..., -0.10119921,\n",
246
- " -0.41469565, -0.00319315],\n",
247
- " ...,\n",
248
- " [-0.14358369, 0.3124389 , -0.30157998, ..., -0.13289244,\n",
249
- " -0.45353398, -0.07878865],\n",
250
- " [ 0.17538925, 0.24394299, 0.20141171, ..., 0.14587352,\n",
251
- " -0.38479003, 0.07409196],\n",
252
- " [-0.23031706, -0.35487285, 0.1960684 , ..., -0.1283362 ,\n",
253
- " 0.31107333, -0.3265108 ]], shape=(238, 480), dtype=float32),\n",
254
- " array([[-0.41981837, -0.3666375 , 0.10595217, ..., 0.3903574 ,\n",
255
- " 0.0382378 , -0.36337993],\n",
256
- " [-0.5054137 , 0.38347068, -0.10992069, ..., -0.05231472,\n",
257
- " -0.13636623, -0.34830108],\n",
258
- " [-0.06784609, 0.69349885, -0.4212398 , ..., -0.24805346,\n",
259
- " -0.39583805, -0.10972726],\n",
260
- " ...,\n",
261
- " [-0.2090099 , 0.29489496, -0.11039071, ..., -0.24245434,\n",
262
- " -0.60625184, -0.02307999],\n",
263
- " [ 0.19134358, 0.21744648, 0.2575827 , ..., 0.15845427,\n",
264
- " -0.34743664, 0.10218249],\n",
265
- " [-0.2551157 , -0.21778448, 0.21906358, ..., -0.09656111,\n",
266
- " 0.22394855, -0.20267345]], shape=(222, 480), dtype=float32),\n",
267
- " array([[-0.40043733, -0.48596814, 0.0886725 , ..., 0.38941646,\n",
268
- " 0.06195956, -0.40999672],\n",
269
- " [-0.54576075, 0.4312959 , -0.3451486 , ..., -0.09285564,\n",
270
- " 0.03116508, -0.45269737],\n",
271
- " [ 0.0221165 , 0.53196615, -0.30137214, ..., -0.1889072 ,\n",
272
- " -0.32587305, 0.05078396],\n",
273
- " ...,\n",
274
- " [ 0.2630385 , -0.22976042, 0.5510368 , ..., 0.47436473,\n",
275
- " -0.42733562, -0.83135855],\n",
276
- " [-0.13752195, 0.28678602, -0.18887053, ..., 0.28262627,\n",
277
- " 0.1254679 , -0.6496486 ],\n",
278
- " [-0.4541417 , 0.24564984, 0.2132735 , ..., 0.03287445,\n",
279
- " 0.03825552, -0.34259132]], shape=(124, 480), dtype=float32),\n",
280
- " array([[-0.26863217, 0.32259187, 0.10813517, ..., 0.03953876,\n",
281
- " 0.18312076, -0.00498045],\n",
282
- " [-0.2165424 , -0.38562432, -0.02696264, ..., 0.20541488,\n",
283
- " 0.18698391, -0.22639504],\n",
284
- " [-0.41950518, 0.04743317, 0.0048816 , ..., 0.11408642,\n",
285
- " -0.05384652, 0.1025871 ],\n",
286
- " ...,\n",
287
- " [-0.10960457, 0.35151365, -0.21752454, ..., -0.21448943,\n",
288
- " -0.6396219 , -0.00839792],\n",
289
- " [ 0.20491892, 0.36294487, 0.19217414, ..., 0.07750722,\n",
290
- " -0.5039212 , 0.03793833],\n",
291
- " [-0.11638474, -0.35350856, 0.13215722, ..., -0.1606055 ,\n",
292
- " 0.23913842, -0.2565337 ]], shape=(115, 480), dtype=float32),\n",
293
- " array([[-0.42062947, -0.44009134, 0.00152371, ..., 0.27141467,\n",
294
- " 0.03798106, -0.397461 ],\n",
295
- " [-0.57318133, 0.5258899 , -0.17001636, ..., -0.23864633,\n",
296
- " 0.2088059 , -0.57877594],\n",
297
- " [-0.38988614, 0.46168196, -0.3429413 , ..., -0.14872643,\n",
298
- " -0.46576905, -0.21224979],\n",
299
- " ...,\n",
300
- " [-0.21528634, 0.30046722, -0.25216463, ..., -0.11576828,\n",
301
- " -0.4704907 , -0.0740136 ],\n",
302
- " [ 0.0633081 , 0.22700705, 0.28184187, ..., 0.15967266,\n",
303
- " -0.377182 , 0.06188517],\n",
304
- " [-0.27826303, -0.37297496, 0.21229912, ..., -0.14886017,\n",
305
- " 0.24998347, -0.35954213]], shape=(238, 480), dtype=float32)]"
306
- ]
307
- },
308
- "execution_count": 9,
309
- "metadata": {},
310
- "output_type": "execute_result"
311
- }
312
- ],
313
  "source": [
314
  "ablang(all_seqs, mode='rescoding', stepwise_masking = False)"
315
  ]
@@ -330,80 +278,10 @@
330
  },
331
  {
332
  "cell_type": "code",
333
- "execution_count": 10,
334
  "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
335
  "metadata": {},
336
- "outputs": [
337
- {
338
- "name": "stdout",
339
- "output_type": "stream",
340
- "text": [
341
- "['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n",
342
- " '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n",
343
- " '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n",
344
- " '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n",
345
- " '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n",
346
- " '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n",
347
- " '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n",
348
- " '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n",
349
- " '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n",
350
- " '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n",
351
- " '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n",
352
- " '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n",
353
- " '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n",
354
- " '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n",
355
- " '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n",
356
- " '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n",
357
- " '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n",
358
- " '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n",
359
- " '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n",
360
- " '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n",
361
- " '125 ' '126 ' '127 ' '>']\n",
362
- "['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT----->|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*N-RDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>']\n",
363
- "[[[ 9.31621838 -3.42184329 -3.59397745 ... -14.73707485 -6.8935833\n",
364
- " -0.23662776]\n",
365
- " [ -3.54718232 -5.84866619 -4.02423859 ... -12.93966579 -9.5614481\n",
366
- " -4.48473835]\n",
367
- " [-11.94997597 -2.245543 -5.69481373 ... -15.19639015 -17.97454071\n",
368
- " -12.56952095]\n",
369
- " ...\n",
370
- " [ -8.94504833 -0.42261261 -4.95588207 ... -16.66817474 -15.2224741\n",
371
- " -10.37267494]\n",
372
- " [-11.65150356 -5.44477606 -2.95585775 ... -16.25555801 -9.75158596\n",
373
- " -11.75897026]\n",
374
- " [ 1.79469728 -1.95846701 -3.59784532 ... -14.95585823 -7.47080708\n",
375
- " -0.95226753]]\n",
376
- "\n",
377
- " [[ 8.55518723 -3.83663297 -2.33595967 ... -13.87456799 -8.14840603\n",
378
- " -0.42472434]\n",
379
- " [ -4.40701294 -5.53201008 -3.69397402 ... -12.97877789 -9.86258411\n",
380
- " -4.95414352]\n",
381
- " [-11.95642853 -3.86210871 -5.80935192 ... -14.89213085 -16.94556236\n",
382
- " -11.36959839]\n",
383
- " ...\n",
384
- " [ -7.75924015 -0.66524202 -4.08643246 ... -16.16580772 -14.76507473\n",
385
- " -8.3507061 ]\n",
386
- " [-11.91039753 -4.86995983 -2.74777436 ... -16.07694817 -8.44974899\n",
387
- " -10.45223904]\n",
388
- " [ 0.86006832 -2.37964034 -3.58130741 ... -15.35423565 -7.73035526\n",
389
- " -1.11989737]]\n",
390
- "\n",
391
- " [[ -4.37902737 -7.55587149 1.21958363 ... -15.48622513 -6.021842\n",
392
- " -3.79647374]\n",
393
- " [ 0. 0. 0. ... 0. 0.\n",
394
- " 0. ]\n",
395
- " [ 0. 0. 0. ... 0. 0.\n",
396
- " 0. ]\n",
397
- " ...\n",
398
- " [ -8.94207573 -0.51090252 -5.09760332 ... -16.69521713 -15.45450687\n",
399
- " -10.50823212]\n",
400
- " [-11.92354965 -5.55152607 -2.87666893 ... -16.40607834 -10.19431686\n",
401
- " -12.1328764 ]\n",
402
- " [ 2.42200375 -2.01573253 -3.61701298 ... -14.9590435 -7.19029331\n",
403
- " -0.89830256]]]\n"
404
- ]
405
- }
406
- ],
407
  "source": [
408
  "results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n",
409
  "\n",
@@ -414,60 +292,10 @@
414
  },
415
  {
416
  "cell_type": "code",
417
- "execution_count": 9,
418
  "id": "56be8cad",
419
  "metadata": {},
420
- "outputs": [
421
- {
422
- "data": {
423
- "text/plain": [
424
- "[array([[9.9955505e-01, 2.9358694e-06, 2.4716087e-06, ..., 3.5776201e-11,\n",
425
- " 9.1196831e-08, 7.0967326e-05],\n",
426
- " [4.1573694e-06, 4.1619489e-07, 2.5800944e-06, ..., 3.4650952e-10,\n",
427
- " 1.0159109e-08, 1.6279575e-06],\n",
428
- " [7.8059600e-08, 1.2794037e-03, 4.0645118e-05, ..., 3.0375720e-09,\n",
429
- " 1.8879491e-10, 4.2010839e-08],\n",
430
- " ...,\n",
431
- " [3.4210879e-07, 1.7195340e-03, 1.8477240e-05, ..., 1.5137445e-10,\n",
432
- " 6.4255873e-10, 8.2064140e-08],\n",
433
- " [9.1038084e-09, 4.5161755e-06, 5.4411950e-05, ..., 9.1139631e-11,\n",
434
- " 6.0862085e-08, 8.1761966e-09],\n",
435
- " [8.5759175e-04, 2.0104915e-05, 3.9023766e-06, ..., 4.5562460e-11,\n",
436
- " 8.1156479e-08, 5.4990651e-05]], shape=(238, 26), dtype=float32),\n",
437
- " array([[9.9939799e-01, 4.1499175e-06, 1.8611167e-05, ..., 1.8139243e-10,\n",
438
- " 5.5649299e-08, 1.2583815e-04],\n",
439
- " [1.6735513e-06, 5.4332406e-07, 3.4143472e-06, ..., 3.1693398e-10,\n",
440
- " 7.1501400e-09, 9.6832969e-07],\n",
441
- " [3.7784993e-08, 1.2377645e-04, 1.7658784e-05, ..., 2.0061326e-09,\n",
442
- " 2.5737484e-10, 6.7947965e-08],\n",
443
- " ...,\n",
444
- " [1.1050455e-06, 1.3312638e-03, 4.3497097e-05, ..., 2.4686178e-10,\n",
445
- " 1.0018089e-09, 6.1165900e-07],\n",
446
- " [5.7270397e-09, 6.5396339e-06, 5.4601755e-05, ..., 8.8801404e-11,\n",
447
- " 1.8233513e-07, 2.4615032e-08],\n",
448
- " [7.3952030e-04, 2.8970928e-05, 8.7113440e-06, ..., 6.7168833e-11,\n",
449
- " 1.3746008e-07, 1.0210846e-04]], shape=(222, 26), dtype=float32),\n",
450
- " array([[9.99685407e-01, 3.35662639e-06, 1.14241482e-06, ...,\n",
451
- " 2.32460891e-11, 6.88188067e-08, 5.69467156e-05],\n",
452
- " [6.38133372e-07, 1.01300586e-07, 5.64459742e-06, ...,\n",
453
- " 4.09234556e-11, 2.53804799e-09, 4.31722100e-07],\n",
454
- " [1.49096788e-08, 2.04515047e-04, 9.23794141e-06, ...,\n",
455
- " 7.46306961e-10, 2.92107380e-11, 2.21786500e-08],\n",
456
- " ...,\n",
457
- " [2.15093763e-07, 1.06453872e-03, 1.62486140e-05, ...,\n",
458
- " 1.12102910e-10, 1.47300866e-10, 4.73037538e-08],\n",
459
- " [4.30136682e-09, 3.09317988e-06, 3.96632568e-05, ...,\n",
460
- " 5.24226877e-11, 2.39579450e-08, 3.86403221e-09],\n",
461
- " [9.77773685e-04, 1.29533228e-05, 2.78623725e-06, ...,\n",
462
- " 2.73364300e-11, 3.96418649e-08, 4.04014427e-05]],\n",
463
- " shape=(238, 26), dtype=float32)]"
464
- ]
465
- },
466
- "execution_count": 9,
467
- "metadata": {},
468
- "output_type": "execute_result"
469
- }
470
- ],
471
  "source": [
472
  "ablang(only_both_chains_seqs, mode='probability')"
473
  ]
@@ -492,21 +320,10 @@
492
  },
493
  {
494
  "cell_type": "code",
495
- "execution_count": 12,
496
  "id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
497
  "metadata": {},
498
- "outputs": [
499
- {
500
- "data": {
501
- "text/plain": [
502
- "array([1.96673731, 2.04801253, 2.09881898, 1.82533665, 1.97255249])"
503
- ]
504
- },
505
- "execution_count": 12,
506
- "metadata": {},
507
- "output_type": "execute_result"
508
- }
509
- ],
510
  "source": [
511
  "results = ablang(all_seqs, mode='pseudo_log_likelihood')\n",
512
  "np.exp(-results) # convert to pseudo perplexity"
@@ -514,22 +331,10 @@
514
  },
515
  {
516
  "cell_type": "code",
517
- "execution_count": 13,
518
  "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
519
  "metadata": {},
520
- "outputs": [
521
- {
522
- "data": {
523
- "text/plain": [
524
- "array([1.2636038, 1.126463 , 1.3123759, 1.2140924, 1.1805094],\n",
525
- " dtype=float32)"
526
- ]
527
- },
528
- "execution_count": 13,
529
- "metadata": {},
530
- "output_type": "execute_result"
531
- }
532
- ],
533
  "source": [
534
  "results = ablang(all_seqs, mode='confidence')\n",
535
  "np.exp(-results)"
@@ -547,24 +352,10 @@
547
  },
548
  {
549
  "cell_type": "code",
550
- "execution_count": 14,
551
  "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
552
  "metadata": {},
553
- "outputs": [
554
- {
555
- "data": {
556
- "text/plain": [
557
- "array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
558
- " '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT>|<PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
559
- " '<EVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
560
- " dtype='<U238')"
561
- ]
562
- },
563
- "execution_count": 14,
564
- "metadata": {},
565
- "output_type": "execute_result"
566
- }
567
- ],
568
  "source": [
569
  "restored = ablang(only_both_chains_seqs, mode='restore')\n",
570
  "restored"
@@ -572,24 +363,10 @@
572
  },
573
  {
574
  "cell_type": "code",
575
- "execution_count": 15,
576
  "id": "0e9615f7-c490-4947-96f4-7617266c686e",
577
  "metadata": {},
578
- "outputs": [
579
- {
580
- "data": {
581
- "text/plain": [
582
- "array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
583
- " '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DVVMTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
584
- " '<QVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
585
- " dtype='<U238')"
586
- ]
587
- },
588
- "execution_count": 15,
589
- "metadata": {},
590
- "output_type": "execute_result"
591
- }
592
- ],
593
  "source": [
594
  "restored = ablang(only_both_chains_seqs, mode='restore', align = True)\n",
595
  "restored"
 
86
  "id": "6d66ad84",
87
  "metadata": {},
88
  "outputs": [
89
+ {
90
+ "data": {
91
+ "application/vnd.jupyter.widget-view+json": {
92
+ "model_id": "a5acedae3cc4420ea2971400b0915426",
93
+ "version_major": 2,
94
+ "version_minor": 0
95
+ },
96
+ "text/plain": [
97
+ "config.json: 0%| | 0.00/560 [00:00<?, ?B/s]"
98
+ ]
99
+ },
100
+ "metadata": {},
101
+ "output_type": "display_data"
102
+ },
103
+ {
104
+ "data": {
105
+ "application/vnd.jupyter.widget-view+json": {
106
+ "model_id": "5727addb151447cf9bb091ef1159717c",
107
+ "version_major": 2,
108
+ "version_minor": 0
109
+ },
110
+ "text/plain": [
111
+ "configuration_ablang2paired.py: 0.00B [00:00, ?B/s]"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ },
117
  {
118
  "name": "stderr",
119
  "output_type": "stream",
120
  "text": [
121
  "A new version of the following files was downloaded from https://huggingface.co/hemantn/ablang2:\n",
122
  "- configuration_ablang2paired.py\n",
123
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
 
 
 
 
 
124
  ]
125
  },
126
  {
127
+ "data": {
128
+ "application/vnd.jupyter.widget-view+json": {
129
+ "model_id": "845b45d4aed542dc86ab7b7ac3305a0e",
130
+ "version_major": 2,
131
+ "version_minor": 0
132
+ },
133
+ "text/plain": [
134
+ "modeling_ablang2paired.py: 0.00B [00:00, ?B/s]"
135
+ ]
136
+ },
137
+ "metadata": {},
138
+ "output_type": "display_data"
139
  },
140
  {
141
  "name": "stderr",
142
  "output_type": "stream",
143
  "text": [
144
+ "Encountered exception while importing ablang2: No module named 'ablang2'\n"
145
+ ]
146
+ },
147
+ {
148
+ "ename": "ImportError",
149
+ "evalue": "This modeling file requires the following packages that were not found in your environment: ablang2. Run `pip install ablang2`",
150
+ "output_type": "error",
151
+ "traceback": [
152
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
153
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
154
+ "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Load model and tokenizer from Hugging Face Hub\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mhemantn/ablang2\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhemantn/ablang2\u001b[39m\u001b[38;5;124m\"\u001b[39m, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Find the cached model directory and import adapter\u001b[39;00m\n",
155
+ "File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:582\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 579\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madapter_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m adapter_kwargs\n\u001b[1;32m 581\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_remote_code \u001b[38;5;129;01mand\u001b[39;00m trust_remote_code:\n\u001b[0;32m--> 582\u001b[0m model_class \u001b[38;5;241m=\u001b[39m \u001b[43mget_class_from_dynamic_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_ref\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode_revision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 585\u001b[0m _ \u001b[38;5;241m=\u001b[39m hub_kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcode_revision\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 586\u001b[0m \u001b[38;5;66;03m# This block handles the case where the user is loading a model with `trust_remote_code=True`\u001b[39;00m\n\u001b[1;32m 587\u001b[0m \u001b[38;5;66;03m# but a library model exists with the same name. We don't want to override the autoclass\u001b[39;00m\n\u001b[1;32m 588\u001b[0m \u001b[38;5;66;03m# mappings in this case, or all future loads of that model will be the remote code model.\u001b[39;00m\n",
156
+ "File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:570\u001b[0m, in \u001b[0;36mget_class_from_dynamic_module\u001b[0;34m(class_reference, pretrained_model_name_or_path, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, code_revision, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m code_revision \u001b[38;5;241m=\u001b[39m revision\n\u001b[1;32m 569\u001b[0m \u001b[38;5;66;03m# And lastly we get the class inside our newly created module\u001b[39;00m\n\u001b[0;32m--> 570\u001b[0m final_module \u001b[38;5;241m=\u001b[39m \u001b[43mget_cached_module_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 571\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 572\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_file\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.py\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 574\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 575\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 577\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 578\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcode_revision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 579\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m get_class_in_module(class_name, final_module, force_reload\u001b[38;5;241m=\u001b[39mforce_download)\n",
157
+ "File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:393\u001b[0m, in \u001b[0;36mget_cached_module_file\u001b[0;34m(pretrained_model_name_or_path, module_file, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 392\u001b[0m \u001b[38;5;66;03m# Check we have all the requirements in our environment\u001b[39;00m\n\u001b[0;32m--> 393\u001b[0m modules_needed \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_imports\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresolved_module_file\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;66;03m# Now we move the module inside our cached dynamic modules.\u001b[39;00m\n\u001b[1;32m 396\u001b[0m full_submodule \u001b[38;5;241m=\u001b[39m TRANSFORMERS_DYNAMIC_MODULE_NAME \u001b[38;5;241m+\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39msep \u001b[38;5;241m+\u001b[39m submodule\n",
158
+ "File \u001b[0;32m~/.conda/envs/lib_transformer/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:225\u001b[0m, in \u001b[0;36mcheck_imports\u001b[0;34m(filename)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(missing_packages) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 226\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis modeling file requires the following packages that were not found in your environment: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing_packages)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Run `pip install \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(missing_packages)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 228\u001b[0m )\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m get_relative_imports(filename)\n",
159
+ "\u001b[0;31mImportError\u001b[0m: This modeling file requires the following packages that were not found in your environment: ablang2. Run `pip install ablang2`"
160
  ]
161
  }
162
  ],
 
205
  },
206
  {
207
  "cell_type": "code",
208
+ "execution_count": null,
209
  "id": "ceae4a88-0679-4704-8bad-c06a4569c497",
210
  "metadata": {},
211
  "outputs": [],
 
230
  },
231
  {
232
  "cell_type": "code",
233
+ "execution_count": null,
234
  "id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
235
  "metadata": {},
236
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  "source": [
238
  "ablang(all_seqs, mode='seqcoding')\n"
239
  ]
 
254
  },
255
  {
256
  "cell_type": "code",
257
+ "execution_count": null,
258
  "id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
259
  "metadata": {},
260
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  "source": [
262
  "ablang(all_seqs, mode='rescoding', stepwise_masking = False)"
263
  ]
 
278
  },
279
  {
280
  "cell_type": "code",
281
+ "execution_count": null,
282
  "id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
283
  "metadata": {},
284
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  "source": [
286
  "results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n",
287
  "\n",
 
292
  },
293
  {
294
  "cell_type": "code",
295
+ "execution_count": null,
296
  "id": "56be8cad",
297
  "metadata": {},
298
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  "source": [
300
  "ablang(only_both_chains_seqs, mode='probability')"
301
  ]
 
320
  },
321
  {
322
  "cell_type": "code",
323
+ "execution_count": null,
324
  "id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
325
  "metadata": {},
326
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
327
  "source": [
328
  "results = ablang(all_seqs, mode='pseudo_log_likelihood')\n",
329
  "np.exp(-results) # convert to pseudo perplexity"
 
331
  },
332
  {
333
  "cell_type": "code",
334
+ "execution_count": null,
335
  "id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
336
  "metadata": {},
337
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
338
  "source": [
339
  "results = ablang(all_seqs, mode='confidence')\n",
340
  "np.exp(-results)"
 
352
  },
353
  {
354
  "cell_type": "code",
355
+ "execution_count": null,
356
  "id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
357
  "metadata": {},
358
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  "source": [
360
  "restored = ablang(only_both_chains_seqs, mode='restore')\n",
361
  "restored"
 
363
  },
364
  {
365
  "cell_type": "code",
366
+ "execution_count": null,
367
  "id": "0e9615f7-c490-4947-96f4-7617266c686e",
368
  "metadata": {},
369
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  "source": [
371
  "restored = ablang(only_both_chains_seqs, mode='restore', align = True)\n",
372
  "restored"
test_module_loading.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from transformers.utils import cached_file
6
+
7
+ # Load model and tokenizer from Hugging Face Hub
8
+ model = AutoModel.from_pretrained("hemantn/ablang2", trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained("hemantn/ablang2", trust_remote_code=True)
10
+
11
+ # Find the cached model directory and import adapter
12
+ adapter_path = cached_file("hemantn/ablang2", "adapter.py")
13
+ cached_model_dir = os.path.dirname(adapter_path)
14
+ sys.path.insert(0, cached_model_dir)
15
+
16
+ # Import and create the adapter
17
+ from adapter import AbLang2PairedHuggingFaceAdapter
18
+ ablang = AbLang2PairedHuggingFaceAdapter(model=model, tokenizer=tokenizer)
19
+