Handle truncated image boundaries in `_convert` to avoid tensor size mismatch
#54
by
maikezu
- opened
- processing_minicpmo.py +12 -5
processing_minicpmo.py
CHANGED
|
@@ -269,7 +269,7 @@ class MiniCPMOProcessor(ProcessorMixin):
|
|
| 269 |
image_start_idx += 1
|
| 270 |
image_end_idx = torch.where(end_cond)[0]
|
| 271 |
|
| 272 |
-
valid_image_nums =
|
| 273 |
|
| 274 |
image_bounds = torch.hstack(
|
| 275 |
[
|
|
@@ -278,16 +278,23 @@ class MiniCPMOProcessor(ProcessorMixin):
|
|
| 278 |
]
|
| 279 |
)
|
| 280 |
|
|
|
|
| 281 |
## audio bound
|
| 282 |
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
| 283 |
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
| 284 |
-
|
| 285 |
-
audio_bounds = torch.hstack([
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
| 288 |
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
| 289 |
-
|
| 290 |
-
spk_bounds = torch.hstack([
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
return input_ids, image_bounds, audio_bounds, spk_bounds
|
| 293 |
|
|
|
|
| 269 |
image_start_idx += 1
|
| 270 |
image_end_idx = torch.where(end_cond)[0]
|
| 271 |
|
| 272 |
+
valid_image_nums = min(len(image_start_idx), len(image_end_idx))
|
| 273 |
|
| 274 |
image_bounds = torch.hstack(
|
| 275 |
[
|
|
|
|
| 278 |
]
|
| 279 |
)
|
| 280 |
|
| 281 |
+
|
| 282 |
## audio bound
|
| 283 |
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
| 284 |
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
| 285 |
+
valid_audio_nums = min(len(audio_start_idx), len(audio_end_idx))
|
| 286 |
+
audio_bounds = torch.hstack([
|
| 287 |
+
(audio_start_idx[:valid_audio_nums] + 1).unsqueeze(-1),
|
| 288 |
+
audio_end_idx[:valid_audio_nums].unsqueeze(-1)
|
| 289 |
+
])
|
| 290 |
|
| 291 |
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
| 292 |
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
| 293 |
+
valid_spk_nums = min(len(spk_start_idx), len(spk_end_idx))
|
| 294 |
+
spk_bounds = torch.hstack([
|
| 295 |
+
(spk_start_idx[:valid_spk_nums] + 1).unsqueeze(-1),
|
| 296 |
+
spk_end_idx[:valid_spk_nums].unsqueeze(-1)
|
| 297 |
+
])
|
| 298 |
|
| 299 |
return input_ids, image_bounds, audio_bounds, spk_bounds
|
| 300 |
|