|
import os |
|
import sys |
|
import shutil |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
if current_dir not in sys.path: |
|
sys.path.insert(0, current_dir) |
|
|
|
|
|
|
|
|
|
UTILITY_FILES = [ |
|
'restoration.py', |
|
'ablang_encodings.py', |
|
'alignment.py', |
|
'scores.py', |
|
'extra_utils.py', |
|
'ablang.py', |
|
'encoderblock.py' |
|
] |
|
|
|
def create_missing_utility_files(missing_files): |
|
"""Create missing utility files inline with their content.""" |
|
|
|
|
|
utility_contents = { |
|
'restoration.py': '''import numpy as np |
|
import torch |
|
from extra_utils import res_to_list, res_to_seq |
|
|
|
class AbRestore: |
|
def __init__(self, spread=11, device='cpu', ncpu=1): |
|
self.spread = spread |
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def _initiate_abrestore(self, model, tokenizer): |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
|
|
def restore(self, seqs, align=False, **kwargs): |
|
"""Restore masked sequences.""" |
|
# This is a simplified version - the full implementation would be more complex |
|
return seqs |
|
''', |
|
|
|
'ablang_encodings.py': '''import numpy as np |
|
import torch |
|
from extra_utils import res_to_list, res_to_seq |
|
|
|
class AbEncoding: |
|
def __init__(self, device='cpu', ncpu=1): |
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def _initiate_abencoding(self, model, tokenizer): |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
|
|
def _encode_sequences(self, seqs): |
|
# This will be overridden by the adapter |
|
pass |
|
|
|
def seqcoding(self, seqs, **kwargs): |
|
"""Sequence specific representations""" |
|
pass |
|
|
|
def rescoding(self, seqs, align=False, **kwargs): |
|
"""Residue specific representations.""" |
|
pass |
|
|
|
def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
"""Likelihood of mutations""" |
|
pass |
|
|
|
def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
"""Probability of mutations""" |
|
pass |
|
''', |
|
|
|
'alignment.py': '''from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment |
|
|
|
@dataclass |
|
class aligned_results: |
|
aligned_seqs: list |
|
aligned_embeds: np.ndarray |
|
number_alignment: list |
|
|
|
class AbAlignment: |
|
def __init__(self, device='cpu', ncpu=1): |
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def number_sequences(self, seqs, chain='H', fragmented=False): |
|
if chain == 'HL': |
|
numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented=fragmented, n_jobs=self.ncpu) |
|
else: |
|
numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(seqs, chain=chain, fragmented=fragmented, n_jobs=self.ncpu) |
|
return numbered_seqs, seqs, number_alignment |
|
|
|
def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): |
|
aligned_encodings = [] |
|
for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs): |
|
aligned_encodings.append(create_alignment(res_embed, numbered_seq, seq, number_alignment)) |
|
return np.concatenate([aligned_encodings], axis=0) |
|
|
|
def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None): |
|
if mode in ['seqcoding', 'pseudo_log_likelihood', 'confidence']: |
|
return np.concatenate(subset_list) |
|
elif mode == 'restore' and align: |
|
# For restore mode with alignment, return the aligned sequences |
|
return subset_list[0] if len(subset_list) == 1 else subset_list |
|
elif mode == 'restore' and not align: |
|
# For restore mode without alignment, return the restored sequences |
|
return subset_list[0] if len(subset_list) == 1 else subset_list |
|
elif align: |
|
aligned_subsets = [] |
|
for num, subset in enumerate(subset_list): |
|
start_idx = num * len(subset) |
|
end_idx = (num + 1) * len(subset) |
|
aligned_subset = self.align_encodings( |
|
subset, |
|
numbered_seqs[start_idx:end_idx], |
|
seqs[start_idx:end_idx], |
|
number_alignment |
|
) |
|
aligned_subsets.append(aligned_subset) |
|
subset = np.concatenate(aligned_subsets) |
|
return aligned_results( |
|
aligned_seqs=[''.join(alist) for alist in subset[:,:,-1]], |
|
aligned_embeds=subset[:,:,:-1].astype(float), |
|
number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values |
|
) |
|
elif not align: |
|
return sum(subset_list, []) |
|
else: |
|
return np.concatenate(subset_list) |
|
''', |
|
|
|
'scores.py': '''import numpy as np |
|
import torch |
|
from extra_utils import res_to_list, res_to_seq |
|
|
|
class AbScores: |
|
def __init__(self, device='cpu', ncpu=1): |
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def _initiate_abencoding(self, model, tokenizer): |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
|
|
def _encode_sequences(self, seqs): |
|
# This will be overridden by the adapter |
|
pass |
|
|
|
def _predict_logits(self, seqs): |
|
# This will be overridden by the adapter |
|
pass |
|
|
|
def pseudo_log_likelihood(self, seqs, **kwargs): |
|
"""Pseudo log likelihood of sequences.""" |
|
pass |
|
''', |
|
|
|
'extra_utils.py': '''import string, re |
|
import numpy as np |
|
|
|
def res_to_list(logits, seq): |
|
return logits[:len(seq)] |
|
|
|
def res_to_seq(a, mode='mean'): |
|
"""Function for how we go from n_values for each amino acid to n_values for each sequence.""" |
|
if mode=='sum': |
|
return a[0:(int(a[-1]))].sum() |
|
elif mode=='mean': |
|
return a[0:(int(a[-1]))].mean() |
|
elif mode=='restore': |
|
return a[0][0:(int(a[-1]))] |
|
|
|
def get_number_alignment(numbered_seqs): |
|
"""Creates a number alignment from the anarci results.""" |
|
import pandas as pd |
|
alist = [pd.DataFrame(aligned_seq, columns=[0,1,'resi']) for aligned_seq in numbered_seqs] |
|
unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0) |
|
max_alignment = get_max_alignment() |
|
return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]] |
|
|
|
def get_max_alignment(): |
|
"""Create maximum possible alignment for sorting""" |
|
import pandas as pd |
|
sortlist = [[("<", "")]] |
|
for num in range(1, 128+1): |
|
if num in [33,61,112]: |
|
for char in string.ascii_uppercase[::-1]: |
|
sortlist.append([(num, char)]) |
|
sortlist.append([(num,' ')]) |
|
else: |
|
sortlist.append([(num,' ')]) |
|
for char in string.ascii_uppercase: |
|
sortlist.append([(num, char)]) |
|
return pd.DataFrame(sortlist + [[(">", "")]]) |
|
|
|
def paired_msa_numbering(ab_seqs, fragmented=False, n_jobs=10): |
|
import pandas as pd |
|
tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs] |
|
numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering([i[0] for i in tmp_seqs], 'H', fragmented=fragmented, n_jobs=n_jobs) |
|
numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering([i[1] for i in tmp_seqs], 'L', fragmented=fragmented, n_jobs=n_jobs) |
|
number_alignment = pd.concat([number_alignment_heavy, pd.DataFrame([[("|",""), "|"]]), number_alignment_light]).reset_index(drop=True) |
|
seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)] |
|
numbered_seqs = [heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)] |
|
return numbered_seqs, seqs, number_alignment |
|
|
|
def unpaired_msa_numbering(seqs, chain='H', fragmented=False, n_jobs=10): |
|
numbered_seqs = number_with_anarci(seqs, chain=chain, fragmented=fragmented, n_jobs=n_jobs) |
|
number_alignment = get_number_alignment(numbered_seqs) |
|
number_alignment[1] = chain |
|
seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs] |
|
return numbered_seqs, seqs, number_alignment |
|
|
|
def number_with_anarci(seqs, chain='H', fragmented=False, n_jobs=1): |
|
import anarci |
|
import pandas as pd |
|
anarci_out = anarci.run_anarci(pd.DataFrame(seqs).reset_index().values.tolist(), ncpu=n_jobs, scheme='imgt', allowed_species=['human', 'mouse']) |
|
numbered_seqs = [] |
|
for onarci in anarci_out[1]: |
|
numbered_seq = [] |
|
for i in onarci[0][0]: |
|
if i[1] != '-': |
|
numbered_seq.append((i[0], chain, i[1])) |
|
if fragmented: |
|
numbered_seqs.append(numbered_seq) |
|
else: |
|
numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")]) |
|
return numbered_seqs |
|
|
|
def create_alignment(res_embeds, numbered_seqs, seq, number_alignment): |
|
import pandas as pd |
|
datadf = pd.DataFrame(numbered_seqs) |
|
sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2] |
|
idxs = np.where(sequence_alignment.values == '-')[0] |
|
idxs = [idx-num for num, idx in enumerate(idxs)] |
|
aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs, 0, axis=0)) |
|
return pd.concat([aligned_embeds, sequence_alignment], axis=1).values |
|
''', |
|
|
|
'ablang.py': '''from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from .encoderblock import TransformerEncoder, get_activation_fn |
|
|
|
class AbLang(torch.nn.Module): |
|
def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0): |
|
super().__init__() |
|
self.AbRep = AbRep(vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps, a_fn, dropout) |
|
self.AbHead = AbHead(vocab_size, hidden_embed_size, self.AbRep.aa_embed_layer.weight, layer_norm_eps, a_fn) |
|
|
|
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): |
|
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers) |
|
if return_attn_weights: |
|
return representations.attention_weights |
|
elif return_rep_layers != []: |
|
return representations.many_hidden_states |
|
else: |
|
likelihoods = self.AbHead(representations.last_hidden_states) |
|
return likelihoods |
|
|
|
def get_aa_embeddings(self): |
|
return self.AbRep.aa_embed_layer |
|
|
|
class AbRep(torch.nn.Module): |
|
def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0): |
|
super().__init__() |
|
self.aa_embed_layer = nn.Embedding(vocab_size, hidden_embed_size, padding_idx=padding_tkn) |
|
self.encoder_blocks = nn.ModuleList([TransformerEncoder(hidden_embed_size, n_attn_heads, dropout, layer_norm_eps, a_fn) for _ in range(n_encoder_blocks)]) |
|
|
|
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): |
|
hidden_states = self.aa_embed_layer(tokens) |
|
for i, encoder_block in enumerate(self.encoder_blocks): |
|
hidden_states, attn_weights = encoder_block(hidden_states) |
|
return type('obj', (object,), {'last_hidden_states': hidden_states}) |
|
|
|
class AbHead(torch.nn.Module): |
|
def __init__(self, vocab_size, hidden_embed_size, aa_embeddings, layer_norm_eps: float = 1e-12, a_fn: str = "gelu"): |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
self.aa_embeddings = aa_embeddings |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.layer_norm(hidden_states) |
|
return torch.matmul(hidden_states, self.aa_embeddings.transpose(0, 1)) |
|
''', |
|
|
|
'encoderblock.py': '''import torch |
|
import math |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import einops |
|
from rotary_embedding_torch import RotaryEmbedding |
|
|
|
class TransformerEncoder(torch.nn.Module): |
|
def __init__(self, hidden_embed_size, n_attn_heads, attn_dropout: float = 0.0, layer_norm_eps: float = 1e-05, a_fn: str = "gelu"): |
|
super().__init__() |
|
assert hidden_embed_size % n_attn_heads == 0, "Embedding dimension must be devisible with the number of heads." |
|
self.multihead_attention = MultiHeadAttention(embed_dim=hidden_embed_size, num_heads=n_attn_heads, attention_dropout_prob=attn_dropout) |
|
activation_fn, scale = get_activation_fn(a_fn) |
|
self.intermediate_layer = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale), |
|
activation_fn(), |
|
torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size), |
|
) |
|
self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) |
|
|
|
def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False): |
|
residual = hidden_embed |
|
hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone()) |
|
hidden_embed, attn_weights = self.multihead_attention(hidden_embed, attn_mask=attn_mask, return_attn_weights=return_attn_weights) |
|
hidden_embed = residual + hidden_embed |
|
residual = hidden_embed |
|
hidden_embed = self.final_layer_norm(hidden_embed) |
|
hidden_embed = self.intermediate_layer(hidden_embed) |
|
hidden_embed = residual + hidden_embed |
|
return hidden_embed, attn_weights |
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
def __init__(self, embed_dim, num_heads, attention_dropout_prob=0.0): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim ** -0.5 |
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
self.dropout = nn.Dropout(attention_dropout_prob) |
|
|
|
def forward(self, x, attn_mask=None, return_attn_weights=False): |
|
batch_size, seq_len, embed_dim = x.shape |
|
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling |
|
if attn_mask is not None: |
|
attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf')) |
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
attn_weights = self.dropout(attn_weights) |
|
|
|
attn_output = torch.matmul(attn_weights, v) |
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
if return_attn_weights: |
|
return attn_output, attn_weights |
|
return attn_output |
|
|
|
def get_activation_fn(activation_fn): |
|
if activation_fn == "gelu": |
|
return torch.nn.GELU, 1 |
|
elif activation_fn == "relu": |
|
return torch.nn.ReLU, 1 |
|
elif activation_fn == "swish": |
|
return torch.nn.SiLU, 1 |
|
else: |
|
raise ValueError(f"Unsupported activation function: {activation_fn}") |
|
''' |
|
} |
|
|
|
|
|
for file in missing_files: |
|
if file in utility_contents: |
|
with open(file, 'w') as f: |
|
f.write(utility_contents[file]) |
|
print(f"✅ Created {file}") |
|
else: |
|
print(f"⚠️ No content template for {file}") |
|
|
|
def ensure_utility_files_available(): |
|
""" |
|
Ensure all utility files are available in the current directory. |
|
If any are missing, try to copy them from the repository root. |
|
""" |
|
missing_files = [] |
|
for file in UTILITY_FILES: |
|
if not os.path.exists(file): |
|
missing_files.append(file) |
|
|
|
if missing_files: |
|
print(f"🔍 Looking for missing utility files: {missing_files}") |
|
|
|
|
|
|
|
possible_paths = [ |
|
current_dir, |
|
os.path.join(current_dir, '..'), |
|
os.path.join(current_dir, '..', '..'), |
|
os.path.join(current_dir, '..', '..', '..'), |
|
os.path.join(os.path.expanduser('~'), 'ablang2'), |
|
'/data/hn533621/ablang2', |
|
'/content/ablang2', |
|
'/tmp/ablang2', |
|
] |
|
|
|
|
|
is_hf_cache = 'huggingface' in current_dir and 'cache' in current_dir |
|
if is_hf_cache: |
|
print("🔍 Detected Hugging Face cache directory - will create utility files inline") |
|
|
|
possible_paths = [] |
|
|
|
|
|
cache_dir = os.path.dirname(current_dir) |
|
if 'huggingface' in cache_dir: |
|
|
|
repo_root = os.path.join(cache_dir, '..', '..', '..', '..') |
|
possible_paths.append(repo_root) |
|
|
|
for path in possible_paths: |
|
if os.path.exists(path): |
|
print(f"🔍 Checking path: {path}") |
|
|
|
all_found = True |
|
for file in missing_files: |
|
file_path = os.path.join(path, file) |
|
if not os.path.exists(file_path): |
|
all_found = False |
|
print(f" ❌ Missing: {file}") |
|
break |
|
else: |
|
print(f" ✅ Found: {file}") |
|
|
|
if all_found: |
|
print(f"🎯 Found all files in: {path}") |
|
|
|
for file in missing_files: |
|
src = os.path.join(path, file) |
|
dst = os.path.join(current_dir, file) |
|
shutil.copy2(src, dst) |
|
print(f"✅ Copied {file} to cached directory") |
|
return True |
|
|
|
|
|
print(f"❌ Could not find utility files in any of the searched paths:") |
|
for path in possible_paths: |
|
print(f" - {path}") |
|
|
|
|
|
print("🔧 Attempting to create missing utility files inline...") |
|
try: |
|
create_missing_utility_files(missing_files) |
|
print("✅ Successfully created missing utility files") |
|
return True |
|
except Exception as e: |
|
print(f"❌ Failed to create utility files: {e}") |
|
|
|
|
|
if 'google.colab' in str(sys.modules): |
|
raise FileNotFoundError( |
|
f"Missing utility files: {missing_files}. " |
|
"This appears to be a Google Colab environment. " |
|
"Please ensure you have cloned the repository and the utility files are available. " |
|
"Try running: !git clone https://huggingface.co/hemantn/ablang2" |
|
) |
|
else: |
|
raise FileNotFoundError( |
|
f"Missing utility files: {missing_files}. " |
|
"These files are required for the adapter to work. " |
|
"Please ensure the repository is properly set up." |
|
) |
|
|
|
return True |
|
|
|
|
|
ensure_utility_files_available() |
|
|
|
|
|
print(f"📁 Files in current directory ({current_dir}):") |
|
for f in os.listdir(current_dir): |
|
if f.endswith('.py'): |
|
print(f" {f}") |
|
|
|
|
|
import sys |
|
import os |
|
|
|
|
|
cache_dir = os.path.dirname(os.path.abspath(__file__)) |
|
if cache_dir not in sys.path: |
|
sys.path.insert(0, cache_dir) |
|
|
|
|
|
content_path = '/content' |
|
if content_path in sys.path: |
|
sys.path.remove(content_path) |
|
print(f"✅ Removed {content_path} from sys.path to avoid import conflicts") |
|
|
|
|
|
try: |
|
from restoration import AbRestore |
|
from ablang_encodings import AbEncoding |
|
from alignment import AbAlignment |
|
from scores import AbScores |
|
import torch |
|
import numpy as np |
|
from extra_utils import res_to_seq, res_to_list |
|
print("✅ Successfully imported utility modules from cache directory") |
|
except ImportError as e: |
|
print(f"❌ Import error: {e}") |
|
print(f"🔧 Current sys.path: {sys.path}") |
|
print(f"🔧 Cache directory: {cache_dir}") |
|
raise |
|
|
|
class HuggingFaceTokenizerAdapter: |
|
def __init__(self, tokenizer, device): |
|
self.tokenizer = tokenizer |
|
self.device = device |
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) |
|
self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab |
|
self.inv_vocab = {v: k for k, v in self.vocab.items()} |
|
self.all_special_tokens = tokenizer.all_special_tokens |
|
|
|
def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): |
|
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') |
|
input_ids = tokens['input_ids'].to(self.device if device is None else device) |
|
if mode == 'decode': |
|
|
|
if isinstance(seqs, torch.Tensor): |
|
seqs = seqs.cpu().numpy() |
|
decoded = [] |
|
for i, seq in enumerate(seqs): |
|
chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] |
|
|
|
|
|
from extra_utils import res_to_seq |
|
formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') |
|
decoded.append(formatted) |
|
return decoded |
|
return input_ids |
|
|
|
class HFAbRestore(AbRestore): |
|
def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): |
|
super().__init__(spread=spread, device=device, ncpu=ncpu) |
|
self.used_device = device |
|
self._hf_model = hf_model |
|
self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) |
|
|
|
@property |
|
def AbLang(self): |
|
def model_call(x): |
|
output = self._hf_model(x) |
|
if hasattr(output, 'last_hidden_state'): |
|
return output.last_hidden_state |
|
return output |
|
return model_call |
|
|
|
def restore(self, seqs, align=False, **kwargs): |
|
"""Restore masked residues in antibody sequences.""" |
|
if isinstance(seqs, str): |
|
seqs = [seqs] |
|
|
|
n_seqs = len(seqs) |
|
|
|
if align: |
|
|
|
seqs = self._sequence_aligning(seqs) |
|
nr_seqs = len(seqs)//self.spread |
|
|
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
predictions = self.AbLang(tokens)[:,:,1:21] |
|
|
|
|
|
tokens = tokens.reshape(nr_seqs, self.spread, -1) |
|
predictions = predictions.reshape(nr_seqs, self.spread, -1, 20) |
|
seqs = seqs.reshape(nr_seqs, -1) |
|
|
|
|
|
best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1) |
|
|
|
|
|
tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1) |
|
predictions = predictions[range(predictions.shape[0]), best_seq_idx] |
|
seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1) |
|
else: |
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
predictions = self.AbLang(tokens)[:,:,1:21] |
|
|
|
predicted_tokens = torch.max(predictions, -1).indices + 1 |
|
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens) |
|
|
|
restored_seqs = self.tokenizer(restored_tokens, mode="decode") |
|
|
|
if n_seqs < len(restored_seqs): |
|
restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])] |
|
seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])] |
|
|
|
from extra_utils import res_to_seq |
|
return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]]) |
|
|
|
def _sequence_aligning(self, seqs): |
|
"""Create spread sequences using ANARCI alignment.""" |
|
tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs] |
|
|
|
spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')] |
|
spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')] |
|
|
|
return np.concatenate([np.array(spread_heavy),np.array(spread_light)]) |
|
|
|
def _create_spread_of_sequences(self, seqs, chain = 'H'): |
|
"""Create spread sequences using ANARCI.""" |
|
import pandas as pd |
|
import anarci |
|
|
|
chain_idx = 0 if chain == 'H' else 1 |
|
numbered_seqs = anarci.run_anarci( |
|
pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), |
|
ncpu=self.ncpu, |
|
scheme='imgt', |
|
allowed_species=['human', 'mouse'], |
|
) |
|
|
|
anarci_data = pd.DataFrame( |
|
[str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]], |
|
columns=['anarci'] |
|
).astype('<U90') |
|
|
|
max_position = 128 if chain == 'H' else 127 |
|
|
|
|
|
import re |
|
|
|
def get_sequences_from_anarci(out_anarci, max_position, spread): |
|
""" |
|
Ensures correct masking on each side of sequence |
|
""" |
|
|
|
if out_anarci == 'ANARCI_error': |
|
return np.array(['ANARCI-ERR']*spread) |
|
|
|
end_position = int(re.search(r'\d+', out_anarci[::-1]).group()[::-1]) |
|
|
|
start_position = int(re.search(r'\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+\'\),\s\(\(\d+,\s\'.\'\),\s\'[^-]+', |
|
out_anarci).group().split(',')[0]) - 1 |
|
|
|
sequence = "".join(re.findall(r"(?i)[A-Z*]", "".join(re.findall(r'\),\s\'[A-Z*]', out_anarci)))) |
|
|
|
sequence_j = ''.join(sequence).replace('-','').replace('X','*') + '*'*(max_position-int(end_position)) |
|
|
|
return get_spread_sequences(sequence_j, spread, start_position) |
|
|
|
def get_spread_sequences(seq, spread, start_position): |
|
""" |
|
Test sequences which are 8 positions shorter (position 10 + max CDR1 gap of 7) up to 2 positions longer (possible insertions). |
|
""" |
|
spread_sequences = [] |
|
|
|
for diff in range(start_position-8, start_position+2+1): |
|
spread_sequences.append('*'*diff+seq) |
|
|
|
return np.array(spread_sequences) |
|
seqs = anarci_data.apply( |
|
lambda x: get_sequences_from_anarci( |
|
x.anarci, |
|
max_position, |
|
self.spread |
|
), axis=1, result_type='expand' |
|
).to_numpy().reshape(-1) |
|
|
|
return seqs |
|
|
|
def add_angle_brackets(seq): |
|
|
|
if '|' in seq: |
|
vh, vl = seq.split('|', 1) |
|
else: |
|
vh, vl = seq, '' |
|
return f"<{vh}>|<{vl}>" |
|
|
|
class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): |
|
""" |
|
Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. |
|
Automatically uses CUDA if available, otherwise CPU. |
|
""" |
|
def __init__(self, model, tokenizer, device=None, ncpu=1): |
|
super().__init__() |
|
if device is None: |
|
self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.used_device = torch.device(device) |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
self.AbLang.to(self.used_device) |
|
self.AbLang.eval() |
|
|
|
if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): |
|
self.AbRep = self.AbLang.model.AbRep |
|
else: |
|
raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") |
|
self.ncpu = ncpu |
|
self.spread = 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def freeze(self): |
|
self.AbLang.eval() |
|
|
|
def unfreeze(self): |
|
self.AbLang.train() |
|
|
|
def _encode_sequences(self, seqs): |
|
|
|
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') |
|
tokens = extract_input_ids(tokens, self.used_device) |
|
return self.AbRep(tokens).last_hidden_states.detach() |
|
|
|
def _predict_logits(self, seqs): |
|
|
|
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') |
|
tokens = extract_input_ids(tokens, self.used_device) |
|
output = self.AbLang(tokens) |
|
if hasattr(output, 'last_hidden_state'): |
|
return output.last_hidden_state.detach() |
|
return output.detach() |
|
|
|
def _predict_logits_with_step_masking(self, seqs): |
|
|
|
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') |
|
tokens = extract_input_ids(tokens, self.used_device) |
|
|
|
logits = [] |
|
for single_seq_tokens in tokens: |
|
tkn_len = len(single_seq_tokens) |
|
masked_tokens = single_seq_tokens.repeat(tkn_len, 1) |
|
for num in range(tkn_len): |
|
masked_tokens[num, num] = self.tokenizer.mask_token_id |
|
|
|
with torch.no_grad(): |
|
logits_tmp = self.AbLang(masked_tokens) |
|
|
|
logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)]) |
|
logits.append(logits_tmp) |
|
|
|
return torch.stack(logits, dim=0) |
|
|
|
def _preprocess_labels(self, labels): |
|
labels = extract_input_ids(labels, self.used_device) |
|
return labels |
|
|
|
def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): |
|
""" |
|
Use different modes for different usecases, mimicking the original pretrained class. |
|
""" |
|
|
|
def format_seq_input(seqs, fragmented=False): |
|
"""Format input sequences for processing.""" |
|
if isinstance(seqs[0], str): |
|
seqs = [seqs] |
|
|
|
if fragmented: |
|
|
|
formatted_seqs = [] |
|
for seq in seqs: |
|
if isinstance(seq, (list, tuple)) and len(seq) == 2: |
|
heavy, light = seq[0], seq[1] |
|
formatted_seqs.append(f"{heavy}|{light}") |
|
else: |
|
formatted_seqs.append(seq) |
|
return formatted_seqs, 'HL' |
|
else: |
|
|
|
formatted_seqs = [] |
|
for seq in seqs: |
|
if isinstance(seq, (list, tuple)) and len(seq) == 2: |
|
heavy, light = seq[0], seq[1] |
|
|
|
heavy_part = f"<{heavy}>" if heavy else "<>" |
|
light_part = f"<{light}>" if light else "<>" |
|
formatted_seqs.append(f"{heavy_part}|{light_part}".replace("<>", "")) |
|
else: |
|
formatted_seqs.append(seq) |
|
|
|
return formatted_seqs, 'HL' |
|
|
|
valid_modes = [ |
|
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', |
|
'pseudo_log_likelihood', 'confidence' |
|
] |
|
if mode not in valid_modes: |
|
raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") |
|
|
|
seqs, chain = format_seq_input(seqs, fragmented=fragmented) |
|
|
|
if align: |
|
numbered_seqs, seqs, number_alignment = self.number_sequences( |
|
seqs, chain=chain, fragmented=fragmented |
|
) |
|
else: |
|
numbered_seqs = None |
|
number_alignment = None |
|
|
|
subset_list = [] |
|
for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: |
|
subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) |
|
|
|
return self.reformat_subsets( |
|
subset_list, |
|
mode=mode, |
|
align=align, |
|
numbered_seqs=numbered_seqs, |
|
seqs=seqs, |
|
number_alignment=number_alignment, |
|
) |
|
|
|
def pseudo_log_likelihood(self, seqs, **kwargs): |
|
""" |
|
Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. |
|
""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
|
|
labels = self.tokenizer( |
|
formatted_seqs, padding=True, return_tensors='pt' |
|
) |
|
labels = extract_input_ids(labels, self.used_device) |
|
|
|
|
|
if isinstance(self.tokenizer.all_special_tokens[0], int): |
|
special_token_ids = set(self.tokenizer.all_special_tokens) |
|
else: |
|
special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) |
|
pad_token_id = self.tokenizer.pad_token_id |
|
|
|
mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) |
|
if mask_token_id is None: |
|
mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
|
|
|
plls = [] |
|
with torch.no_grad(): |
|
for i, seq_label in enumerate(labels): |
|
seq_pll = [] |
|
for j, token_id in enumerate(seq_label): |
|
if token_id.item() in special_token_ids or token_id.item() == pad_token_id: |
|
continue |
|
masked = seq_label.clone() |
|
masked[j] = mask_token_id |
|
logits = self.AbLang(masked.unsqueeze(0)) |
|
if hasattr(logits, 'last_hidden_state'): |
|
logits = logits.last_hidden_state |
|
logits = logits[0, j] |
|
nll = torch.nn.functional.cross_entropy( |
|
logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" |
|
) |
|
seq_pll.append(-nll.item()) |
|
if seq_pll: |
|
plls.append(np.mean(seq_pll)) |
|
else: |
|
plls.append(float('nan')) |
|
return np.array(plls) |
|
|
|
def seqcoding(self, seqs, **kwargs): |
|
"""Sequence specific representations - returns 480-dimensional embeddings for each sequence.""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
|
|
embeddings = self._encode_sequences(formatted_seqs) |
|
|
|
|
|
|
|
if len(embeddings.shape) == 3: |
|
seq_embeddings = embeddings.mean(dim=1) |
|
else: |
|
seq_embeddings = embeddings |
|
|
|
return seq_embeddings.cpu().numpy() |
|
|
|
def rescoding(self, seqs, align=False, **kwargs): |
|
"""Residue specific representations - returns 480-dimensional embeddings for each residue.""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
|
|
embeddings = self._encode_sequences(formatted_seqs) |
|
|
|
|
|
|
|
if len(embeddings.shape) == 3: |
|
|
|
embeddings_np = embeddings.cpu().numpy() |
|
return [embeddings_np[i] for i in range(embeddings_np.shape[0])] |
|
else: |
|
return embeddings.cpu().numpy() |
|
|
|
def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
"""Likelihood of mutations - returns logits for each amino acid at each position.""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
|
|
if stepwise_masking: |
|
logits = self._predict_logits_with_step_masking(formatted_seqs) |
|
else: |
|
logits = self._predict_logits(formatted_seqs) |
|
|
|
|
|
return logits.cpu().numpy() |
|
|
|
def confidence(self, seqs, **kwargs): |
|
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
plls = [] |
|
for seq in formatted_seqs: |
|
tokens = self.tokenizer([seq], padding=True, return_tensors='pt') |
|
input_ids = extract_input_ids(tokens, self.used_device) |
|
|
|
with torch.no_grad(): |
|
output = self.AbLang(input_ids) |
|
if hasattr(output, 'last_hidden_state'): |
|
logits = output.last_hidden_state |
|
else: |
|
logits = output |
|
|
|
|
|
logits = logits[0] |
|
input_ids = input_ids[0] |
|
|
|
|
|
if isinstance(self.tokenizer.all_special_tokens[0], int): |
|
special_token_ids = set(self.tokenizer.all_special_tokens) |
|
else: |
|
special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) |
|
valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) |
|
|
|
if valid_mask.sum() > 0: |
|
valid_logits = logits[valid_mask] |
|
valid_labels = input_ids[valid_mask] |
|
|
|
|
|
nll = torch.nn.functional.cross_entropy( |
|
valid_logits, |
|
valid_labels, |
|
reduction="mean" |
|
) |
|
pll = -nll.item() |
|
else: |
|
pll = 0.0 |
|
|
|
plls.append(pll) |
|
|
|
return np.array(plls, dtype=np.float32) |
|
|
|
def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
""" |
|
Probability of mutations - applies softmax to logits to get probabilities |
|
""" |
|
|
|
formatted_seqs = [] |
|
for s in seqs: |
|
if isinstance(s, (list, tuple)): |
|
formatted_seqs.append('|'.join(s)) |
|
else: |
|
formatted_seqs.append(s) |
|
|
|
|
|
if stepwise_masking: |
|
|
|
|
|
logits = self._predict_logits(formatted_seqs) |
|
else: |
|
logits = self._predict_logits(formatted_seqs) |
|
|
|
|
|
probs = logits.softmax(-1).cpu().numpy() |
|
|
|
if align: |
|
return probs |
|
else: |
|
|
|
return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] |
|
|
|
def restore(self, seqs, align=False, **kwargs): |
|
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) |
|
restored = hf_abrestore.restore(seqs, align=align) |
|
|
|
if isinstance(restored, np.ndarray): |
|
restored = np.array([add_angle_brackets(seq) for seq in restored]) |
|
else: |
|
restored = [add_angle_brackets(seq) for seq in restored] |
|
return restored |
|
|
|
def extract_input_ids(tokens, device): |
|
if hasattr(tokens, 'input_ids'): |
|
return tokens.input_ids.to(device) |
|
elif isinstance(tokens, dict): |
|
if 'input_ids' in tokens: |
|
return tokens['input_ids'].to(device) |
|
else: |
|
for v in tokens.values(): |
|
if hasattr(v, 'ndim') or torch.is_tensor(v): |
|
return v.to(device) |
|
elif torch.is_tensor(tokens): |
|
return tokens.to(device) |
|
else: |
|
raise ValueError("Could not extract input_ids from tokenizer output") |