Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.loss import _Loss | |
| # Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28 | |
| class SmoothCrossEntropyLoss(_Loss): | |
| """ | |
| https://arxiv.org/abs/1512.00567 | |
| """ | |
| __constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction'] | |
| def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True): | |
| assert 0.0 <= label_smoothing <= 1.0 | |
| super().__init__(reduction=reduction) | |
| self.label_smoothing = label_smoothing | |
| self.vocab_size = vocab_size | |
| self.ignore_index = ignore_index | |
| self.input_is_logits = is_logits | |
| def forward(self, input, target): | |
| """ | |
| Args: | |
| input: [B * T, V] | |
| target: [B * T] | |
| Returns: | |
| cross entropy: [1] | |
| """ | |
| mask = (target == self.ignore_index).unsqueeze(-1) | |
| q = F.one_hot(target.long(), self.vocab_size).type(torch.float32) | |
| u = 1.0 / self.vocab_size | |
| q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u | |
| q_prime = q_prime.masked_fill(mask, 0) | |
| ce = self.cross_entropy_with_logits(q_prime, input) | |
| if self.reduction == 'mean': | |
| lengths = torch.sum(target != self.ignore_index) | |
| return ce.sum() / lengths | |
| elif self.reduction == 'sum': | |
| return ce.sum() | |
| else: | |
| raise NotImplementedError | |
| def cross_entropy_with_logits(self, p, q): | |
| return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1) | |