Spaces:
Runtime error
Runtime error
Hugo Flores
commited on
Commit
·
04c5b94
1
Parent(s):
534a89c
remove wavenet, readability
Browse files- vampnet/modules/layers.py +18 -0
- vampnet/modules/transformer.py +1 -1
- vampnet/modules/wavenet.py +0 -90
vampnet/modules/layers.py
CHANGED
|
@@ -8,6 +8,24 @@ import torch.nn.functional as F
|
|
| 8 |
from einops import rearrange
|
| 9 |
from torch.nn.utils import weight_norm
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def num_params(model):
|
| 13 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
| 8 |
from einops import rearrange
|
| 9 |
from torch.nn.utils import weight_norm
|
| 10 |
|
| 11 |
+
# Scripting this brings model speed up 1.4x
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def snake(x, alpha):
|
| 14 |
+
shape = x.shape
|
| 15 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 16 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 17 |
+
x = x.reshape(shape)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Snake1d(nn.Module):
|
| 22 |
+
def __init__(self, channels):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return snake(x, self.alpha)
|
| 28 |
+
|
| 29 |
|
| 30 |
def num_params(model):
|
| 31 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
vampnet/modules/transformer.py
CHANGED
|
@@ -377,7 +377,7 @@ class TransformerStack(nn.Module):
|
|
| 377 |
n_heads,
|
| 378 |
bidirectional,
|
| 379 |
is_decoder,
|
| 380 |
-
has_relative_attention_bias=(i == 0),
|
| 381 |
flash_attn=flash_attn,
|
| 382 |
dropout=dropout,
|
| 383 |
)
|
|
|
|
| 377 |
n_heads,
|
| 378 |
bidirectional,
|
| 379 |
is_decoder,
|
| 380 |
+
has_relative_attention_bias=True if (i == 0) else False,
|
| 381 |
flash_attn=flash_attn,
|
| 382 |
dropout=dropout,
|
| 383 |
)
|
vampnet/modules/wavenet.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
from einops import rearrange
|
| 3 |
-
|
| 4 |
-
from voicegpt.nn import WaveNet
|
| 5 |
-
|
| 6 |
-
class AutoregMLP(nn.Module):
|
| 7 |
-
"""Implements an autoregressive ConvNet decoder
|
| 8 |
-
Refer to SampleRNN (https://arxiv.org/abs/1612.07837) for motivation
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
def __init__(
|
| 12 |
-
self,
|
| 13 |
-
vocab_size: int,
|
| 14 |
-
d_model: int,
|
| 15 |
-
n_layers: int,
|
| 16 |
-
n_fine_tokens: int = 6,
|
| 17 |
-
n_tokens: int = 9,
|
| 18 |
-
dropout: float = 0.1,
|
| 19 |
-
activation: str = "gelu",
|
| 20 |
-
causal: bool = True,
|
| 21 |
-
):
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.n_fine = n_fine_tokens
|
| 24 |
-
self.n_layers = n_layers
|
| 25 |
-
self.upsampler = nn.Linear(d_model, d_model * n_fine_tokens)
|
| 26 |
-
|
| 27 |
-
self.wavenet = WaveNet(
|
| 28 |
-
d_model,
|
| 29 |
-
d_model,
|
| 30 |
-
d_model,
|
| 31 |
-
n_layers,
|
| 32 |
-
n_fine_tokens,
|
| 33 |
-
dropout=dropout,
|
| 34 |
-
activation=activation,
|
| 35 |
-
causal=causal,
|
| 36 |
-
)
|
| 37 |
-
self.ff_output = nn.Linear(d_model, vocab_size * n_tokens, bias=False)
|
| 38 |
-
|
| 39 |
-
def time_upsample(self, h_t_coarse):
|
| 40 |
-
"""Upsamples the conditioning hidden states to match the time resolution
|
| 41 |
-
of output tokens
|
| 42 |
-
Parameters
|
| 43 |
-
----------
|
| 44 |
-
h_t_coarse : Tensor[B x T_coarse x D]
|
| 45 |
-
Conditioning hidden states in coarse time-scale
|
| 46 |
-
Returns
|
| 47 |
-
-------
|
| 48 |
-
Tensor[B x T_fine x D]
|
| 49 |
-
Conditioning hidden states in fine time-scale
|
| 50 |
-
"""
|
| 51 |
-
# Upsample the transformer hidden states to fine scale
|
| 52 |
-
h_t_fine = rearrange(
|
| 53 |
-
self.upsampler(h_t_coarse), "b t (n d) -> b (t n) d", n=self.n_fine
|
| 54 |
-
)
|
| 55 |
-
return h_t_fine
|
| 56 |
-
|
| 57 |
-
def decode_logits(self, x_tm1, h_t_fine):
|
| 58 |
-
"""Decodes output logits conditioned on previous output
|
| 59 |
-
tokens (upto timestep t-1) and conditioning hidden states
|
| 60 |
-
using an autoregressive WaveNet
|
| 61 |
-
Parameters
|
| 62 |
-
----------
|
| 63 |
-
x_tm1 : Tensor[B x T x D]
|
| 64 |
-
h_t_fine : Tensor[B x T x D]
|
| 65 |
-
Returns
|
| 66 |
-
-------
|
| 67 |
-
Tensor[B x T x vocab_size]
|
| 68 |
-
Predicted logits
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
# Compute wavenet layers and predict logits
|
| 72 |
-
o_t = self.wavenet(x_tm1, h_t_fine)
|
| 73 |
-
return self.ff_output(o_t)
|
| 74 |
-
|
| 75 |
-
def forward(self, x_tm1, h_t_coarse):
|
| 76 |
-
"""Computes autoregressive conditional probability distribution
|
| 77 |
-
using a WaveNet decoder
|
| 78 |
-
Parameters
|
| 79 |
-
----------
|
| 80 |
-
x_tm1 : Tensor[B x T_fine x D]
|
| 81 |
-
Embeddings of tokens at fine time-scale
|
| 82 |
-
h_t_coarse : Tensor[B x T_coarse x D]
|
| 83 |
-
Hidden states at coarse time scale
|
| 84 |
-
Returns
|
| 85 |
-
-------
|
| 86 |
-
Tensor[B x T_fine x vocab_size]
|
| 87 |
-
Predicted logits at fine time-scale
|
| 88 |
-
"""
|
| 89 |
-
h_t_fine = self.time_upsample(h_t_coarse)
|
| 90 |
-
return self.decode_logits(x_tm1, h_t_fine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|