|
import json |
|
import os |
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class AbLang2PairedTokenizer(PreTrainedTokenizer): |
|
vocab_files_names = {"vocab_file": "vocab.json"} |
|
model_input_names = ["input_ids"] |
|
|
|
def __init__(self, vocab_file=None, **kwargs): |
|
if vocab_file is None: |
|
|
|
vocab_file = "vocab.json" |
|
|
|
self.vocab_file = vocab_file |
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
self.vocab = json.load(f) |
|
|
|
|
|
kwargs.setdefault("pad_token", "-") |
|
kwargs.setdefault("mask_token", "*") |
|
kwargs.setdefault("unk_token", "X") |
|
|
|
super().__init__(**kwargs) |
|
|
|
@property |
|
def pad_token_id(self): |
|
return self.vocab[self.pad_token] |
|
|
|
@property |
|
def mask_token_id(self): |
|
return self.vocab[self.mask_token] |
|
|
|
def _tokenize(self, text): |
|
return list(text) |
|
|
|
def tokenize(self, text, text_pair=None, **kwargs): |
|
"""Tokenize text or text pair.""" |
|
if text_pair is not None: |
|
|
|
combined_text = text + "|" + text_pair |
|
return self._tokenize(combined_text) |
|
else: |
|
return self._tokenize(text) |
|
|
|
def _convert_token_to_id(self, token): |
|
return self.vocab.get(token, self.vocab[self.unk_token]) |
|
|
|
def _convert_id_to_token(self, index): |
|
inv_vocab = {v: k for k, v in self.vocab.items()} |
|
return inv_vocab.get(index, self.unk_token) |
|
|
|
def get_vocab(self): |
|
return self.vocab |
|
|
|
def save_vocabulary(self, save_directory, filename_prefix=None): |
|
os.makedirs(save_directory, exist_ok=True) |
|
path = os.path.join(save_directory, (filename_prefix or "") + "vocab.json") |
|
with open(path, "w", encoding="utf-8") as f: |
|
json.dump(self.vocab, f) |
|
return (path,) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
try: |
|
from transformers.utils import cached_file |
|
vocab_file = cached_file( |
|
pretrained_model_name_or_path, |
|
"vocab.json", |
|
cache_dir=kwargs.get("cache_dir"), |
|
force_download=kwargs.get("force_download", False), |
|
resume_download=kwargs.get("resume_download", False), |
|
proxies=kwargs.get("proxies"), |
|
token=kwargs.get("token"), |
|
revision=kwargs.get("revision"), |
|
local_files_only=kwargs.get("local_files_only", False), |
|
) |
|
|
|
if vocab_file is None or not os.path.exists(vocab_file): |
|
raise ValueError(f"Vocabulary file vocab.json not found in {pretrained_model_name_or_path}") |
|
|
|
return cls(vocab_file=vocab_file, **kwargs) |
|
except Exception as e: |
|
|
|
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") |
|
if not os.path.exists(vocab_file): |
|
raise ValueError(f"Vocabulary file {vocab_file} not found") |
|
return cls(vocab_file=vocab_file, **kwargs) |
|
|
|
def save_pretrained(self, save_directory, filename_prefix=None): |
|
os.makedirs(save_directory, exist_ok=True) |
|
vocab_files = self.save_vocabulary(save_directory, filename_prefix) |
|
|
|
tokenizer_config = { |
|
"tokenizer_class": f"{self.__class__.__module__}.{self.__class__.__name__}" |
|
} |
|
with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: |
|
json.dump(tokenizer_config, f, indent=2) |
|
|
|
return vocab_files |
|
|
|
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs): |
|
|
|
if isinstance(sequences, str): |
|
|
|
sequences = [sequences] |
|
elif isinstance(sequences, list) and len(sequences) > 0: |
|
if isinstance(sequences[0], list): |
|
|
|
sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences] |
|
|
|
|
|
|
|
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences] |
|
|
|
if padding: |
|
maxlen = max(len(ids) for ids in input_ids) |
|
input_ids = [ids + [self.pad_token_id] * (maxlen - len(ids)) for ids in input_ids] |
|
|
|
if return_tensors == 'pt': |
|
import torch |
|
input_ids = torch.tensor(input_ids) |
|
return {'input_ids': input_ids} |
|
|
|
|