import json import os from transformers import PreTrainedTokenizer class AbLang2PairedTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids"] def __init__(self, vocab_file=None, **kwargs): if vocab_file is None: # Try to find vocab file in the current directory vocab_file = "vocab.json" self.vocab_file = vocab_file with open(vocab_file, "r", encoding="utf-8") as f: self.vocab = json.load(f) # Set required token attributes (all as strings, standard for HF) kwargs.setdefault("pad_token", "-") kwargs.setdefault("mask_token", "*") kwargs.setdefault("unk_token", "X") super().__init__(**kwargs) @property def pad_token_id(self): return self.vocab[self.pad_token] @property def mask_token_id(self): return self.vocab[self.mask_token] def _tokenize(self, text): return list(text) def tokenize(self, text, text_pair=None, **kwargs): """Tokenize text or text pair.""" if text_pair is not None: # For paired sequences, combine them with a separator combined_text = text + "|" + text_pair return self._tokenize(combined_text) else: return self._tokenize(text) def _convert_token_to_id(self, token): return self.vocab.get(token, self.vocab[self.unk_token]) def _convert_id_to_token(self, index): inv_vocab = {v: k for k, v in self.vocab.items()} return inv_vocab.get(index, self.unk_token) def get_vocab(self): return self.vocab def save_vocabulary(self, save_directory, filename_prefix=None): os.makedirs(save_directory, exist_ok=True) path = os.path.join(save_directory, (filename_prefix or "") + "vocab.json") with open(path, "w", encoding="utf-8") as f: json.dump(self.vocab, f) return (path,) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): try: from transformers.utils import cached_file vocab_file = cached_file( pretrained_model_name_or_path, "vocab.json", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download", False), proxies=kwargs.get("proxies"), token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) if vocab_file is None or not os.path.exists(vocab_file): raise ValueError(f"Vocabulary file vocab.json not found in {pretrained_model_name_or_path}") return cls(vocab_file=vocab_file, **kwargs) except Exception as e: # Fallback for local paths vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") if not os.path.exists(vocab_file): raise ValueError(f"Vocabulary file {vocab_file} not found") return cls(vocab_file=vocab_file, **kwargs) def save_pretrained(self, save_directory, filename_prefix=None): os.makedirs(save_directory, exist_ok=True) vocab_files = self.save_vocabulary(save_directory, filename_prefix) tokenizer_config = { "tokenizer_class": f"{self.__class__.__module__}.{self.__class__.__name__}" } with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: json.dump(tokenizer_config, f, indent=2) return vocab_files def __call__(self, sequences, padding=False, return_tensors=None, **kwargs): # Handle different input formats if isinstance(sequences, str): # Single string: "VH|VL" sequences = [sequences] elif isinstance(sequences, list) and len(sequences) > 0: if isinstance(sequences[0], list): # List of lists: [['VH', 'VL'], ['VH2', 'VL2']] sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences] # List of strings: ["VH|VL", "VH2|VL2"] - already correct format # Tokenize each sequence input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences] # Padding if padding: maxlen = max(len(ids) for ids in input_ids) input_ids = [ids + [self.pad_token_id] * (maxlen - len(ids)) for ids in input_ids] # Return tensors if requested if return_tensors == 'pt': import torch input_ids = torch.tensor(input_ids) return {'input_ids': input_ids}