Fix tokenizer and format_seq_input to properly handle paired sequences with angle brackets
Browse files- adapter.py +24 -22
- tokenizer_ablang2paired.py +8 -1
adapter.py
CHANGED
@@ -215,31 +215,33 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
215 |
# Local implementation of format_seq_input
|
216 |
def format_seq_input(seqs, fragmented=False):
|
217 |
"""Format input sequences for processing."""
|
|
|
|
|
|
|
218 |
if fragmented:
|
219 |
-
# For fragmented sequences,
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
formatted_seqs = []
|
224 |
-
for seq in seqs:
|
225 |
-
if isinstance(seq, (list, tuple)):
|
226 |
-
if len(seq) == 2:
|
227 |
-
# Heavy and light chain
|
228 |
heavy, light = seq[0], seq[1]
|
229 |
-
|
230 |
-
formatted_seqs.append(f"{heavy}|{light}")
|
231 |
-
elif heavy:
|
232 |
-
formatted_seqs.append(heavy)
|
233 |
-
elif light:
|
234 |
-
formatted_seqs.append(light)
|
235 |
-
else:
|
236 |
-
formatted_seqs.append("")
|
237 |
else:
|
238 |
-
formatted_seqs.append(seq
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
valid_modes = [
|
245 |
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
|
|
|
215 |
# Local implementation of format_seq_input
|
216 |
def format_seq_input(seqs, fragmented=False):
|
217 |
"""Format input sequences for processing."""
|
218 |
+
if isinstance(seqs[0], str):
|
219 |
+
seqs = [seqs]
|
220 |
+
|
221 |
if fragmented:
|
222 |
+
# For fragmented sequences, format as VH|VL without angle brackets
|
223 |
+
formatted_seqs = []
|
224 |
+
for seq in seqs:
|
225 |
+
if isinstance(seq, (list, tuple)) and len(seq) == 2:
|
|
|
|
|
|
|
|
|
|
|
226 |
heavy, light = seq[0], seq[1]
|
227 |
+
formatted_seqs.append(f"{heavy}|{light}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
else:
|
229 |
+
formatted_seqs.append(seq)
|
230 |
+
return formatted_seqs, 'HL'
|
231 |
+
else:
|
232 |
+
# For non-fragmented sequences, add angle brackets: <VH>|<VL>
|
233 |
+
formatted_seqs = []
|
234 |
+
for seq in seqs:
|
235 |
+
if isinstance(seq, (list, tuple)) and len(seq) == 2:
|
236 |
+
heavy, light = seq[0], seq[1]
|
237 |
+
# Add angle brackets and handle empty sequences
|
238 |
+
heavy_part = f"<{heavy}>" if heavy else "<>"
|
239 |
+
light_part = f"<{light}>" if light else "<>"
|
240 |
+
formatted_seqs.append(f"{heavy_part}|{light_part}".replace("<>", ""))
|
241 |
+
else:
|
242 |
+
formatted_seqs.append(seq)
|
243 |
+
|
244 |
+
return formatted_seqs, 'HL'
|
245 |
|
246 |
valid_modes = [
|
247 |
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
|
tokenizer_ablang2paired.py
CHANGED
@@ -100,9 +100,16 @@ class AbLang2PairedTokenizer(PreTrainedTokenizer):
|
|
100 |
return vocab_files
|
101 |
|
102 |
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
|
103 |
-
#
|
104 |
if isinstance(sequences, str):
|
|
|
105 |
sequences = [sequences]
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
# Tokenize each sequence
|
107 |
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
|
108 |
# Padding
|
|
|
100 |
return vocab_files
|
101 |
|
102 |
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
|
103 |
+
# Handle different input formats
|
104 |
if isinstance(sequences, str):
|
105 |
+
# Single string: "VH|VL"
|
106 |
sequences = [sequences]
|
107 |
+
elif isinstance(sequences, list) and len(sequences) > 0:
|
108 |
+
if isinstance(sequences[0], list):
|
109 |
+
# List of lists: [['VH', 'VL'], ['VH2', 'VL2']]
|
110 |
+
sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences]
|
111 |
+
# List of strings: ["VH|VL", "VH2|VL2"] - already correct format
|
112 |
+
|
113 |
# Tokenize each sequence
|
114 |
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
|
115 |
# Padding
|