ablang2 / tokenizer_ablang2paired.py
hemantn's picture
Fix tokenizer and format_seq_input to properly handle paired sequences with angle brackets
ed12887
import json
import os
from transformers import PreTrainedTokenizer
class AbLang2PairedTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids"]
def __init__(self, vocab_file=None, **kwargs):
if vocab_file is None:
# Try to find vocab file in the current directory
vocab_file = "vocab.json"
self.vocab_file = vocab_file
with open(vocab_file, "r", encoding="utf-8") as f:
self.vocab = json.load(f)
# Set required token attributes (all as strings, standard for HF)
kwargs.setdefault("pad_token", "-")
kwargs.setdefault("mask_token", "*")
kwargs.setdefault("unk_token", "X")
super().__init__(**kwargs)
@property
def pad_token_id(self):
return self.vocab[self.pad_token]
@property
def mask_token_id(self):
return self.vocab[self.mask_token]
def _tokenize(self, text):
return list(text)
def tokenize(self, text, text_pair=None, **kwargs):
"""Tokenize text or text pair."""
if text_pair is not None:
# For paired sequences, combine them with a separator
combined_text = text + "|" + text_pair
return self._tokenize(combined_text)
else:
return self._tokenize(text)
def _convert_token_to_id(self, token):
return self.vocab.get(token, self.vocab[self.unk_token])
def _convert_id_to_token(self, index):
inv_vocab = {v: k for k, v in self.vocab.items()}
return inv_vocab.get(index, self.unk_token)
def get_vocab(self):
return self.vocab
def save_vocabulary(self, save_directory, filename_prefix=None):
os.makedirs(save_directory, exist_ok=True)
path = os.path.join(save_directory, (filename_prefix or "") + "vocab.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(self.vocab, f)
return (path,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
try:
from transformers.utils import cached_file
vocab_file = cached_file(
pretrained_model_name_or_path,
"vocab.json",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download", False),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
if vocab_file is None or not os.path.exists(vocab_file):
raise ValueError(f"Vocabulary file vocab.json not found in {pretrained_model_name_or_path}")
return cls(vocab_file=vocab_file, **kwargs)
except Exception as e:
# Fallback for local paths
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
if not os.path.exists(vocab_file):
raise ValueError(f"Vocabulary file {vocab_file} not found")
return cls(vocab_file=vocab_file, **kwargs)
def save_pretrained(self, save_directory, filename_prefix=None):
os.makedirs(save_directory, exist_ok=True)
vocab_files = self.save_vocabulary(save_directory, filename_prefix)
tokenizer_config = {
"tokenizer_class": f"{self.__class__.__module__}.{self.__class__.__name__}"
}
with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f:
json.dump(tokenizer_config, f, indent=2)
return vocab_files
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
# Handle different input formats
if isinstance(sequences, str):
# Single string: "VH|VL"
sequences = [sequences]
elif isinstance(sequences, list) and len(sequences) > 0:
if isinstance(sequences[0], list):
# List of lists: [['VH', 'VL'], ['VH2', 'VL2']]
sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences]
# List of strings: ["VH|VL", "VH2|VL2"] - already correct format
# Tokenize each sequence
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
# Padding
if padding:
maxlen = max(len(ids) for ids in input_ids)
input_ids = [ids + [self.pad_token_id] * (maxlen - len(ids)) for ids in input_ids]
# Return tensors if requested
if return_tensors == 'pt':
import torch
input_ids = torch.tensor(input_ids)
return {'input_ids': input_ids}