File size: 3,683 Bytes
712d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import torch

from extra_utils import res_to_seq, get_sequences_from_anarci


class AbRestore:
    def __init__(self, spread = 11, device = 'cpu', ncpu = 1):
        self.spread = spread
        self.device = device
        self.ncpu = ncpu
        
    def _initiate_abrestore(self, model, tokenizer):
        self.AbLang = model
        self.tokenizer = tokenizer

    def restore(self, seqs, align = False, **kwargs):
        """
        Restore sequences
        """
        n_seqs = len(seqs)
        
        if align:
            
            seqs = self._sequence_aligning(seqs)
            nr_seqs = len(seqs)//self.spread
            
            tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)          
            predictions = self.AbLang(tokens)[:,:,1:21]

            # Reshape
            tokens = tokens.reshape(nr_seqs, self.spread, -1)
            predictions = predictions.reshape(nr_seqs, self.spread, -1, 20)
            seqs = seqs.reshape(nr_seqs, -1)

            # Find index of best predictions
            best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1)

            # Select best predictions           
            tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1)
            predictions = predictions[range(predictions.shape[0]), best_seq_idx]
            seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1)

        else:
            tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
            predictions = self.AbLang(tokens)[:,:,1:21]

        predicted_tokens = torch.max(predictions, -1).indices + 1
        restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)

        restored_seqs = self.tokenizer(restored_tokens, mode="decode")

        if n_seqs < len(restored_seqs):
            restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
            seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
        
        return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
    
    def _create_spread_of_sequences(self, seqs, chain = 'H'):
        import pandas as pd
        import anarci
        
        chain_idx = 0 if chain == 'H' else 1
        numbered_seqs = anarci.run_anarci(
            pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(), 
            ncpu=self.ncpu, 
            scheme='imgt',
            allowed_species=['human', 'mouse'],
        )
        
        anarci_data = pd.DataFrame(
            [str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]], 
            columns=['anarci']
        ).astype('<U90')
        
        max_position = 128 if chain == 'H' else 127
        
        seqs = anarci_data.apply(
            lambda x: get_sequences_from_anarci(
                x.anarci, 
                max_position, 
                self.spread
            ), axis=1, result_type='expand'
        ).to_numpy().reshape(-1)
        
        return seqs
        
    
    def _sequence_aligning(self, seqs):

        tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs]
        
        spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')]
        spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')]
        
        return np.concatenate([np.array(spread_heavy),np.array(spread_light)])