hemantn commited on
Commit
ed12887
·
1 Parent(s): 056b066

Fix tokenizer and format_seq_input to properly handle paired sequences with angle brackets

Browse files
Files changed (2) hide show
  1. adapter.py +24 -22
  2. tokenizer_ablang2paired.py +8 -1
adapter.py CHANGED
@@ -215,31 +215,33 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
215
  # Local implementation of format_seq_input
216
  def format_seq_input(seqs, fragmented=False):
217
  """Format input sequences for processing."""
 
 
 
218
  if fragmented:
219
- # For fragmented sequences, assume they're already in the right format
220
- return seqs, 'HL'
221
-
222
- # For paired sequences, format them as VH|VL
223
- formatted_seqs = []
224
- for seq in seqs:
225
- if isinstance(seq, (list, tuple)):
226
- if len(seq) == 2:
227
- # Heavy and light chain
228
  heavy, light = seq[0], seq[1]
229
- if heavy and light:
230
- formatted_seqs.append(f"{heavy}|{light}")
231
- elif heavy:
232
- formatted_seqs.append(heavy)
233
- elif light:
234
- formatted_seqs.append(light)
235
- else:
236
- formatted_seqs.append("")
237
  else:
238
- formatted_seqs.append(seq[0] if seq else "")
239
- else:
240
- formatted_seqs.append(seq)
241
-
242
- return formatted_seqs, 'HL'
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  valid_modes = [
245
  'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
 
215
  # Local implementation of format_seq_input
216
  def format_seq_input(seqs, fragmented=False):
217
  """Format input sequences for processing."""
218
+ if isinstance(seqs[0], str):
219
+ seqs = [seqs]
220
+
221
  if fragmented:
222
+ # For fragmented sequences, format as VH|VL without angle brackets
223
+ formatted_seqs = []
224
+ for seq in seqs:
225
+ if isinstance(seq, (list, tuple)) and len(seq) == 2:
 
 
 
 
 
226
  heavy, light = seq[0], seq[1]
227
+ formatted_seqs.append(f"{heavy}|{light}")
 
 
 
 
 
 
 
228
  else:
229
+ formatted_seqs.append(seq)
230
+ return formatted_seqs, 'HL'
231
+ else:
232
+ # For non-fragmented sequences, add angle brackets: <VH>|<VL>
233
+ formatted_seqs = []
234
+ for seq in seqs:
235
+ if isinstance(seq, (list, tuple)) and len(seq) == 2:
236
+ heavy, light = seq[0], seq[1]
237
+ # Add angle brackets and handle empty sequences
238
+ heavy_part = f"<{heavy}>" if heavy else "<>"
239
+ light_part = f"<{light}>" if light else "<>"
240
+ formatted_seqs.append(f"{heavy_part}|{light_part}".replace("<>", ""))
241
+ else:
242
+ formatted_seqs.append(seq)
243
+
244
+ return formatted_seqs, 'HL'
245
 
246
  valid_modes = [
247
  'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
tokenizer_ablang2paired.py CHANGED
@@ -100,9 +100,16 @@ class AbLang2PairedTokenizer(PreTrainedTokenizer):
100
  return vocab_files
101
 
102
  def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
103
- # Accepts a string or a list of strings
104
  if isinstance(sequences, str):
 
105
  sequences = [sequences]
 
 
 
 
 
 
106
  # Tokenize each sequence
107
  input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
108
  # Padding
 
100
  return vocab_files
101
 
102
  def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
103
+ # Handle different input formats
104
  if isinstance(sequences, str):
105
+ # Single string: "VH|VL"
106
  sequences = [sequences]
107
+ elif isinstance(sequences, list) and len(sequences) > 0:
108
+ if isinstance(sequences[0], list):
109
+ # List of lists: [['VH', 'VL'], ['VH2', 'VL2']]
110
+ sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences]
111
+ # List of strings: ["VH|VL", "VH2|VL2"] - already correct format
112
+
113
  # Tokenize each sequence
114
  input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
115
  # Padding