File size: 4,245 Bytes
332656d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from typing import List, Optional, Union
import os

MASK = "#"
MSA_PAD = "!"
UL_ALPHABET_PLUS = "ACDEFGHIKLMNPQRSTVWYBZXJOU-*#@!/[]{}"
MSA_AAS = "ACDEFGHIKLMNPQRSTVWYBZXJOU-"
GAP = "-"
START = "@"
STOP = "*"
SEP = "/"
END_AL = "]"
END_UL = "}"
START_AL = "["
START_UL = "{"

class ProteinTokenizer(PreTrainedTokenizer):

    def __init__(
        self,
        protein_alphabet: str = UL_ALPHABET_PLUS,
        model_max_length: int = 2048,
        pad_token=MSA_PAD,
        mask_token=MASK,
        all_aas=MSA_AAS,
        gap_token=GAP,
        bos_token=START,
        eos_token=STOP,
        sep_token=SEP,
        **kwargs
    ):
        """Character tokenizer for Hugging Face transformers.

        model_max_length (int): Model maximum sequence length.
        """
        self.alphabet = list("".join(protein_alphabet))
        self.all_aas = list("".join(all_aas))
        self.a_to_i = {u: i for i, u in enumerate(self.alphabet)}
        self.i_to_a = {i: u for i, u in enumerate(self.alphabet)}
        self.gap_token = gap_token

        
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
        mask_token = AddedToken(mask_token, lstrip=False, rstrip=False) if isinstance(mask_token, str) else mask_token 
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
        gap_token = AddedToken(gap_token, lstrip=False, rstrip=False) if isinstance(gap_token, str) else gap_token

        super().__init__(
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            bos_token=bos_token,
            sep_token=sep_token,
            model_max_length=model_max_length,
            **kwargs
        )

    @property
    def vocab_size(self):
        return len(self.alphabet)
    
    @property
    def gap_token_id(self):
        return self.convert_tokens_to_ids(self.gap_token)

    def get_vocab(self):
        return self.a_to_i

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token) -> int:
        return self.a_to_i[token]

    def _convert_id_to_token(self, index) -> str:
        return self.i_to_a[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        result = token_ids_0
        if token_ids_1 is not None:
            raise NotImplementedError("This tokenizer does not support two sequences")
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [0] * len(token_ids_0)
        if token_ids_1 is not None:
            raise NotImplementedError("This tokenizer does not support two sequences")

        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Identifies the type of token. 0 for the first sentence, 1 for the second sentence if it exists
        """

        result = len(token_ids_0) * [0]

        if token_ids_1 is not None:
            raise NotImplementedError("This tokenizer does not support two sequences")
        return result

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        super().save_pretrained(save_directory, **kwargs)
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
        return ()