|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformers XYTokenizer model.""" |
|
|
|
|
|
import math |
|
|
from collections import defaultdict |
|
|
from dataclasses import asdict, dataclass |
|
|
from typing import Optional, Tuple, Union, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.modeling_utils import PreTrainedAudioTokenizerBase |
|
|
from transformers.utils import ModelOutput, logging |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
from .configuration_xy_tokenizer import XYTokenizerConfig |
|
|
from .feature_extraction_xy_tokenizer import ExtractorIterator |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class XYTokenizerEncodeOutput(ModelOutput): |
|
|
""" |
|
|
Output type of [`XYTokenizerModel.encode`]. |
|
|
|
|
|
Args: |
|
|
quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`): |
|
|
The quantized continuous representation of the input audio. This is the output of the quantizer. |
|
|
audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`): |
|
|
The discrete codes from the quantizer for each codebook. |
|
|
codes_lengths (`torch.LongTensor` of shape `(batch_size,)`): |
|
|
The valid length of each sequence in `audio_codes`. |
|
|
commit_loss (`torch.FloatTensor`, *optional*): |
|
|
The commitment loss from the vector quantizer. |
|
|
overlap_seconds (`int`, *optional*): |
|
|
The duration of the overlap in seconds between adjacent audio chunks. |
|
|
""" |
|
|
quantized_representation: torch.FloatTensor = None |
|
|
audio_codes: torch.LongTensor = None |
|
|
codes_lengths: torch.LongTensor = None |
|
|
commit_loss: Optional[torch.FloatTensor] = None |
|
|
overlap_seconds: Optional[int] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class XYTokenizerDecodeOutput(ModelOutput): |
|
|
""" |
|
|
Output type of [`XYTokenizerModel.decode`]. |
|
|
|
|
|
Args: |
|
|
audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`): |
|
|
The reconstructed audio waveform. |
|
|
output_length (`torch.LongTensor` of shape `(batch_size,)`): |
|
|
The valid length of each sequence in `audio_values`. |
|
|
""" |
|
|
audio_values: torch.FloatTensor = None |
|
|
output_length: Optional[torch.LongTensor] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class XYTokenizerModelOutput(ModelOutput): |
|
|
""" |
|
|
Output type of [`XYTokenizerModel`]'s forward pass. |
|
|
|
|
|
Args: |
|
|
audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`): |
|
|
The reconstructed audio waveform. |
|
|
output_length (`torch.LongTensor` of shape `(batch_size,)`): |
|
|
The valid length of each sequence in `audio_values`. |
|
|
quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`): |
|
|
The quantized continuous representation of the input audio. This is the output of the quantizer. |
|
|
audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`): |
|
|
The discrete codes from the quantizer for each codebook. |
|
|
codes_lengths (`torch.LongTensor` of shape `(batch_size,)`): |
|
|
The valid length of each sequence in `audio_codes`. |
|
|
commit_loss (`torch.FloatTensor`, *optional*): |
|
|
The commitment loss from the vector quantizer. |
|
|
""" |
|
|
audio_values: torch.FloatTensor = None |
|
|
output_length: torch.LongTensor = None |
|
|
quantized_representation: torch.FloatTensor = None |
|
|
audio_codes: torch.LongTensor = None |
|
|
codes_lengths: torch.LongTensor = None |
|
|
commit_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VectorQuantizerConfig: |
|
|
"""Configuration for the VectorQuantize module.""" |
|
|
commitment: float = 1.0 |
|
|
decay: float = 0.99 |
|
|
epsilon: float = 1e-5 |
|
|
threshold_ema_dead: int = 2 |
|
|
kmeans_init: bool = True |
|
|
kmeans_iters: int = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sinusoids(length, channels, max_timescale=10000, device=None): |
|
|
assert channels % 2 == 0 |
|
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
|
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
|
scaled_time = torch.arange(length, device=device)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
|
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
|
|
|
|
|
def get_sequence_mask(inputs, inputs_length): |
|
|
if inputs.dim() == 3: |
|
|
bsz, tgt_len, _ = inputs.size() |
|
|
else: |
|
|
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length) |
|
|
sequence_mask = torch.arange(0, tgt_len, device=inputs.device) |
|
|
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1) |
|
|
return sequence_mask |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]: |
|
|
hidden_states = hidden_states.to(self.weight.dtype) |
|
|
return self.weight * hidden_states |
|
|
|
|
|
|
|
|
class VarLenAttention(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
|
|
self.causal = causal |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
|
|
|
def _create_attention_mask(self, seq_len, max_len, device, dtype): |
|
|
bsz = seq_len.size(0) |
|
|
mask = torch.ones(bsz, 1, max_len, max_len, device=device, dtype=dtype) |
|
|
seq_indices = torch.arange(max_len, device=device).unsqueeze(0) |
|
|
seq_len_expanded = seq_len.unsqueeze(1) |
|
|
valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1) |
|
|
mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype) |
|
|
if self.causal: |
|
|
causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1) |
|
|
mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype) |
|
|
mask = mask + (1.0 - mask) * torch.finfo(dtype).min |
|
|
return mask |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: |
|
|
bsz, max_len, _ = hidden_states.size() |
|
|
query = self.q_proj(hidden_states) * self.scaling |
|
|
key = self.k_proj(hidden_states) |
|
|
value = self.v_proj(hidden_states) |
|
|
query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
attn_scores = torch.matmul(query, key.transpose(-1, -2)) |
|
|
attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype) |
|
|
attn_scores = attn_scores + attn_mask |
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
attn_weights = self.dropout(attn_weights) |
|
|
attn_output = torch.matmul(attn_weights, value) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim) |
|
|
attn_output = self.out_proj(attn_output) |
|
|
return attn_output |
|
|
|
|
|
|
|
|
class OmniWhisperMLP(nn.Module): |
|
|
def __init__(self, activation_function="gelu", d_model=1280, ffn_dim=5120): |
|
|
super().__init__() |
|
|
self.activation_fn = ACT2FN[activation_function] |
|
|
self.fc1 = nn.Linear(d_model, ffn_dim) |
|
|
self.fc2 = nn.Linear(ffn_dim, d_model) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
|
return self.fc2(hidden_states) |
|
|
|
|
|
|
|
|
class OmniWhisperTransformerLayer(nn.Module): |
|
|
def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"): |
|
|
super().__init__() |
|
|
self.embed_dim = d_model |
|
|
if attn_type != "varlen": |
|
|
raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.") |
|
|
self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal) |
|
|
if ln_type == "LayerNorm": |
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
elif ln_type == "RMSNorm": |
|
|
self.self_attn_layer_norm = RMSNorm(self.embed_dim) |
|
|
else: |
|
|
raise ValueError(f"Unknown ln_type: {ln_type}") |
|
|
|
|
|
self.mlp = OmniWhisperMLP(activation_function, d_model, ffn_dim) |
|
|
if ln_type == "LayerNorm": |
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
elif ln_type == "RMSNorm": |
|
|
self.final_layer_norm = RMSNorm(self.embed_dim) |
|
|
else: |
|
|
raise ValueError(f"Unknown ln_type: {ln_type}") |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: |
|
|
residual = hidden_states |
|
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
hidden_states = self.self_attn(hidden_states, seq_len) |
|
|
hidden_states = residual + hidden_states |
|
|
residual = hidden_states |
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \ |
|
|
(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): |
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class OmniAudioEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3, |
|
|
d_model=1280, scale_embedding=True, max_audio_seconds=30, encoder_layers=32, |
|
|
encoder_attention_heads=20, encoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen" |
|
|
): |
|
|
super().__init__() |
|
|
self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size |
|
|
self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0 |
|
|
self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size |
|
|
self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1) |
|
|
self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1) |
|
|
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) |
|
|
self.layers = nn.ModuleList([ |
|
|
OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type) |
|
|
for _ in range(encoder_layers) |
|
|
]) |
|
|
self.layer_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, input_features, input_length, output_hidden_states=False): |
|
|
input_features = input_features.to(self.conv1.weight.dtype) |
|
|
inputs_embeds = F.gelu(self.conv1(input_features)) |
|
|
inputs_embeds = F.gelu(self.conv2(inputs_embeds)) |
|
|
output_length = (input_length // self.stride_size).long() |
|
|
hidden_states = inputs_embeds.permute(0, 2, 1) |
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding |
|
|
hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype) |
|
|
attention_mask = get_sequence_mask(hidden_states, output_length) |
|
|
all_hidden = () if output_hidden_states else None |
|
|
for layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden += (hidden_states,) |
|
|
hidden_states = layer(hidden_states, output_length) |
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
if output_hidden_states: |
|
|
all_hidden += (hidden_states,) |
|
|
hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2) |
|
|
if not output_hidden_states: |
|
|
return hidden_states, output_length |
|
|
return hidden_states, output_length, all_hidden |
|
|
|
|
|
|
|
|
class OmniAudioDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3, |
|
|
d_model=1280, scale_embedding=True, max_audio_seconds=30, decoder_layers=32, |
|
|
decoder_attention_heads=20, decoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen" |
|
|
): |
|
|
super().__init__() |
|
|
self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size |
|
|
self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0 |
|
|
self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size |
|
|
self.deconv1 = nn.ConvTranspose1d(d_model, d_model, kernel_size, stride_size, padding=0, output_padding=0) |
|
|
self.deconv2 = nn.ConvTranspose1d(d_model, num_mel_bins, kernel_size, stride=1, padding=0) |
|
|
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) |
|
|
self.layers = nn.ModuleList([ |
|
|
OmniWhisperTransformerLayer(activation_function, d_model, decoder_attention_heads, decoder_ffn_dim, False, attn_type=attn_type) |
|
|
for _ in range(decoder_layers) |
|
|
]) |
|
|
self.layer_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, hidden_states, input_length): |
|
|
hidden_states = hidden_states.transpose(1, 2) |
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding |
|
|
hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype) |
|
|
attention_mask = get_sequence_mask(hidden_states, input_length) |
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, input_length) |
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
hidden_states = torch.where(attention_mask, hidden_states, 0).permute(0, 2, 1) |
|
|
output_features = F.gelu(self.deconv1(hidden_states)) |
|
|
output_features = F.gelu(self.deconv2(output_features)) |
|
|
expected_length = tgt_len * self.stride_size |
|
|
if output_features.size(2) > expected_length: |
|
|
output_features = output_features[:, :, :expected_length] |
|
|
output_length = input_length * self.stride_size |
|
|
return output_features, output_length |
|
|
|
|
|
|
|
|
class ResidualDownConv(nn.Module): |
|
|
def __init__(self, d_model=1280, avg_pooler=4): |
|
|
super().__init__() |
|
|
self.d_model, self.avg_pooler = d_model, avg_pooler |
|
|
self.intermediate_dim = d_model * avg_pooler |
|
|
self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False) |
|
|
self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False) |
|
|
self.act_fn = ACT2FN['silu'] |
|
|
self.layer_norm = nn.LayerNorm(self.intermediate_dim) |
|
|
|
|
|
def forward(self, x, input_length): |
|
|
output_length = input_length // self.avg_pooler |
|
|
x = x.transpose(1, 2) |
|
|
batch_size, seq_len, _ = x.shape |
|
|
if seq_len % self.avg_pooler != 0: |
|
|
pad_size = self.avg_pooler - seq_len % self.avg_pooler |
|
|
x = F.pad(x, (0, 0, 0, pad_size), "constant", 0) |
|
|
xt = x.permute(0, 2, 1) |
|
|
g, u = self.gate_proj(xt).permute(0, 2, 1), self.up_proj(xt).permute(0, 2, 1) |
|
|
x = x.reshape(batch_size, -1, self.intermediate_dim) |
|
|
c = self.down_proj(self.act_fn(g) * u) |
|
|
res = self.layer_norm(c + x).transpose(1, 2) |
|
|
return res, output_length |
|
|
|
|
|
|
|
|
class UpConv(nn.Module): |
|
|
def __init__(self, d_model=1280, stride=4): |
|
|
super().__init__() |
|
|
self.d_model, self.stride = d_model, stride |
|
|
self.up_conv = nn.ConvTranspose1d(self.stride * d_model, d_model, stride, stride, bias=False) |
|
|
|
|
|
def forward(self, x, input_length): |
|
|
res = self.up_conv(x) |
|
|
output_length = input_length * self.stride |
|
|
return res, output_length |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__( |
|
|
self, input_dim=1280, d_model=1280, output_dim=1280, max_source_positions=1500, |
|
|
encoder_layers=32, encoder_attention_heads=20, encoder_ffn_dim=5120, |
|
|
activation_function="gelu", attn_type="varlen" |
|
|
): |
|
|
super().__init__() |
|
|
self.input_dim, self.d_model, self.output_dim, self.max_source_positions = input_dim, d_model, output_dim, max_source_positions |
|
|
self.proj = nn.Linear(input_dim, d_model, bias=True) if input_dim != d_model else None |
|
|
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) |
|
|
self.layers = nn.ModuleList([ |
|
|
OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type) |
|
|
for _ in range(encoder_layers) |
|
|
]) |
|
|
self.layer_norm = nn.LayerNorm(d_model) |
|
|
self.out_proj = nn.Linear(d_model, output_dim, bias=True) if output_dim != d_model else None |
|
|
|
|
|
def forward(self, input_features, input_length, output_hidden_states=False): |
|
|
output_length = input_length.long() |
|
|
hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) if self.proj else input_features |
|
|
hidden_states = hidden_states.permute(0, 2, 1) |
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding |
|
|
hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype) |
|
|
attention_mask = get_sequence_mask(hidden_states, output_length) |
|
|
all_hidden = () if output_hidden_states else None |
|
|
for layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden += (hidden_states,) |
|
|
hidden_states = layer(hidden_states, output_length) |
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
if output_hidden_states: |
|
|
all_hidden += (hidden_states,) |
|
|
hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2) |
|
|
if self.out_proj: |
|
|
hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
if not output_hidden_states: |
|
|
return hidden_states, output_length |
|
|
return hidden_states, output_length, all_hidden |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ISTFT(nn.Module): |
|
|
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): |
|
|
super().__init__() |
|
|
if padding not in ["center", "same"]: |
|
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
|
self.padding, self.n_fft, self.hop_length, self.win_length = padding, n_fft, hop_length, win_length |
|
|
self.register_buffer("window", torch.hann_window(win_length)) |
|
|
|
|
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: |
|
|
if self.padding == "center": |
|
|
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) |
|
|
elif self.padding == "same": |
|
|
pad = (self.win_length - self.hop_length) // 2 |
|
|
else: |
|
|
raise ValueError("Padding must be 'center' or 'same'.") |
|
|
B, N, T = spec.shape |
|
|
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window[None, :, None] |
|
|
output_size = (T - 1) * self.hop_length + self.win_length |
|
|
|
|
|
y = F.fold(ifft, (1, output_size), (1, self.win_length), stride=(1, self.hop_length))[:, 0, 0, pad:-pad] |
|
|
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) |
|
|
window_envelope = torch.nn.functional.fold( |
|
|
window_sq, |
|
|
output_size=(1, output_size), |
|
|
kernel_size=(1, self.win_length), |
|
|
stride=(1, self.hop_length), |
|
|
).squeeze()[pad:-pad] |
|
|
assert (window_envelope > 1e-11).all() |
|
|
return y / window_envelope |
|
|
|
|
|
|
|
|
class FourierHead(nn.Module): |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
|
|
|
class ISTFTHead(FourierHead): |
|
|
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): |
|
|
super().__init__() |
|
|
self.out = nn.Linear(dim, n_fft + 2) |
|
|
self.istft = ISTFT(n_fft, hop_length, n_fft, padding) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.out(x).transpose(1, 2) |
|
|
mag, p = x.chunk(2, dim=1) |
|
|
mag = torch.exp(mag).clip(max=1e2) |
|
|
s = mag.float() * (torch.cos(p).float() + 1j * torch.sin(p).float()) |
|
|
return self.istft(s).to(x.dtype) |
|
|
|
|
|
|
|
|
class AdaLayerNorm(nn.Module): |
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.eps, self.dim = eps, embedding_dim |
|
|
self.scale = nn.Embedding(num_embeddings, embedding_dim) |
|
|
self.shift = nn.Embedding(num_embeddings, embedding_dim) |
|
|
torch.nn.init.ones_(self.scale.weight) |
|
|
torch.nn.init.zeros_(self.shift.weight) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: |
|
|
scale, shift = self.scale(cond_embedding_id), self.shift(cond_embedding_id) |
|
|
x = F.layer_norm(x, (self.dim,), eps=self.eps) |
|
|
return x * scale + shift |
|
|
|
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
|
def __init__(self, dim, intermediate_dim, layer_scale_init_value, adanorm_num_embeddings=None): |
|
|
super().__init__() |
|
|
self.dwconv = nn.Conv1d(dim, dim, 7, 1, 3, groups=dim) |
|
|
self.adanorm = adanorm_num_embeddings is not None |
|
|
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6) |
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) |
|
|
self.act = nn.GELU() |
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
|
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None |
|
|
|
|
|
def forward(self, x, cond_embedding_id=None): |
|
|
res = x |
|
|
x = self.dwconv(x).transpose(1, 2) |
|
|
x = self.norm(x, cond_embedding_id) if self.adanorm else self.norm(x) |
|
|
x = self.pwconv2(self.act(self.pwconv1(x))) |
|
|
if self.gamma is not None: |
|
|
x = self.gamma * x |
|
|
x = res + x.transpose(1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
class Backbone(nn.Module): |
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
raise NotImplementedError("Subclasses must implement the forward method.") |
|
|
|
|
|
|
|
|
class VocosBackbone(Backbone): |
|
|
def __init__(self, input_channels, dim, intermediate_dim, num_layers, layer_scale_init_value=None, adanorm_num_embeddings=None): |
|
|
super().__init__() |
|
|
self.input_channels, self.embed = input_channels, nn.Conv1d(input_channels, dim, 7, 1, 3) |
|
|
self.adanorm = adanorm_num_embeddings is not None |
|
|
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6) |
|
|
self.convnext = nn.ModuleList([ConvNeXtBlock(dim, intermediate_dim, layer_scale_init_value or 1/num_layers, adanorm_num_embeddings) for _ in range(num_layers)]) |
|
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
x = self.embed(x).transpose(1, 2) |
|
|
x = self.norm(x, kwargs.get("bandwidth_id")) if self.adanorm else self.norm(x) |
|
|
x = x.transpose(1, 2) |
|
|
for block in self.convnext: |
|
|
x = block(x, kwargs.get("bandwidth_id")) |
|
|
return self.final_layer_norm(x.transpose(1, 2)) |
|
|
|
|
|
|
|
|
class Vocos(nn.Module): |
|
|
def __init__(self, input_channels=128, dim=512, intermediate_dim=4096, num_layers=30, n_fft=640, hop_size=160, padding="same", adanorm_num_embeddings=None): |
|
|
super().__init__() |
|
|
self.backbone = VocosBackbone(input_channels, dim, intermediate_dim, num_layers, adanorm_num_embeddings=adanorm_num_embeddings) |
|
|
self.head = ISTFTHead(dim, n_fft, hop_size, padding) |
|
|
self.hop_size = hop_size |
|
|
|
|
|
def forward(self, x, input_length): |
|
|
x = self.backbone(x) |
|
|
x = self.head(x) |
|
|
return x[:, None, :], input_length * self.hop_size |
|
|
|
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay): |
|
|
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) |
|
|
|
|
|
|
|
|
def sample_vectors(samples, num): |
|
|
num_samples, device = samples.shape[0], samples.device |
|
|
indices = torch.randperm(num_samples, device=device)[:num] if num_samples >= num else torch.randint(0, num_samples, (num,), device=device) |
|
|
return samples[indices].float() |
|
|
|
|
|
|
|
|
def kmeans(samples, num_clusters, num_iters=10): |
|
|
dim, means = samples.shape[-1], sample_vectors(samples, num_clusters).float() |
|
|
for _ in range(num_iters): |
|
|
dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True)) |
|
|
buckets = dists.max(dim=-1).indices |
|
|
bins = torch.bincount(buckets, minlength=num_clusters) |
|
|
zero_mask = bins == 0 |
|
|
bins_min_clamped = bins.masked_fill(zero_mask, 1) |
|
|
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32).scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) / bins_min_clamped[..., None] |
|
|
means = torch.where(zero_mask[..., None], means, new_means) |
|
|
dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True)) |
|
|
return means, torch.bincount(dists.max(dim=-1).indices, minlength=num_clusters).float() |
|
|
|
|
|
|
|
|
class VectorQuantize(nn.Module): |
|
|
def __init__(self, input_dim, codebook_size, codebook_dim, commitment=1.0, decay=0.99, epsilon=1e-5, threshold_ema_dead=2, kmeans_init=True, kmeans_iters=10): |
|
|
super().__init__() |
|
|
self.input_dim, self.codebook_size, self.codebook_dim = input_dim, codebook_size, codebook_dim |
|
|
self.commitment, self.decay, self.epsilon, self.threshold_ema_dead = commitment, decay, epsilon, threshold_ema_dead |
|
|
self.kmeans_init, self.kmeans_iters = kmeans_init, kmeans_iters |
|
|
self.in_project = WNConv1d(input_dim, codebook_dim, 1) if input_dim != codebook_dim else nn.Identity() |
|
|
self.out_project = WNConv1d(codebook_dim, input_dim, 1) if codebook_dim != input_dim else nn.Identity() |
|
|
self.register_buffer("codebook", torch.zeros(codebook_size, codebook_dim) if kmeans_init else torch.randn(codebook_size, codebook_dim)) |
|
|
self.register_buffer("inited", torch.tensor(not kmeans_init, dtype=torch.bool)) |
|
|
self.register_buffer("cluster_size", torch.zeros(codebook_size)) |
|
|
self.register_buffer("embed_avg", self.codebook.clone()) |
|
|
|
|
|
def ema_update(self, encodings, embed_onehot): |
|
|
encodings, embed_onehot = encodings.float(), embed_onehot.float() |
|
|
cluster_size_new, embed_sum = embed_onehot.sum(0), encodings.t() @ embed_onehot |
|
|
if dist.is_initialized(): |
|
|
dist.all_reduce(cluster_size_new) |
|
|
dist.all_reduce(embed_sum) |
|
|
ema_inplace(self.cluster_size, cluster_size_new, self.decay) |
|
|
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) |
|
|
cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) * self.cluster_size.sum() |
|
|
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) |
|
|
|
|
|
def replace_dead_codes(self, encodings): |
|
|
if self.threshold_ema_dead == 0: return |
|
|
dead_mask = self.cluster_size < self.threshold_ema_dead |
|
|
if dead_mask.any(): |
|
|
samples = sample_vectors(encodings.float(), self.codebook_size) if not dist.is_initialized() or dist.get_rank() == 0 else torch.zeros_like(self.codebook) |
|
|
if dist.is_initialized(): dist.broadcast(samples, src=0) |
|
|
self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) |
|
|
|
|
|
def init_codebook(self, encodings): |
|
|
if self.inited.item(): return |
|
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
|
embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) |
|
|
else: |
|
|
embed, cluster_sizes = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device), torch.zeros(self.codebook_size, device=encodings.device) |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(embed, src=0) |
|
|
dist.broadcast(cluster_sizes, src=0) |
|
|
self.codebook.copy_(embed) |
|
|
self.embed_avg.copy_(embed.clone()) |
|
|
self.cluster_size.copy_(cluster_sizes) |
|
|
self.inited.fill_(True) |
|
|
|
|
|
def forward(self, z): |
|
|
z_e = self.in_project(z.float()) |
|
|
encodings = rearrange(z_e, "b d t -> (b t) d") |
|
|
if self.kmeans_init and not self.inited.item(): self.init_codebook(encodings) |
|
|
dist = encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ self.codebook.float().t() + self.codebook.float().pow(2).sum(1, keepdim=True).t() |
|
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=z.size(0)) |
|
|
z_q = self.decode_code(indices) |
|
|
commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment |
|
|
if self.training and torch.is_grad_enabled(): |
|
|
self.ema_update(encodings, F.one_hot(indices.view(-1), self.codebook_size)) |
|
|
self.replace_dead_codes(encodings) |
|
|
z_q = self.out_project(z_e + (z_q - z_e).detach()) |
|
|
return z_q, commit_loss, torch.tensor(0.0, device=z.device), indices, z_e |
|
|
|
|
|
def decode_code(self, embed_id): |
|
|
return F.embedding(embed_id, self.codebook.float()).transpose(1, 2) |
|
|
|
|
|
|
|
|
class ResidualVQ(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int = 1280, |
|
|
rvq_dim: int = None, |
|
|
output_dim: int = None, |
|
|
num_quantizers: int = 32, |
|
|
codebook_size: int = 1024, |
|
|
codebook_dim: int = 8, |
|
|
quantizer_dropout: float = 0.5, |
|
|
skip_rvq_ratio: float = 0.0, |
|
|
vq_config: VectorQuantizerConfig = None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
self.input_dim, self.rvq_dim, self.output_dim = input_dim, rvq_dim, output_dim or input_dim |
|
|
self.num_quantizers, self.codebook_size, self.codebook_dim = num_quantizers, codebook_size, codebook_dim |
|
|
self.quantizer_dropout, self.skip_rvq_ratio = quantizer_dropout, skip_rvq_ratio |
|
|
self.input_proj = WNConv1d(input_dim, rvq_dim, 1) if input_dim != rvq_dim else nn.Identity() |
|
|
self.output_proj = WNConv1d(rvq_dim, self.output_dim, 1) if rvq_dim != self.output_dim else nn.Identity() |
|
|
if vq_config is None: |
|
|
vq_config = VectorQuantizerConfig() |
|
|
quantizer_kwargs = asdict(vq_config) |
|
|
self.quantizers = nn.ModuleList([VectorQuantize(rvq_dim, codebook_size, codebook_dim, **quantizer_kwargs, **kwargs) for _ in range(num_quantizers)]) |
|
|
|
|
|
|
|
|
def forward(self, z, input_length, n_quantizers: int = None): |
|
|
z = self.input_proj(z) |
|
|
|
|
|
with torch.autocast('cuda', enabled=False): |
|
|
batch_size, _, max_time = z.shape |
|
|
device = z.device |
|
|
mask = torch.arange(max_time, device=device).expand(batch_size, max_time) < input_length.unsqueeze(1) |
|
|
|
|
|
quantized_out = torch.zeros_like(z) |
|
|
residual = z.clone().float() |
|
|
|
|
|
all_commit_losses = [] |
|
|
all_indices = [] |
|
|
all_quantized = [] |
|
|
|
|
|
|
|
|
|
|
|
n_q_tensor = self._get_n_quantizers_tensor(batch_size, device, n_quantizers) |
|
|
skip_mask = self._get_skip_mask(batch_size, device) |
|
|
|
|
|
|
|
|
max_q_to_run = self.num_quantizers if self.training else (n_quantizers or self.num_quantizers) |
|
|
|
|
|
for i, quantizer in enumerate(self.quantizers[:max_q_to_run]): |
|
|
|
|
|
active_in_iteration_mask = (i < n_q_tensor) |
|
|
|
|
|
|
|
|
if not active_in_iteration_mask.any(): |
|
|
|
|
|
|
|
|
all_commit_losses.append(torch.tensor(0.0, device=device)) |
|
|
all_indices.append(torch.zeros(batch_size, max_time, dtype=torch.long, device=device)) |
|
|
all_quantized.append(torch.zeros_like(z)) |
|
|
continue |
|
|
|
|
|
masked_residual = residual * mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
z_q_i, commit_loss_i, indices_i = self._quantize_step(quantizer, masked_residual, skip_mask) |
|
|
|
|
|
|
|
|
|
|
|
update_mask = (active_in_iteration_mask.view(-1, 1, 1) & mask.unsqueeze(1)) |
|
|
|
|
|
quantized_out += z_q_i * update_mask |
|
|
residual -= z_q_i * update_mask |
|
|
|
|
|
|
|
|
commit_loss_i = commit_loss_i[active_in_iteration_mask].mean() if active_in_iteration_mask.any() else torch.tensor(0.0, device=device) |
|
|
|
|
|
all_commit_losses.append(commit_loss_i) |
|
|
all_indices.append(indices_i) |
|
|
all_quantized.append(z_q_i) |
|
|
|
|
|
|
|
|
num_loops_done = len(all_commit_losses) |
|
|
if num_loops_done < self.num_quantizers: |
|
|
remaining = self.num_quantizers - num_loops_done |
|
|
all_commit_losses.extend([torch.tensor(0.0, device=device)] * remaining) |
|
|
all_indices.extend([torch.zeros(batch_size, max_time, dtype=torch.long, device=device)] * remaining) |
|
|
all_quantized.extend([torch.zeros_like(z)] * remaining) |
|
|
|
|
|
|
|
|
quantized_out = self.output_proj(quantized_out) |
|
|
all_indices_tensor = torch.stack(all_indices) |
|
|
all_commit_losses_tensor = torch.stack(all_commit_losses) |
|
|
all_quantized_tensor = torch.stack(all_quantized) |
|
|
|
|
|
return ( |
|
|
quantized_out, |
|
|
all_indices_tensor, |
|
|
all_commit_losses_tensor, |
|
|
all_quantized_tensor, |
|
|
input_length, |
|
|
) |
|
|
|
|
|
def decode_codes(self, codes): |
|
|
nq, B, T = codes.shape |
|
|
emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32) |
|
|
for i, quantizer in enumerate(self.quantizers[:nq]): |
|
|
emb += quantizer.decode_code(codes[i]) |
|
|
return self.output_proj(emb) |
|
|
|
|
|
def _get_n_quantizers_tensor(self, batch_size: int, device: torch.device, n_quantizers_override: Optional[int] = None) -> torch.Tensor: |
|
|
""" |
|
|
Determines the number of quantizers to use for each item in the batch, |
|
|
applying dropout during training. |
|
|
""" |
|
|
|
|
|
is_training = self.training and torch.is_grad_enabled() |
|
|
if not is_training or self.quantizer_dropout == 0: |
|
|
num_q = n_quantizers_override or self.num_quantizers |
|
|
return torch.full((batch_size,), num_q, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
n_q_tensor = torch.full((batch_size,), self.num_quantizers, device=device) |
|
|
n_dropout = int(batch_size * self.quantizer_dropout) |
|
|
if n_dropout > 0: |
|
|
dropout_indices = torch.randperm(batch_size, device=device)[:n_dropout] |
|
|
dropout_values = torch.randint(1, self.num_quantizers + 1, (n_dropout,), device=device) |
|
|
n_q_tensor[dropout_indices] = dropout_values |
|
|
|
|
|
return n_q_tensor |
|
|
|
|
|
def _get_skip_mask(self, batch_size: int, device: torch.device) -> Optional[torch.Tensor]: |
|
|
"""Generates a mask for skipping RVQ during training if skip_rvq_ratio > 0.""" |
|
|
is_training = self.training and torch.is_grad_enabled() |
|
|
if not is_training or self.skip_rvq_ratio <= 0: |
|
|
return None |
|
|
|
|
|
skip_mask = torch.rand(batch_size, device=device) < self.skip_rvq_ratio |
|
|
|
|
|
if skip_mask.all(): |
|
|
skip_mask[0] = False |
|
|
return skip_mask |
|
|
|
|
|
def _quantize_step(self, quantizer, residual, skip_mask): |
|
|
"""Helper to perform one step of quantization, handling the skip logic.""" |
|
|
|
|
|
z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(residual.float()) |
|
|
|
|
|
|
|
|
if skip_mask is not None: |
|
|
|
|
|
|
|
|
skip_mask_expanded = skip_mask.view(-1, 1, 1) |
|
|
z_q_i = torch.where(skip_mask_expanded, residual, z_q_i) |
|
|
commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) |
|
|
|
|
|
return z_q_i, commit_loss_i, indices_i |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XYTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase): |
|
|
""" |
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
|
models. |
|
|
""" |
|
|
config_class = XYTokenizerConfig |
|
|
base_model_prefix = "xy_tokenizer" |
|
|
main_input_name = "input_values" |
|
|
_supports_grad_checkpointing = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights.""" |
|
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, (OmniAudioEncoder, OmniAudioDecoder, Transformer)): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XYTokenizerModel(XYTokenizerPreTrainedModel): |
|
|
def __init__(self, config: XYTokenizerConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
self.config = config |
|
|
|
|
|
params = config.params |
|
|
self.semantic_encoder = OmniAudioEncoder(**params['semantic_encoder_kwargs']) |
|
|
self.semantic_encoder_adapter = Transformer(**params['semantic_encoder_adapter_kwargs']) |
|
|
self.acoustic_encoder = OmniAudioEncoder(**params['acoustic_encoder_kwargs']) |
|
|
self.pre_rvq_adapter = Transformer(**params['pre_rvq_adapter_kwargs']) |
|
|
self.downsample = ResidualDownConv(**params['downsample_kwargs']) |
|
|
self.quantizer = ResidualVQ(**params['quantizer_kwargs']) |
|
|
self.post_rvq_adapter = Transformer(**params['post_rvq_adapter_kwargs']) |
|
|
self.upsample = UpConv(**params['upsample_kwargs']) |
|
|
self.acoustic_decoder = OmniAudioDecoder(**params['acoustic_decoder_kwargs']) |
|
|
self.enhanced_vocos = Vocos(**params['vocos_kwargs']) |
|
|
self.feature_extractor = params['feature_extractor_kwargs'] |
|
|
|
|
|
self.encoder_downsample_rate = config.encoder_downsample_rate |
|
|
self.nq = params['quantizer_kwargs']['num_quantizers'] |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: Optional[torch.Tensor]): |
|
|
""" |
|
|
Computes the output lengths of the feature extractor. |
|
|
""" |
|
|
def _get_out_len(in_len): |
|
|
return (in_len - self.feature_extractor["n_fft"]) // self.feature_extractor["hop_length"] + 1 |
|
|
|
|
|
if input_lengths is None: |
|
|
return None |
|
|
|
|
|
return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device) |
|
|
|
|
|
def scale_window_size(self, boundaries, scaling_factor): |
|
|
scaling_range = [] |
|
|
scaling_boundaries = [] |
|
|
for left_boundary, right_boundary in boundaries: |
|
|
scaling_left_boundary = left_boundary// scaling_factor |
|
|
scaling_right_boundary = right_boundary // scaling_factor |
|
|
scaling_range.append(scaling_right_boundary-scaling_left_boundary) |
|
|
scaling_boundaries.append(slice(scaling_left_boundary, scaling_right_boundary)) |
|
|
return scaling_range, scaling_boundaries |
|
|
|
|
|
@torch.inference_mode |
|
|
def encode( |
|
|
self, |
|
|
features: Union[BatchFeature, ExtractorIterator], |
|
|
n_quantizers: Optional[int] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
) -> Union[XYTokenizerEncodeOutput, Tuple]: |
|
|
r""" |
|
|
Encodes the input audio waveform into discrete codes. |
|
|
|
|
|
Args: |
|
|
features (`BatchFeature` or `ExtractorIterator`): |
|
|
A single batch of features or an iterator that yields batches of chunks for long audio files. |
|
|
The iterator is expected to yield `BatchFeature` dicts which must contain a `sequence_ids` |
|
|
tensor of shape `(batch_size,)` mapping each item in the chunk to its original sequence. |
|
|
n_quantizers (`int`, *optional*): |
|
|
The number of quantizers to use. If not specified, all quantizers are used. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
Returns: |
|
|
[`XYTokenizerEncodeOutput`] or `tuple(torch.FloatTensor)` |
|
|
""" |
|
|
assert isinstance(features, (BatchFeature, ExtractorIterator)) |
|
|
|
|
|
if isinstance(features, BatchFeature): |
|
|
return self._encode(features, n_quantizers, return_dict) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
encodings = defaultdict(lambda: {"zq": [], "codes": [], "length": 0}) |
|
|
commit_losses = [] |
|
|
total_frames = 0 |
|
|
|
|
|
|
|
|
for chunk_features in features: |
|
|
|
|
|
chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True) |
|
|
valid_code_lengths, valid_code_ranges = self.scale_window_size(chunk_features["input_lengths"], self.encoder_downsample_rate) |
|
|
|
|
|
|
|
|
chunk_length = chunk_output.codes_lengths.sum().item() |
|
|
valid_chunk_length = sum(valid_code_lengths) |
|
|
if chunk_output.commit_loss is not None and valid_chunk_length > 0: |
|
|
commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length |
|
|
commit_losses.append((commit_loss.cpu(), valid_chunk_length)) |
|
|
total_frames += valid_chunk_length |
|
|
|
|
|
|
|
|
for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()): |
|
|
valid_code_range = valid_code_ranges[i] |
|
|
if valid_code_range.stop > 0: |
|
|
encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, valid_code_range]) |
|
|
encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, valid_code_range]) |
|
|
|
|
|
encodings[seq_id]["length"] += valid_code_lengths[i] |
|
|
|
|
|
final_outputs = [] |
|
|
for seq_id, seq_data in encodings.items(): |
|
|
final_outputs.append({ |
|
|
"zq": torch.cat(seq_data["zq"], dim=2), |
|
|
"codes": torch.cat(seq_data["codes"], dim=2), |
|
|
"length": seq_data["length"] |
|
|
}) |
|
|
|
|
|
|
|
|
max_len = max(seq["zq"].shape[2] for seq in final_outputs) |
|
|
|
|
|
batch_zq = [] |
|
|
batch_codes = [] |
|
|
batch_lengths = [] |
|
|
|
|
|
for seq in final_outputs: |
|
|
pad_amount = max_len - seq["zq"].shape[2] |
|
|
|
|
|
padded_zq = F.pad(seq["zq"], (0, pad_amount)) |
|
|
padded_codes = F.pad(seq["codes"], (0, pad_amount)) |
|
|
|
|
|
batch_zq.append(padded_zq) |
|
|
batch_codes.append(padded_codes) |
|
|
batch_lengths.append(seq["length"]) |
|
|
|
|
|
|
|
|
quantized_representation = torch.cat(batch_zq, dim=0) |
|
|
audio_codes = torch.cat(batch_codes, dim=0) |
|
|
codes_lengths = torch.tensor(batch_lengths, dtype=torch.long, device=self.device) |
|
|
|
|
|
|
|
|
if total_frames > 0: |
|
|
|
|
|
commit_loss = sum(loss * length for loss, length in commit_losses) / total_frames |
|
|
commit_loss = commit_loss.to(self.device) |
|
|
else: |
|
|
commit_loss = torch.tensor(0.0, device=self.device) |
|
|
|
|
|
if not return_dict: |
|
|
return (quantized_representation, audio_codes, codes_lengths, commit_loss) |
|
|
|
|
|
return XYTokenizerEncodeOutput( |
|
|
quantized_representation=quantized_representation, |
|
|
audio_codes=audio_codes, |
|
|
codes_lengths=codes_lengths, |
|
|
commit_loss=commit_loss, |
|
|
overlap_seconds=features.overlap_seconds, |
|
|
) |
|
|
|
|
|
def _encode( |
|
|
self, |
|
|
features: BatchFeature, |
|
|
n_quantizers: Optional[int] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
) -> Union[XYTokenizerEncodeOutput, Tuple]: |
|
|
input_mel = features['input_features'].to(self.device, dtype=self.dtype) |
|
|
mel_attention_mask = features['attention_mask'].to(self.device) |
|
|
mel_output_length = mel_attention_mask.sum(dim=-1).long() |
|
|
|
|
|
|
|
|
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) |
|
|
semantic_adapter_output, _ = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) |
|
|
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) |
|
|
|
|
|
concated_channel = torch.cat([semantic_adapter_output, acoustic_encoder_output], dim=1) |
|
|
|
|
|
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_channel, acoustic_encoder_output_length) |
|
|
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) |
|
|
|
|
|
n_quantizers = n_quantizers or self.quantizer.num_quantizers |
|
|
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length, n_quantizers=n_quantizers) |
|
|
|
|
|
if not return_dict: |
|
|
return (zq, codes, quantizer_output_length, vq_loss) |
|
|
|
|
|
return XYTokenizerEncodeOutput( |
|
|
quantized_representation=zq, |
|
|
audio_codes=codes, |
|
|
codes_lengths=quantizer_output_length, |
|
|
commit_loss=vq_loss.mean() |
|
|
) |
|
|
|
|
|
@torch.inference_mode |
|
|
def decode( |
|
|
self, |
|
|
audio_codes: Union[torch.Tensor, XYTokenizerEncodeOutput], |
|
|
overlap_seconds: int = 10, |
|
|
return_dict: Optional[bool] = True, |
|
|
) -> Union[XYTokenizerDecodeOutput, Tuple]: |
|
|
r""" |
|
|
Decodes discrete codes back into an audio waveform. |
|
|
|
|
|
Args: |
|
|
audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`): |
|
|
The discrete codes from the quantizer for each codebook. |
|
|
codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
Returns: |
|
|
[`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)` |
|
|
""" |
|
|
assert not isinstance(audio_codes, tuple), "try to set param `return_dict=True` for `codec.encode()` function" |
|
|
assert isinstance(audio_codes, (torch.Tensor, XYTokenizerEncodeOutput)), \ |
|
|
"only accept `torch.Tensor` or `XYTokenizerEncodeOutput` for `codec.decode()` function" |
|
|
if isinstance(audio_codes, XYTokenizerEncodeOutput): |
|
|
audio_codes = audio_codes.audio_codes |
|
|
if hasattr(audio_codes, "overlap_seconds"): |
|
|
overlap_seconds = audio_codes.overlap_seconds |
|
|
if overlap_seconds is None: |
|
|
overlap_seconds = 0 |
|
|
chunk_length = self.feature_extractor["chunk_length"] |
|
|
duration_seconds = chunk_length - overlap_seconds |
|
|
chunk_code_length = int(chunk_length * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) |
|
|
duration_code_length = int(duration_seconds * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) |
|
|
duration_wav_length = duration_code_length * self.config.decoder_upsample_rate |
|
|
|
|
|
|
|
|
batch_size = audio_codes.shape[1] |
|
|
codes_list = [audio_codes[:, i, :] for i in range(batch_size)] |
|
|
max_code_length = max(codes.shape[-1] for codes in codes_list) |
|
|
batch_size = len(codes_list) |
|
|
codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=self.device, dtype=torch.long) |
|
|
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.device) |
|
|
for i, codes in enumerate(codes_list): |
|
|
codes_tensor[:, i, :codes.shape[-1]] = codes.to(self.device) |
|
|
code_lengths[i] = codes.shape[-1] |
|
|
|
|
|
|
|
|
max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length |
|
|
wav_list = [] |
|
|
|
|
|
|
|
|
for chunk_idx in range(max_chunks): |
|
|
start = chunk_idx * duration_code_length |
|
|
end = min(start + chunk_code_length, max_code_length) |
|
|
chunk_codes = codes_tensor[:, :, start:end] |
|
|
chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) |
|
|
|
|
|
|
|
|
if chunk_code_lengths.max() == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
result = self._decode(chunk_codes, chunk_code_lengths) |
|
|
chunk_wav = result["audio_values"] |
|
|
chunk_wav_lengths = result["output_length"] |
|
|
|
|
|
|
|
|
valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) |
|
|
valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=self.device) |
|
|
for b in range(batch_size): |
|
|
if valid_wav_lengths[b] > 0: |
|
|
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] |
|
|
|
|
|
wav_list.append(valid_chunk_wav) |
|
|
|
|
|
|
|
|
if wav_list: |
|
|
wav_tensor = torch.cat(wav_list, dim=-1) |
|
|
syn_wav_list = [wav_tensor[i, :, :code_lengths[i] * self.config.decoder_upsample_rate] for i in range(batch_size)] |
|
|
else: |
|
|
syn_wav_list = [torch.zeros(1, 0, device=self.device) for _ in range(batch_size)] |
|
|
|
|
|
if not return_dict: |
|
|
return (syn_wav_list,) |
|
|
|
|
|
return XYTokenizerDecodeOutput( |
|
|
audio_values=syn_wav_list |
|
|
) |
|
|
|
|
|
def _decode( |
|
|
self, |
|
|
audio_codes: torch.Tensor, |
|
|
codes_lengths: Optional[torch.Tensor] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
) -> Union[XYTokenizerDecodeOutput, Tuple]: |
|
|
r""" |
|
|
Decodes discrete codes back into an audio waveform. |
|
|
|
|
|
Args: |
|
|
audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`): |
|
|
The discrete codes from the quantizer for each codebook. |
|
|
codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
Returns: |
|
|
[`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)` |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if codes_lengths is None: |
|
|
codes_lengths = torch.full((audio_codes.shape[1],), audio_codes.shape[2], device=self.device) |
|
|
|
|
|
|
|
|
zq = self.quantizer.decode_codes(audio_codes) |
|
|
|
|
|
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) |
|
|
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) |
|
|
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) |
|
|
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) |
|
|
|
|
|
if not return_dict: |
|
|
return (y, vocos_output_length) |
|
|
|
|
|
return XYTokenizerDecodeOutput( |
|
|
audio_values=y, |
|
|
output_length=vocos_output_length |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
n_quantizers: Optional[int] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
) -> Union[XYTokenizerModelOutput, Tuple]: |
|
|
r""" |
|
|
The forward method that handles the full encoding and decoding process. |
|
|
|
|
|
Args: |
|
|
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): |
|
|
Float values of the input audio waveform. |
|
|
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding token indices. |
|
|
n_quantizers (`int`, *optional*): |
|
|
The number of quantizers to use for encoding. If not specified, all quantizers are used. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoModel, AutoFeatureExtractor |
|
|
>>> from datasets import load_dataset, Audio |
|
|
>>> import torch |
|
|
|
|
|
>>> # This is a placeholder model name, replace with the actual one on the Hub |
|
|
>>> model_id = "your-namespace/xy-tokenizer-model" |
|
|
>>> model = AutoModel.from_pretrained(model_id) |
|
|
>>> # The feature extractor config is part of the model config, so it can be loaded this way |
|
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) |
|
|
|
|
|
>>> # Load a dummy audio dataset |
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
|
>>> audio_sample = ds[0]["audio"]["array"] |
|
|
>>> sampling_rate = ds[0]["audio"]["sampling_rate"] |
|
|
|
|
|
>>> # Process audio |
|
|
>>> inputs = feature_extractor(audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") |
|
|
|
|
|
>>> # Encode to get codes |
|
|
>>> with torch.no_grad(): |
|
|
... encoder_output = model.encode(inputs["input_values"], attention_mask=inputs["attention_mask"]) |
|
|
... audio_codes = encoder_output.audio_codes |
|
|
|
|
|
>>> # Decode from codes |
|
|
>>> with torch.no_grad(): |
|
|
... decoder_output = model.decode(audio_codes) |
|
|
... reconstructed_audio = decoder_output.audio_values |
|
|
|
|
|
>>> # Full forward pass |
|
|
>>> with torch.no_grad(): |
|
|
... model_output = model(**inputs) |
|
|
... reconstructed_audio_fwd = model_output.audio_values |
|
|
|
|
|
>>> print(reconstructed_audio.shape) |
|
|
torch.Size([1, 1, 147200]) |
|
|
>>> print(torch.allclose(reconstructed_audio, reconstructed_audio_fwd)) |
|
|
True |
|
|
``` |
|
|
|
|
|
Returns: |
|
|
[`XYTokenizerModelOutput`] or `tuple(torch.FloatTensor)` |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
encoder_outputs = self.encode( |
|
|
input_values=input_values, |
|
|
attention_mask=attention_mask, |
|
|
n_quantizers=n_quantizers, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
decoder_outputs = self.decode( |
|
|
audio_codes=encoder_outputs, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
return ( |
|
|
decoder_outputs.audio_values, |
|
|
decoder_outputs.output_length, |
|
|
encoder_outputs.quantized_representation, |
|
|
encoder_outputs.audio_codes, |
|
|
encoder_outputs.codes_lengths, |
|
|
encoder_outputs.commit_loss |
|
|
) |
|
|
|
|
|
return XYTokenizerModelOutput( |
|
|
audio_values=decoder_outputs.audio_values, |
|
|
output_length=decoder_outputs.output_length, |
|
|
quantized_representation=encoder_outputs.quantized_representation, |
|
|
audio_codes=encoder_outputs.audio_codes, |
|
|
codes_lengths=encoder_outputs.codes_lengths, |
|
|
commit_loss=encoder_outputs.commit_loss |
|
|
) |
|
|
|