|
import numpy as np |
|
import torch |
|
|
|
from extra_utils import res_to_list, res_to_seq |
|
|
|
|
|
class AbEncoding: |
|
|
|
def __init__(self, device = 'cpu', ncpu = 1): |
|
|
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def _initiate_abencoding(self, model, tokenizer): |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
|
|
def _encode_sequences(self, seqs): |
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
with torch.no_grad(): |
|
return self.AbLang.AbRep(tokens).last_hidden_states |
|
|
|
def _predict_logits(self, seqs): |
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
with torch.no_grad(): |
|
return self.AbLang(tokens) |
|
|
|
def _predict_logits_with_step_masking(self, seqs): |
|
|
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
|
|
logits = [] |
|
for single_seq_tokens in tokens: |
|
|
|
tkn_len = len(single_seq_tokens) |
|
masked_tokens = single_seq_tokens.repeat(tkn_len, 1) |
|
for num in range(tkn_len): |
|
masked_tokens[num, num] = self.tokenizer.mask_token |
|
|
|
with torch.no_grad(): |
|
logits_tmp = self.AbLang(masked_tokens) |
|
|
|
logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)]) |
|
|
|
logits.append(logits_tmp) |
|
|
|
return torch.stack(logits, dim=0) |
|
|
|
def seqcoding(self, seqs, **kwargs): |
|
""" |
|
Sequence specific representations |
|
""" |
|
|
|
encodings = self._encode_sequences(seqs).cpu().numpy() |
|
|
|
lens = np.vectorize(len)(seqs) |
|
lens = np.tile(lens.reshape(-1,1,1), (encodings.shape[2], 1)) |
|
|
|
return np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(encodings,1,2), lens]) |
|
|
|
def rescoding(self, seqs, align=False, **kwargs): |
|
""" |
|
Residue specific representations. |
|
""" |
|
encodings = self._encode_sequences(seqs).cpu().numpy() |
|
|
|
if align: return encodings |
|
|
|
else: return [res_to_list(state, seq) for state, seq in zip(encodings, seqs)] |
|
|
|
def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
""" |
|
Likelihood of mutations |
|
""" |
|
if stepwise_masking: |
|
logits = self._predict_logits_with_step_masking(seqs).cpu().numpy() |
|
else: |
|
logits = self._predict_logits(seqs).cpu().numpy() |
|
|
|
if align: return logits |
|
|
|
else: return [res_to_list(state, seq) for state, seq in zip(logits, seqs)] |
|
|
|
def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): |
|
""" |
|
Probability of mutations |
|
""" |
|
if stepwise_masking: |
|
logits = self._predict_logits_with_step_masking(seqs) |
|
else: |
|
logits = self._predict_logits(seqs) |
|
probs = logits.softmax(-1).cpu().numpy() |
|
|
|
if align: return probs |
|
|
|
else: return [res_to_list(state, seq) for state, seq in zip(probs, seqs)] |
|
|