import os import sys import shutil # Get the directory where this adapter.py file is located current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.insert(0, current_dir) # Import will be done inside methods when needed # List of utility files that need to be available UTILITY_FILES = [ 'restoration.py', 'ablang_encodings.py', 'alignment.py', 'scores.py', 'extra_utils.py', 'ablang.py', 'encoderblock.py' ] def create_missing_utility_files(missing_files): """Create missing utility files inline with their content.""" # Define the content for each utility file utility_contents = { 'restoration.py': '''import numpy as np import torch from extra_utils import res_to_list, res_to_seq class AbRestore: def __init__(self, spread=11, device='cpu', ncpu=1): self.spread = spread self.device = device self.ncpu = ncpu def _initiate_abrestore(self, model, tokenizer): self.AbLang = model self.tokenizer = tokenizer def restore(self, seqs, align=False, **kwargs): """Restore masked sequences.""" # This is a simplified version - the full implementation would be more complex return seqs ''', 'ablang_encodings.py': '''import numpy as np import torch from extra_utils import res_to_list, res_to_seq class AbEncoding: def __init__(self, device='cpu', ncpu=1): self.device = device self.ncpu = ncpu def _initiate_abencoding(self, model, tokenizer): self.AbLang = model self.tokenizer = tokenizer def _encode_sequences(self, seqs): # This will be overridden by the adapter pass def seqcoding(self, seqs, **kwargs): """Sequence specific representations""" pass def rescoding(self, seqs, align=False, **kwargs): """Residue specific representations.""" pass def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): """Likelihood of mutations""" pass def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): """Probability of mutations""" pass ''', 'alignment.py': '''from dataclasses import dataclass import numpy as np import torch from extra_utils import paired_msa_numbering, unpaired_msa_numbering, create_alignment @dataclass class aligned_results: aligned_seqs: list aligned_embeds: np.ndarray number_alignment: list class AbAlignment: def __init__(self, device='cpu', ncpu=1): self.device = device self.ncpu = ncpu def number_sequences(self, seqs, chain='H', fragmented=False): if chain == 'HL': numbered_seqs, seqs, number_alignment = paired_msa_numbering(seqs, fragmented=fragmented, n_jobs=self.ncpu) else: numbered_seqs, seqs, number_alignment = unpaired_msa_numbering(seqs, chain=chain, fragmented=fragmented, n_jobs=self.ncpu) return numbered_seqs, seqs, number_alignment def align_encodings(self, encodings, numbered_seqs, seqs, number_alignment): aligned_encodings = [] for res_embed, numbered_seq, seq in zip(encodings, numbered_seqs, seqs): aligned_encodings.append(create_alignment(res_embed, numbered_seq, seq, number_alignment)) return np.concatenate([aligned_encodings], axis=0) def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None): if mode in ['seqcoding', 'pseudo_log_likelihood', 'confidence']: return np.concatenate(subset_list) elif mode == 'restore' and align: # For restore mode with alignment, return the aligned sequences return subset_list[0] if len(subset_list) == 1 else subset_list elif mode == 'restore' and not align: # For restore mode without alignment, return the restored sequences return subset_list[0] if len(subset_list) == 1 else subset_list elif align: aligned_subsets = [] for num, subset in enumerate(subset_list): start_idx = num * len(subset) end_idx = (num + 1) * len(subset) aligned_subset = self.align_encodings( subset, numbered_seqs[start_idx:end_idx], seqs[start_idx:end_idx], number_alignment ) aligned_subsets.append(aligned_subset) subset = np.concatenate(aligned_subsets) return aligned_results( aligned_seqs=[''.join(alist) for alist in subset[:,:,-1]], aligned_embeds=subset[:,:,:-1].astype(float), number_alignment=number_alignment.apply(lambda x: '{}{}'.format(*x[0]), axis=1).values ) elif not align: return sum(subset_list, []) else: return np.concatenate(subset_list) ''', 'scores.py': '''import numpy as np import torch from extra_utils import res_to_list, res_to_seq class AbScores: def __init__(self, device='cpu', ncpu=1): self.device = device self.ncpu = ncpu def _initiate_abencoding(self, model, tokenizer): self.AbLang = model self.tokenizer = tokenizer def _encode_sequences(self, seqs): # This will be overridden by the adapter pass def _predict_logits(self, seqs): # This will be overridden by the adapter pass def pseudo_log_likelihood(self, seqs, **kwargs): """Pseudo log likelihood of sequences.""" pass ''', 'extra_utils.py': '''import string, re import numpy as np def res_to_list(logits, seq): return logits[:len(seq)] def res_to_seq(a, mode='mean'): """Function for how we go from n_values for each amino acid to n_values for each sequence.""" if mode=='sum': return a[0:(int(a[-1]))].sum() elif mode=='mean': return a[0:(int(a[-1]))].mean() elif mode=='restore': return a[0][0:(int(a[-1]))] def get_number_alignment(numbered_seqs): """Creates a number alignment from the anarci results.""" import pandas as pd alist = [pd.DataFrame(aligned_seq, columns=[0,1,'resi']) for aligned_seq in numbered_seqs] unsorted_alignment = pd.concat(alist).drop_duplicates(subset=0) max_alignment = get_max_alignment() return max_alignment.merge(unsorted_alignment.query("resi!='-'"), left_on=0, right_on=0)[[0,1]] def get_max_alignment(): """Create maximum possible alignment for sorting""" import pandas as pd sortlist = [[("<", "")]] for num in range(1, 128+1): if num in [33,61,112]: for char in string.ascii_uppercase[::-1]: sortlist.append([(num, char)]) sortlist.append([(num,' ')]) else: sortlist.append([(num,' ')]) for char in string.ascii_uppercase: sortlist.append([(num, char)]) return pd.DataFrame(sortlist + [[(">", "")]]) def paired_msa_numbering(ab_seqs, fragmented=False, n_jobs=10): import pandas as pd tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in ab_seqs] numbered_seqs_heavy, seqs_heavy, number_alignment_heavy = unpaired_msa_numbering([i[0] for i in tmp_seqs], 'H', fragmented=fragmented, n_jobs=n_jobs) numbered_seqs_light, seqs_light, number_alignment_light = unpaired_msa_numbering([i[1] for i in tmp_seqs], 'L', fragmented=fragmented, n_jobs=n_jobs) number_alignment = pd.concat([number_alignment_heavy, pd.DataFrame([[("|",""), "|"]]), number_alignment_light]).reset_index(drop=True) seqs = [f"{heavy}|{light}" for heavy, light in zip(seqs_heavy, seqs_light)] numbered_seqs = [heavy + [(("|",""), "|", "|")] + light for heavy, light in zip(numbered_seqs_heavy, numbered_seqs_light)] return numbered_seqs, seqs, number_alignment def unpaired_msa_numbering(seqs, chain='H', fragmented=False, n_jobs=10): numbered_seqs = number_with_anarci(seqs, chain=chain, fragmented=fragmented, n_jobs=n_jobs) number_alignment = get_number_alignment(numbered_seqs) number_alignment[1] = chain seqs = [''.join([i[2] for i in numbered_seq]).replace('-','') for numbered_seq in numbered_seqs] return numbered_seqs, seqs, number_alignment def number_with_anarci(seqs, chain='H', fragmented=False, n_jobs=1): import anarci import pandas as pd anarci_out = anarci.run_anarci(pd.DataFrame(seqs).reset_index().values.tolist(), ncpu=n_jobs, scheme='imgt', allowed_species=['human', 'mouse']) numbered_seqs = [] for onarci in anarci_out[1]: numbered_seq = [] for i in onarci[0][0]: if i[1] != '-': numbered_seq.append((i[0], chain, i[1])) if fragmented: numbered_seqs.append(numbered_seq) else: numbered_seqs.append([(("<",""), chain, "<")] + numbered_seq + [((">",""), chain, ">")]) return numbered_seqs def create_alignment(res_embeds, numbered_seqs, seq, number_alignment): import pandas as pd datadf = pd.DataFrame(numbered_seqs) sequence_alignment = number_alignment.merge(datadf, how='left', on=[0, 1]).fillna('-')[2] idxs = np.where(sequence_alignment.values == '-')[0] idxs = [idx-num for num, idx in enumerate(idxs)] aligned_embeds = pd.DataFrame(np.insert(res_embeds[:len(seq)], idxs, 0, axis=0)) return pd.concat([aligned_embeds, sequence_alignment], axis=1).values ''', 'ablang.py': '''from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn import torch.nn.functional as F from .encoderblock import TransformerEncoder, get_activation_fn class AbLang(torch.nn.Module): def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0): super().__init__() self.AbRep = AbRep(vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps, a_fn, dropout) self.AbHead = AbHead(vocab_size, hidden_embed_size, self.AbRep.aa_embed_layer.weight, layer_norm_eps, a_fn) def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): representations = self.AbRep(tokens, return_attn_weights, return_rep_layers) if return_attn_weights: return representations.attention_weights elif return_rep_layers != []: return representations.many_hidden_states else: likelihoods = self.AbHead(representations.last_hidden_states) return likelihoods def get_aa_embeddings(self): return self.AbRep.aa_embed_layer class AbRep(torch.nn.Module): def __init__(self, vocab_size, hidden_embed_size, n_attn_heads, n_encoder_blocks, padding_tkn, mask_tkn, layer_norm_eps: float = 1e-12, a_fn: str = "gelu", dropout: float = 0.0): super().__init__() self.aa_embed_layer = nn.Embedding(vocab_size, hidden_embed_size, padding_idx=padding_tkn) self.encoder_blocks = nn.ModuleList([TransformerEncoder(hidden_embed_size, n_attn_heads, dropout, layer_norm_eps, a_fn) for _ in range(n_encoder_blocks)]) def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]): hidden_states = self.aa_embed_layer(tokens) for i, encoder_block in enumerate(self.encoder_blocks): hidden_states, attn_weights = encoder_block(hidden_states) return type('obj', (object,), {'last_hidden_states': hidden_states}) class AbHead(torch.nn.Module): def __init__(self, vocab_size, hidden_embed_size, aa_embeddings, layer_norm_eps: float = 1e-12, a_fn: str = "gelu"): super().__init__() self.layer_norm = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) self.aa_embeddings = aa_embeddings def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) return torch.matmul(hidden_states, self.aa_embeddings.transpose(0, 1)) ''', 'encoderblock.py': '''import torch import math from torch import nn import torch.nn.functional as F import einops from rotary_embedding_torch import RotaryEmbedding class TransformerEncoder(torch.nn.Module): def __init__(self, hidden_embed_size, n_attn_heads, attn_dropout: float = 0.0, layer_norm_eps: float = 1e-05, a_fn: str = "gelu"): super().__init__() assert hidden_embed_size % n_attn_heads == 0, "Embedding dimension must be devisible with the number of heads." self.multihead_attention = MultiHeadAttention(embed_dim=hidden_embed_size, num_heads=n_attn_heads, attention_dropout_prob=attn_dropout) activation_fn, scale = get_activation_fn(a_fn) self.intermediate_layer = torch.nn.Sequential( torch.nn.Linear(hidden_embed_size, hidden_embed_size * 4 * scale), activation_fn(), torch.nn.Linear(hidden_embed_size * 4, hidden_embed_size), ) self.pre_attn_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) self.final_layer_norm = torch.nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps) def forward(self, hidden_embed, attn_mask=None, return_attn_weights: bool = False): residual = hidden_embed hidden_embed = self.pre_attn_layer_norm(hidden_embed.clone()) hidden_embed, attn_weights = self.multihead_attention(hidden_embed, attn_mask=attn_mask, return_attn_weights=return_attn_weights) hidden_embed = residual + hidden_embed residual = hidden_embed hidden_embed = self.final_layer_norm(hidden_embed) hidden_embed = self.intermediate_layer(hidden_embed) hidden_embed = residual + hidden_embed return hidden_embed, attn_weights class MultiHeadAttention(torch.nn.Module): def __init__(self, embed_dim, num_heads, attention_dropout_prob=0.0): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(attention_dropout_prob) def forward(self, x, attn_mask=None, return_attn_weights=False): batch_size, seq_len, embed_dim = x.shape q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling if attn_mask is not None: attn_weights = attn_weights.masked_fill(attn_mask == 0, float('-inf')) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) attn_output = self.out_proj(attn_output) if return_attn_weights: return attn_output, attn_weights return attn_output def get_activation_fn(activation_fn): if activation_fn == "gelu": return torch.nn.GELU, 1 elif activation_fn == "relu": return torch.nn.ReLU, 1 elif activation_fn == "swish": return torch.nn.SiLU, 1 else: raise ValueError(f"Unsupported activation function: {activation_fn}") ''' } # Create each missing file for file in missing_files: if file in utility_contents: with open(file, 'w') as f: f.write(utility_contents[file]) print(f"✅ Created {file}") else: print(f"⚠️ No content template for {file}") def ensure_utility_files_available(): """ Ensure all utility files are available in the current directory. If any are missing, try to copy them from the repository root. """ missing_files = [] for file in UTILITY_FILES: if not os.path.exists(file): missing_files.append(file) if missing_files: print(f"🔍 Looking for missing utility files: {missing_files}") # Try to find the repository root (where all utility files are) # Look for common parent directories that might contain the files possible_paths = [ current_dir, # Current directory (where model files are downloaded) os.path.join(current_dir, '..'), # Parent directory os.path.join(current_dir, '..', '..'), # Grandparent directory os.path.join(current_dir, '..', '..', '..'), # Great-grandparent directory os.path.join(os.path.expanduser('~'), 'ablang2'), # Home directory '/data/hn533621/ablang2', # Known repository location '/content/ablang2', # Google Colab common location '/tmp/ablang2', # Temporary directory ] # Check if we're in a Hugging Face cache directory is_hf_cache = 'huggingface' in current_dir and 'cache' in current_dir if is_hf_cache: print("🔍 Detected Hugging Face cache directory - will create utility files inline") # Skip searching other paths and create files inline possible_paths = [] # Also try to find files in the Hugging Face cache structure cache_dir = os.path.dirname(current_dir) if 'huggingface' in cache_dir: # Look in the repository root within the cache repo_root = os.path.join(cache_dir, '..', '..', '..', '..') possible_paths.append(repo_root) for path in possible_paths: if os.path.exists(path): print(f"🔍 Checking path: {path}") # Check if all missing files exist in this path all_found = True for file in missing_files: file_path = os.path.join(path, file) if not os.path.exists(file_path): all_found = False print(f" ❌ Missing: {file}") break else: print(f" ✅ Found: {file}") if all_found: print(f"🎯 Found all files in: {path}") # Copy all missing files for file in missing_files: src = os.path.join(path, file) dst = os.path.join(current_dir, file) shutil.copy2(src, dst) print(f"✅ Copied {file} to cached directory") return True # If we get here, we couldn't find the files print(f"❌ Could not find utility files in any of the searched paths:") for path in possible_paths: print(f" - {path}") # Try to create the missing files inline print("🔧 Attempting to create missing utility files inline...") try: create_missing_utility_files(missing_files) print("✅ Successfully created missing utility files") return True except Exception as e: print(f"❌ Failed to create utility files: {e}") # For Colab environments, provide a helpful error message if 'google.colab' in str(sys.modules): raise FileNotFoundError( f"Missing utility files: {missing_files}. " "This appears to be a Google Colab environment. " "Please ensure you have cloned the repository and the utility files are available. " "Try running: !git clone https://huggingface.co/hemantn/ablang2" ) else: raise FileNotFoundError( f"Missing utility files: {missing_files}. " "These files are required for the adapter to work. " "Please ensure the repository is properly set up." ) return True # Ensure utility files are available before importing ensure_utility_files_available() # Debug: Check what files are in the current directory print(f"📁 Files in current directory ({current_dir}):") for f in os.listdir(current_dir): if f.endswith('.py'): print(f" {f}") # Import utility modules directly (no package structure needed) import sys import os # Ensure we import from the cache directory, not from /content cache_dir = os.path.dirname(os.path.abspath(__file__)) if cache_dir not in sys.path: sys.path.insert(0, cache_dir) # Remove /content from sys.path to avoid conflicts content_path = '/content' if content_path in sys.path: sys.path.remove(content_path) print(f"✅ Removed {content_path} from sys.path to avoid import conflicts") # Import utility modules try: from restoration import AbRestore from ablang_encodings import AbEncoding from alignment import AbAlignment from scores import AbScores import torch import numpy as np from extra_utils import res_to_seq, res_to_list print("✅ Successfully imported utility modules from cache directory") except ImportError as e: print(f"❌ Import error: {e}") print(f"🔧 Current sys.path: {sys.path}") print(f"🔧 Cache directory: {cache_dir}") raise class HuggingFaceTokenizerAdapter: def __init__(self, tokenizer, device): self.tokenizer = tokenizer self.device = device self.pad_token_id = tokenizer.pad_token_id self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab self.inv_vocab = {v: k for k, v in self.vocab.items()} self.all_special_tokens = tokenizer.all_special_tokens def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device if device is None else device) if mode == 'decode': # seqs is a tensor of token ids if isinstance(seqs, torch.Tensor): seqs = seqs.cpu().numpy() decoded = [] for i, seq in enumerate(seqs): chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code # The length is not always available, so use len(chars) as fallback from extra_utils import res_to_seq formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') decoded.append(formatted) return decoded return input_ids class HFAbRestore(AbRestore): def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): super().__init__(spread=spread, device=device, ncpu=ncpu) self.used_device = device self._hf_model = hf_model self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) @property def AbLang(self): def model_call(x): output = self._hf_model(x) if hasattr(output, 'last_hidden_state'): return output.last_hidden_state return output return model_call def restore(self, seqs, align=False, **kwargs): """Restore masked residues in antibody sequences.""" if isinstance(seqs, str): seqs = [seqs] n_seqs = len(seqs) if align: # Implement alignment using ANARCI to create spread sequences seqs = self._sequence_aligning(seqs) nr_seqs = len(seqs)//self.spread tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) predictions = self.AbLang(tokens)[:,:,1:21] # Reshape tokens = tokens.reshape(nr_seqs, self.spread, -1) predictions = predictions.reshape(nr_seqs, self.spread, -1, 20) seqs = seqs.reshape(nr_seqs, -1) # Find index of best predictions best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1) # Select best predictions tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1) predictions = predictions[range(predictions.shape[0]), best_seq_idx] seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1) else: tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) predictions = self.AbLang(tokens)[:,:,1:21] predicted_tokens = torch.max(predictions, -1).indices + 1 restored_tokens = torch.where(tokens==23, predicted_tokens, tokens) restored_seqs = self.tokenizer(restored_tokens, mode="decode") if n_seqs < len(restored_seqs): restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])] seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])] from extra_utils import res_to_seq return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]]) def _sequence_aligning(self, seqs): """Create spread sequences using ANARCI alignment.""" tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs] spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')] spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')] return np.concatenate([np.array(spread_heavy),np.array(spread_light)]) def _create_spread_of_sequences(self, seqs, chain = 'H'): """Create spread sequences using ANARCI.""" import pandas as pd import anarci chain_idx = 0 if chain == 'H' else 1 numbered_seqs = anarci.run_anarci( pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), ncpu=self.ncpu, scheme='imgt', allowed_species=['human', 'mouse'], ) anarci_data = pd.DataFrame( [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]], columns=['anarci'] ).astype('|<{vl}>" class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): """ Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. Automatically uses CUDA if available, otherwise CPU. """ def __init__(self, model, tokenizer, device=None, ncpu=1): super().__init__() if device is None: self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.used_device = torch.device(device) self.AbLang = model # HuggingFace model instance self.tokenizer = tokenizer self.AbLang.to(self.used_device) self.AbLang.eval() # Always get AbRep from the underlying model if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): self.AbRep = self.AbLang.model.AbRep else: raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") self.ncpu = ncpu self.spread = 11 # For compatibility with original utilities # The following is no longer needed since all_special_tokens now returns IDs directly # self.tokenizer.all_special_token_ids = [ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens # ] # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens # self.tokenizer.all_special_tokens = [ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str # ] def freeze(self): self.AbLang.eval() def unfreeze(self): self.AbLang.train() def _encode_sequences(self, seqs): # Override to use HuggingFace tokenizer interface tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') tokens = extract_input_ids(tokens, self.used_device) return self.AbRep(tokens).last_hidden_states.detach() def _predict_logits(self, seqs): # Override to use HuggingFace tokenizer interface tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') tokens = extract_input_ids(tokens, self.used_device) output = self.AbLang(tokens) if hasattr(output, 'last_hidden_state'): return output.last_hidden_state.detach() return output.detach() def _predict_logits_with_step_masking(self, seqs): # Override the stepwise masking method to use HuggingFace tokenizer tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') tokens = extract_input_ids(tokens, self.used_device) logits = [] for single_seq_tokens in tokens: tkn_len = len(single_seq_tokens) masked_tokens = single_seq_tokens.repeat(tkn_len, 1) for num in range(tkn_len): masked_tokens[num, num] = self.tokenizer.mask_token_id with torch.no_grad(): logits_tmp = self.AbLang(masked_tokens) logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)]) logits.append(logits_tmp) return torch.stack(logits, dim=0) def _preprocess_labels(self, labels): labels = extract_input_ids(labels, self.used_device) return labels def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): """ Use different modes for different usecases, mimicking the original pretrained class. """ # Local implementation of format_seq_input def format_seq_input(seqs, fragmented=False): """Format input sequences for processing.""" if isinstance(seqs[0], str): seqs = [seqs] if fragmented: # For fragmented sequences, format as VH|VL without angle brackets formatted_seqs = [] for seq in seqs: if isinstance(seq, (list, tuple)) and len(seq) == 2: heavy, light = seq[0], seq[1] formatted_seqs.append(f"{heavy}|{light}") else: formatted_seqs.append(seq) return formatted_seqs, 'HL' else: # For non-fragmented sequences, add angle brackets: | formatted_seqs = [] for seq in seqs: if isinstance(seq, (list, tuple)) and len(seq) == 2: heavy, light = seq[0], seq[1] # Add angle brackets and handle empty sequences heavy_part = f"<{heavy}>" if heavy else "<>" light_part = f"<{light}>" if light else "<>" formatted_seqs.append(f"{heavy_part}|{light_part}".replace("<>", "")) else: formatted_seqs.append(seq) return formatted_seqs, 'HL' valid_modes = [ 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', 'pseudo_log_likelihood', 'confidence' ] if mode not in valid_modes: raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") seqs, chain = format_seq_input(seqs, fragmented=fragmented) if align: numbered_seqs, seqs, number_alignment = self.number_sequences( seqs, chain=chain, fragmented=fragmented ) else: numbered_seqs = None number_alignment = None subset_list = [] for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) return self.reformat_subsets( subset_list, mode=mode, align=align, numbered_seqs=numbered_seqs, seqs=seqs, number_alignment=number_alignment, ) def pseudo_log_likelihood(self, seqs, **kwargs): """ Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. """ # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Tokenize all sequences in batch labels = self.tokenizer( formatted_seqs, padding=True, return_tensors='pt' ) labels = extract_input_ids(labels, self.used_device) # Convert special tokens to IDs if isinstance(self.tokenizer.all_special_tokens[0], int): special_token_ids = set(self.tokenizer.all_special_tokens) else: special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) pad_token_id = self.tokenizer.pad_token_id mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) if mask_token_id is None: mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) plls = [] with torch.no_grad(): for i, seq_label in enumerate(labels): seq_pll = [] for j, token_id in enumerate(seq_label): if token_id.item() in special_token_ids or token_id.item() == pad_token_id: continue masked = seq_label.clone() masked[j] = mask_token_id logits = self.AbLang(masked.unsqueeze(0)) if hasattr(logits, 'last_hidden_state'): logits = logits.last_hidden_state logits = logits[0, j] nll = torch.nn.functional.cross_entropy( logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" ) seq_pll.append(-nll.item()) if seq_pll: plls.append(np.mean(seq_pll)) else: plls.append(float('nan')) return np.array(plls) def seqcoding(self, seqs, **kwargs): """Sequence specific representations - returns 480-dimensional embeddings for each sequence.""" # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Get embeddings using the model embeddings = self._encode_sequences(formatted_seqs) # Return sequence-level embeddings (mean pooling over sequence length) # Remove batch dimension and take mean over sequence dimension if len(embeddings.shape) == 3: # [batch_size, seq_len, hidden_size] seq_embeddings = embeddings.mean(dim=1) # [batch_size, hidden_size] else: seq_embeddings = embeddings return seq_embeddings.cpu().numpy() def rescoding(self, seqs, align=False, **kwargs): """Residue specific representations - returns 480-dimensional embeddings for each residue.""" # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Get embeddings using the model embeddings = self._encode_sequences(formatted_seqs) # Return residue-level embeddings # embeddings shape: [batch_size, seq_len, hidden_size] if len(embeddings.shape) == 3: # Convert to numpy and return as list of arrays for each sequence embeddings_np = embeddings.cpu().numpy() return [embeddings_np[i] for i in range(embeddings_np.shape[0])] else: return embeddings.cpu().numpy() def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): """Likelihood of mutations - returns logits for each amino acid at each position.""" # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Get logits if stepwise_masking: logits = self._predict_logits_with_step_masking(formatted_seqs) else: logits = self._predict_logits(formatted_seqs) # Return logits as numpy array return logits.cpu().numpy() def confidence(self, seqs, **kwargs): """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) plls = [] for seq in formatted_seqs: tokens = self.tokenizer([seq], padding=True, return_tensors='pt') input_ids = extract_input_ids(tokens, self.used_device) with torch.no_grad(): output = self.AbLang(input_ids) if hasattr(output, 'last_hidden_state'): logits = output.last_hidden_state else: logits = output # Get the sequence (remove batch dimension) logits = logits[0] # [seq_len, vocab_size] input_ids = input_ids[0] # [seq_len] # Exclude all special tokens (pad, mask, etc.) if isinstance(self.tokenizer.all_special_tokens[0], int): special_token_ids = set(self.tokenizer.all_special_tokens) else: special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) if valid_mask.sum() > 0: valid_logits = logits[valid_mask] valid_labels = input_ids[valid_mask] # Calculate cross-entropy loss nll = torch.nn.functional.cross_entropy( valid_logits, valid_labels, reduction="mean" ) pll = -nll.item() else: pll = 0.0 plls.append(pll) return np.array(plls, dtype=np.float32) def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): """ Probability of mutations - applies softmax to logits to get probabilities """ # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Get logits if stepwise_masking: # For stepwise masking, we need to implement it similar to likelihood # This is a simplified version - you might want to implement full stepwise masking logits = self._predict_logits(formatted_seqs) else: logits = self._predict_logits(formatted_seqs) # Apply softmax to get probabilities probs = logits.softmax(-1).cpu().numpy() if align: return probs else: # Return residue-level probabilities (excluding special tokens) return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] def restore(self, seqs, align=False, **kwargs): hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) restored = hf_abrestore.restore(seqs, align=align) # Apply angle brackets formatting to match original format if isinstance(restored, np.ndarray): restored = np.array([add_angle_brackets(seq) for seq in restored]) else: restored = [add_angle_brackets(seq) for seq in restored] return restored def extract_input_ids(tokens, device): if hasattr(tokens, 'input_ids'): return tokens.input_ids.to(device) elif isinstance(tokens, dict): if 'input_ids' in tokens: return tokens['input_ids'].to(device) else: for v in tokens.values(): if hasattr(v, 'ndim') or torch.is_tensor(v): return v.to(device) elif torch.is_tensor(tokens): return tokens.to(device) else: raise ValueError("Could not extract input_ids from tokenizer output")