| import torch | |
| from torch import Tensor | |
| from typing import Optional, Tuple | |
| import torch.nn.functional as F | |
| def unpad_input( | |
| inputs: Tensor, | |
| attention_mask: Tensor, | |
| position_ids: Optional[Tensor] = None, | |
| labels: Optional[Tensor] = None, | |
| ) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]: | |
| """ | |
| Remove padding from input sequences. | |
| Args: | |
| inputs: (batch, seqlen, ...) or (batch, seqlen) | |
| attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. | |
| position_ids: (batch, seqlen), int, position ids | |
| labels: (batch, seqlen), int, labels | |
| Returns: | |
| unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. | |
| indices: (total_nnz) | |
| cu_seqlens: (batch + 1), the cumulative sequence lengths | |
| max_seqlen_in_batch: int | |
| unpadded_position_ids: (total_nnz) or None | |
| unpadded_labels: (total_nnz) or None | |
| """ | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = int(seqlens_in_batch.max().item()) | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
| if inputs.dim() == 2: | |
| unpadded_inputs = inputs.flatten()[indices] | |
| else: | |
| batch, seqlen, *rest = inputs.shape | |
| shape = batch * seqlen | |
| unpadded_inputs = inputs.view(shape, *rest)[indices] | |
| unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None | |
| unpadded_labels = labels.flatten()[indices] if labels is not None else None | |
| return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels | |
| def pad_input( | |
| inputs: Tensor, | |
| indices: Tensor, | |
| batch: int, | |
| seqlen: int, | |
| labels: Optional[Tensor] = None, | |
| ignore_index: int = -100, | |
| ) -> Tuple[Tensor, Optional[Tensor]]: | |
| """ | |
| Add padding to sequences. | |
| Args: | |
| inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. | |
| indices: (total_nnz) | |
| batch: int, batch size | |
| seqlen: int, max sequence length | |
| position_ids: (total_nnz) or None | |
| labels: (total_nnz) or None | |
| Returns: | |
| padded_inputs: (batch, seqlen, ...) or (batch, seqlen) | |
| padded_labels: (batch, seqlen) or None | |
| """ | |
| if inputs.dim() == 1: | |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) | |
| output[indices] = inputs | |
| padded_inputs = output.view(batch, seqlen) | |
| else: | |
| _, *rest = inputs.shape | |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) | |
| output[indices] = inputs | |
| padded_inputs = output.view(batch, seqlen, *rest) | |
| padded_labels = None | |
| if labels is not None: | |
| padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device) | |
| padded_labels[indices] = labels | |
| padded_labels = padded_labels.view(batch, seqlen) | |
| return padded_inputs, padded_labels | |