Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| from dataclasses import dataclass | |
| from typing import Union, Tuple, Literal | |
| import torch as T | |
| import torch.nn as nn | |
| from torch.nn.utils.parametrizations import weight_norm | |
| from utils import load_ckpt | |
| from utils.interp import print_colored | |
| from utils import si_module, get_activation | |
| # Adapted from https://github.com/facebookresearch/AudioDec | |
| def Conv1d1x1(in_channels, out_channels, bias=True): | |
| return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) | |
| class NonCausalConv1d(nn.Module): | |
| """1D noncausal convloution w/ 2-sides padding.""" | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=-1, | |
| dilation=1, | |
| groups=1, | |
| bias=True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| if padding < 0: | |
| padding = (kernel_size - 1) // 2 * dilation | |
| self.dilation = dilation | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): Float tensor variable with the shape (B, C, T). | |
| Returns: | |
| Tensor: Float tensor variable with the shape (B, C, T). | |
| """ | |
| x = self.conv(x) | |
| return x | |
| class NonCausalConvTranspose1d(nn.Module): | |
| """1D noncausal transpose convloution.""" | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding=-1, | |
| output_padding=-1, | |
| groups=1, | |
| bias=True, | |
| ): | |
| super().__init__() | |
| if padding < 0: | |
| padding = (stride+1) // 2 | |
| if output_padding < 0: | |
| output_padding = 1 if stride % 2 else 0 | |
| self.deconv = nn.ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| output_padding=output_padding, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): Float tensor variable with the shape (B, C, T). | |
| Returns: | |
| Tensor: Float tensor variable with the shape (B, C', T'). | |
| """ | |
| x = self.deconv(x) | |
| return x | |
| class CausalConv1d(NonCausalConv1d): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=True | |
| ): | |
| super(CausalConv1d, self).__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=0, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| self.stride = stride | |
| self.pad_length = (kernel_size - 1) * dilation | |
| def forward(self, x): | |
| pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) | |
| x = pad(x) | |
| return self.conv(x) | |
| class CausalConvTranspose1d(NonCausalConvTranspose1d): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| bias=True, | |
| pad_buffer=None, | |
| ): | |
| super(CausalConvTranspose1d, self).__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=0, | |
| output_padding=0, | |
| bias=bias, | |
| ) | |
| self.stride = stride | |
| self.pad_length = (math.ceil(kernel_size/stride) - 1) | |
| if pad_buffer is None: | |
| pad_buffer = T.zeros(1, in_channels, self.pad_length) | |
| self.register_buffer("pad_buffer", pad_buffer) | |
| def forward(self, x): | |
| pad = nn.ReplicationPad1d((self.pad_length, 0)) | |
| x = pad(x) | |
| return self.deconv(x)[:, :, self.stride : -self.stride] | |
| def inference(self, x): | |
| x = T.cat((self.pad_buffer, x), -1) | |
| self.pad_buffer = x[:, :, -self.pad_length:] | |
| return self.deconv(x)[:, :, self.stride : -self.stride] | |
| def reset_buffer(self): | |
| self.pad_buffer.zero_() | |
| class NonCausalResUnit(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=7, | |
| dilation=1, | |
| bias=False, | |
| ): | |
| super().__init__() | |
| self.activation = nn.ELU() | |
| self.conv1 = NonCausalConv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| self.conv2 = Conv1d1x1(out_channels, out_channels, bias) | |
| def forward(self, x): | |
| y = self.conv1(self.activation(x)) | |
| y = self.conv2(self.activation(y)) | |
| return x + y | |
| class CausalResUnit(NonCausalResUnit): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=7, | |
| dilation=1, | |
| bias=False, | |
| ): | |
| super(CausalResUnit, self).__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| self.conv1 = CausalConv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| def inference(self, x): | |
| y = self.conv1.inference(self.activation(x)) | |
| y = self.conv2(self.activation(y)) | |
| return x + y | |
| class ResNetBlock(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| stride, | |
| kernel_size=7, | |
| dilations=(1, 3, 9), | |
| bias=True, | |
| mode='encoder', | |
| ): | |
| super().__init__() | |
| assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" | |
| self.mode = mode | |
| self.stride = stride | |
| ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d | |
| res_channels = in_channels if mode == 'encoder' else out_channels | |
| res_units = [CausalResUnit( | |
| res_channels, | |
| res_channels, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| ) for dilation in dilations] | |
| if in_channels == out_channels: | |
| if mode == 'encoder': | |
| self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
| if mode == 'decoder': | |
| self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') | |
| conv_unit = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| bias=bias, | |
| ) if in_channels != out_channels else nn.Identity() | |
| else: | |
| conv_unit = ConvUnit( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(2 * stride), | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| if mode == 'encoder': | |
| if in_channels == out_channels: | |
| self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) | |
| else: | |
| self.res_block = nn.Sequential(*res_units, conv_unit) | |
| elif mode == 'decoder': | |
| if in_channels == out_channels: | |
| self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) | |
| else: | |
| self.res_block = nn.Sequential(conv_unit, *res_units) | |
| def forward(self, x): | |
| out = x | |
| for unit in self.res_block: | |
| out = unit(out) | |
| return out | |
| def inference(self, x): | |
| for unit in self.res_block: | |
| x = unit.inference(x) | |
| return x | |
| class ResNetStack(nn.Module): | |
| """ | |
| ResNet encoder or decoder stack. Channel ratios | |
| and strides take the default order of from | |
| data/io-layer, to the middle of the model. | |
| """ | |
| class Config: | |
| input_channels: int = 1 | |
| output_channels: int = 1 | |
| encode_channels: int = 32 | |
| decode_channel_multiplier: int = 1 | |
| latent_dim: int = None | |
| kernel_size: int = 7 | |
| bias: bool = True | |
| channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) | |
| strides: Tuple[int, ...] = (3, 4, 5, 5) | |
| mode: Literal['encoder', 'decoder'] = 'encoder' | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" | |
| self.mode = c.mode | |
| assert len(c.channel_ratios) == len(c.strides) | |
| channel_ratios = (1,) + c.channel_ratios | |
| strides = c.strides | |
| self.middle_channels = c.encode_channels * channel_ratios[-1] | |
| if c.mode == 'decoder': | |
| channel_ratios = tuple(reversed(channel_ratios)) | |
| strides = tuple(reversed(strides)) | |
| self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 | |
| res_blocks = [ResNetBlock( | |
| c.encode_channels * channel_ratios[s_idx] * self.multiplier, | |
| c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, | |
| stride, | |
| kernel_size=c.kernel_size, | |
| bias=c.bias, | |
| mode=c.mode, | |
| ) for s_idx, stride in enumerate(strides)] | |
| data_conv = CausalConv1d( | |
| in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, | |
| out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, | |
| kernel_size=c.kernel_size, | |
| stride=1, | |
| bias=False, | |
| ) | |
| if c.mode == 'encoder': | |
| self.res_stack = nn.Sequential(data_conv, *res_blocks) | |
| elif c.mode == 'decoder': | |
| self.res_stack = nn.Sequential(*res_blocks, data_conv) | |
| if c.latent_dim is not None: | |
| self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) | |
| if self.multiplier != 1: | |
| self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) | |
| def forward(self, x, return_feats=False): | |
| if self.c.latent_dim is not None and self.mode == 'decoder': | |
| x = self.latent_proj(x) | |
| if self.multiplier != 1: | |
| x = self.multiplier_proj(x) | |
| feats = [] | |
| for block in self.res_stack: | |
| x = block(x) | |
| if return_feats: | |
| feats.append(x) | |
| if self.c.latent_dim is not None and self.mode == 'encoder': | |
| x = self.latent_proj(x) | |
| if return_feats: | |
| feats.append(x) | |
| if return_feats: | |
| return feats | |
| return x | |
| def inference(self, x): | |
| for block in self.res_stack: | |
| x = block.inference(x) | |
| return x | |
| def reset_buffer(self): | |
| def _reset_buffer(m): | |
| if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): | |
| m.reset_buffer() | |
| self.apply(_reset_buffer) | |
| def reset_parameters(self): | |
| def _reset_parameters(m): | |
| if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): | |
| m.weight.data.normal_(0.0, 0.01) | |
| self.apply(_reset_parameters) | |
| def apply_weight_norm(self): | |
| def _apply_weight_norm(m): | |
| if isinstance(m, nn.Conv1d) or isinstance( | |
| m, nn.ConvTranspose1d | |
| ): | |
| nn.utils.parametrizations.weight_norm(m) | |
| self.apply(_apply_weight_norm) | |
| def remove_weight_norm(self): | |
| def _remove_weight_norm(m): | |
| try: | |
| print(m) | |
| nn.utils.remove_weight_norm(m) | |
| except ValueError: # this module didn't have weight norm | |
| return | |
| self.apply(_remove_weight_norm) | |
| class GaussianZ(nn.Module): | |
| class Config: | |
| dim: int | |
| latent_dim: int | |
| bias: bool = False | |
| use_weight_norm: bool = False | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) | |
| self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) | |
| if c.use_weight_norm: | |
| self.proj_in = weight_norm(self.proj_in) | |
| self.proj_out = weight_norm(self.proj_out) | |
| def reparam(self, mu, logvar): | |
| std = T.exp(logvar / 2) | |
| eps = T.randn_like(std) | |
| return mu + eps * std | |
| def kl_divergence(self, mu, logvar): | |
| return T.mean(-0.5 * T.sum( | |
| 1 + logvar - mu.pow(2) - logvar.exp(), | |
| dim=(1, 2)) | |
| ) | |
| def repr_from_latent(self, latent: Union[dict, T.Tensor]): | |
| if isinstance(latent, T.Tensor): | |
| z = latent | |
| else: | |
| z = self.reparam(latent['mu'], latent['logvar']) | |
| l = self.proj_out(z) | |
| return l | |
| def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: | |
| mu, logvar = self.proj_in(x).chunk(2, dim=-1) | |
| kl_div = self.kl_divergence(mu, logvar) | |
| z = self.reparam(mu, logvar) | |
| xhat = self.proj_out(z) | |
| latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} | |
| return xhat, latent | |
| class WaveCodec(nn.Module): | |
| class Config: | |
| resnet_config: ResNetStack.Config = None | |
| sample_rate: int = 16_000 | |
| use_weight_norm: bool = False | |
| compressor_config: dataclass = None | |
| norm_stddev: float = 1.0 | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| self.norm_stddev = c.norm_stddev | |
| self.encoder = c.resnet_config(mode='encoder') | |
| self.sample_rate = c.sample_rate | |
| self.total_stride = 1 | |
| for stride in c.resnet_config.strides: | |
| self.total_stride *= stride | |
| self.tokens_per_second = self.sample_rate / self.total_stride | |
| self.compressor = c.compressor_config(dim=self.encoder.middle_channels) | |
| self.decoder = c.resnet_config(mode='decoder') | |
| if c.use_weight_norm: | |
| self.encoder.apply_weight_norm() | |
| self.decoder.apply_weight_norm() | |
| self.encoder.reset_parameters() | |
| self.decoder.reset_parameters() | |
| def encode(self, data): | |
| return self.encoder(data/self.norm_stddev) | |
| def decode(self, latent): | |
| return self.decoder(latent.transpose(1, 2))*self.norm_stddev | |
| def latent_from_data(self, data, get_parameters=False): | |
| x = self.encode(data) | |
| l_in = x.transpose(1, 2) | |
| l, latent = self.compressor(l_in) | |
| return latent['z'] if not get_parameters else { | |
| 'mu': latent['mu'], | |
| 'logvar': latent['logvar'], | |
| 'z': latent['z'], | |
| } | |
| def data_from_latent(self, latent): | |
| l = self.compressor.repr_from_latent(latent) | |
| x = self.decode(l) | |
| return x | |
| def process(self, x): | |
| return self.latent_from_data(x) | |
| def unprocess(self, latent): | |
| return self.data_from_latent(latent) | |
| def forward(self, audio_input): | |
| x = self.encode(audio_input) | |
| l_in = x.transpose(1, 2) | |
| l, latent = self.compressor(l_in) | |
| xhat = self.decode(l) | |
| return xhat, latent | |
| def make_tokenizer(device='cuda'): | |
| generator_config = WaveCodec.Config( | |
| resnet_config=ResNetStack.Config( | |
| input_channels=1, | |
| output_channels=1, | |
| encode_channels=16, | |
| decode_channel_multiplier=4, | |
| kernel_size=7, | |
| bias=True, | |
| channel_ratios=(4, 8, 16, 16, 16, 16), | |
| strides=(2, 2, 4, 5, 5, 5), | |
| mode=None, | |
| ), | |
| use_weight_norm=True, | |
| compressor_config=GaussianZ.Config( | |
| dim=None, | |
| latent_dim=32, | |
| bias=True, | |
| use_weight_norm=True | |
| ), | |
| norm_stddev=0.05, | |
| ) | |
| checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") | |
| tokenizer = generator_config() | |
| load_result = tokenizer.load_state_dict(checkpoint, strict=False) | |
| print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") | |
| tokenizer = tokenizer.eval() | |
| # Only convert to bfloat16 if using CUDA | |
| if device == 'cuda': | |
| tokenizer = tokenizer.bfloat16() | |
| tokenizer = tokenizer.to(device) | |
| tokenizer.requires_grad_ = False | |
| return tokenizer | |