Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import math | |
| from typing import Any, Callable, Optional, Tuple, Union | |
| from torch.cuda.amp import autocast, GradScaler | |
| from .vits_config import VitsConfig,VitsPreTrainedModel | |
| from .flow import VitsResidualCouplingBlock | |
| from .duration_predictor import VitsDurationPredictor, VitsStochasticDurationPredictor | |
| from .encoder import VitsTextEncoder | |
| from .decoder import VitsHifiGan | |
| from .posterior_encoder import VitsPosteriorEncoder | |
| from .discriminator import VitsDiscriminator | |
| from .vits_output import VitsModelOutput, VitsTrainingOutput | |
| _CONFIG_FOR_DOC = "VitsConfig" | |
| VITS_START_DOCSTRING = r""" | |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
| etc.) | |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
| and behavior. | |
| Parameters: | |
| config ([`VitsConfig`]): | |
| Model configuration class with all the parameters of the model. Initializing with a config file does not | |
| load the weights associated with the model, only the configuration. Check out the | |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| VITS_INPUTS_DOCSTRING = r""" | |
| Args: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | |
| it. | |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
| [`PreTrainedTokenizer.__call__`] for details. | |
| [What are input IDs?](../glossary#input-ids) | |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, | |
| 1]`: | |
| - 1 for tokens that are **not masked**, | |
| - 0 for tokens that are **masked**. | |
| [What are attention masks?](../glossary#attention-mask) | |
| speaker_id (`int`, *optional*): | |
| Which speaker embedding to use. Only used for multispeaker models. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
| tensors for more detail. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
| more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| """ | |
| class Vits_models_only_decoder(VitsPreTrainedModel): | |
| def __init__(self, config: VitsConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.text_encoder = VitsTextEncoder(config) | |
| self.flow = VitsResidualCouplingBlock(config) | |
| self.decoder = VitsHifiGan(config) | |
| if config.use_stochastic_duration_prediction: | |
| self.duration_predictor = VitsStochasticDurationPredictor(config) | |
| else: | |
| self.duration_predictor = VitsDurationPredictor(config) | |
| if config.num_speakers > 1: | |
| self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size) | |
| # This is used only for training. | |
| # self.posterior_encoder = VitsPosteriorEncoder(config) | |
| # These parameters control the synthesised speech properties | |
| self.speaking_rate = config.speaking_rate | |
| self.noise_scale = config.noise_scale | |
| self.noise_scale_duration = config.noise_scale_duration | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_encoder(self): | |
| return self.text_encoder | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| speaker_id: Optional[int] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.FloatTensor] = None, | |
| ) -> Union[Tuple[Any], VitsModelOutput]: | |
| r""" | |
| labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*): | |
| Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss | |
| computation. | |
| Returns: | |
| Example: | |
| ```python | |
| >>> from transformers import VitsTokenizer, VitsModel, set_seed | |
| >>> import torch | |
| >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") | |
| >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
| >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt") | |
| >>> set_seed(555) # make deterministic | |
| >>> with torch.no_grad(): | |
| ... outputs = model(inputs["input_ids"]) | |
| >>> outputs.waveform.shape | |
| torch.Size([1, 45824]) | |
| ``` | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| raise NotImplementedError("Training of VITS is not supported yet.") | |
| if attention_mask is not None: | |
| input_padding_mask = attention_mask.unsqueeze(-1).float() | |
| else: | |
| input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float() | |
| if self.config.num_speakers > 1 and speaker_id is not None: | |
| if not 0 <= speaker_id < self.config.num_speakers: | |
| raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") | |
| if isinstance(speaker_id, int): | |
| speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) | |
| speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) | |
| else: | |
| speaker_embeddings = None | |
| text_encoder_output = self.text_encoder( | |
| input_ids=input_ids, | |
| padding_mask=input_padding_mask, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state | |
| hidden_states = hidden_states.transpose(1, 2) | |
| input_padding_mask = input_padding_mask.transpose(1, 2) | |
| prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means | |
| prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances | |
| if self.config.use_stochastic_duration_prediction: | |
| log_duration = self.duration_predictor( | |
| hidden_states, | |
| input_padding_mask, | |
| speaker_embeddings, | |
| reverse=True, | |
| noise_scale=self.noise_scale_duration, | |
| ) | |
| else: | |
| log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings) | |
| length_scale = 1.0 / self.speaking_rate | |
| duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) | |
| predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() | |
| # Create a padding mask for the output lengths of shape (batch, 1, max_output_length) | |
| indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) | |
| output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) | |
| output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) | |
| # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length) | |
| attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) | |
| batch_size, _, output_length, input_length = attn_mask.shape | |
| cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) | |
| indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) | |
| valid_indices = indices.unsqueeze(0) < cum_duration | |
| valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) | |
| padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] | |
| attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask | |
| # Expand prior distribution | |
| prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) | |
| prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) | |
| prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale | |
| latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True) | |
| spectrogram = latents * output_padding_mask | |
| return spectrogram | |
| # waveform = self.decoder(spectrogram, speaker_embeddings) | |
| # waveform = waveform.squeeze(1) | |
| # sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates) | |
| # if not return_dict: | |
| # outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:] | |
| # return outputs | |
| # return VitsModelOutput( | |
| # waveform=waveform, | |
| # sequence_lengths=sequence_lengths, | |
| # spectrogram=spectrogram, | |
| # hidden_states=text_encoder_output.hidden_states, | |
| # attentions=text_encoder_output.attentions, | |
| # ) |