Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- trained_30_percents/.gitignore +16 -0
- trained_30_percents/BiCodec/config.yaml +60 -0
- trained_30_percents/BiCodec/model.safetensors +3 -0
- trained_30_percents/Readme.md +130 -0
- trained_30_percents/Readme_zh.md +130 -0
- trained_30_percents/__init__.py +0 -0
- trained_30_percents/__pycache__/spark_llm.cpython-311.pyc +0 -0
- trained_30_percents/__pycache__/utilities.cpython-311.pyc +0 -0
- trained_30_percents/added_tokens.json +3 -0
- trained_30_percents/config.json +66 -0
- trained_30_percents/config.yaml +7 -0
- trained_30_percents/configuration_rwkv7.py +91 -0
- trained_30_percents/generation_config.json +6 -0
- trained_30_percents/hf_rwkv_tokenizer.py +280 -0
- trained_30_percents/kafka.wav +3 -0
- trained_30_percents/model.safetensors +3 -0
- trained_30_percents/modeling_rwkvspeech.py +6 -0
- trained_30_percents/output.wav +3 -0
- trained_30_percents/rwkv_vocab_v20230424.txt +0 -0
- trained_30_percents/spark_llm.py +202 -0
- trained_30_percents/sparktts/models/__pycache__/audio_tokenizer.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/models/__pycache__/bicodec.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/models/audio_tokenizer.py +167 -0
- trained_30_percents/sparktts/models/bicodec.py +247 -0
- trained_30_percents/sparktts/modules/blocks/__pycache__/layers.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/blocks/__pycache__/samper.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/blocks/__pycache__/vocos.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/blocks/layers.py +73 -0
- trained_30_percents/sparktts/modules/blocks/samper.py +115 -0
- trained_30_percents/sparktts/modules/blocks/vocos.py +373 -0
- trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_decoder.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_encoder.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/wave_generator.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
- trained_30_percents/sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
- trained_30_percents/sparktts/modules/encoder_decoder/wave_generator.py +88 -0
- trained_30_percents/sparktts/modules/fsq/__pycache__/finite_scalar_quantization.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/fsq/__pycache__/residual_fsq.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
- trained_30_percents/sparktts/modules/fsq/residual_fsq.py +355 -0
- trained_30_percents/sparktts/modules/speaker/__pycache__/ecapa_tdnn.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/speaker/__pycache__/perceiver_encoder.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/speaker/__pycache__/pooling_layers.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/speaker/__pycache__/speaker_encoder.cpython-311.pyc +0 -0
- trained_30_percents/sparktts/modules/speaker/ecapa_tdnn.py +267 -0
- trained_30_percents/sparktts/modules/speaker/perceiver_encoder.py +360 -0
- trained_30_percents/sparktts/modules/speaker/pooling_layers.py +298 -0
- trained_30_percents/sparktts/modules/speaker/speaker_encoder.py +136 -0
- trained_30_percents/sparktts/modules/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
trained_30_percents/kafka.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
trained_30_percents/output.wav filter=lfs diff=lfs merge=lfs -text
|
trained_30_percents/.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python build artifacts
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
|
| 5 |
+
# Environment variables
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# Virtual environment
|
| 9 |
+
venv/
|
| 10 |
+
|
| 11 |
+
# Model backups and outputs
|
| 12 |
+
model.fp32.safetensors
|
| 13 |
+
output.wav
|
| 14 |
+
|
| 15 |
+
# Temporary scripts
|
| 16 |
+
check_dtype.py
|
trained_30_percents/BiCodec/config.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
audio_tokenizer:
|
| 2 |
+
mel_params:
|
| 3 |
+
sample_rate: 16000
|
| 4 |
+
n_fft: 1024
|
| 5 |
+
win_length: 640
|
| 6 |
+
hop_length: 320
|
| 7 |
+
mel_fmin: 10
|
| 8 |
+
mel_fmax: null
|
| 9 |
+
num_mels: 128
|
| 10 |
+
|
| 11 |
+
encoder:
|
| 12 |
+
input_channels: 1024
|
| 13 |
+
vocos_dim: 384
|
| 14 |
+
vocos_intermediate_dim: 2048
|
| 15 |
+
vocos_num_layers: 12
|
| 16 |
+
out_channels: 1024
|
| 17 |
+
sample_ratios: [1,1]
|
| 18 |
+
|
| 19 |
+
decoder:
|
| 20 |
+
input_channel: 1024
|
| 21 |
+
channels: 1536
|
| 22 |
+
rates: [8, 5, 4, 2]
|
| 23 |
+
kernel_sizes: [16,11,8,4]
|
| 24 |
+
|
| 25 |
+
quantizer:
|
| 26 |
+
input_dim: 1024
|
| 27 |
+
codebook_size: 8192
|
| 28 |
+
codebook_dim: 8
|
| 29 |
+
commitment: 0.25
|
| 30 |
+
codebook_loss_weight: 2.0
|
| 31 |
+
use_l2_normlize: True
|
| 32 |
+
threshold_ema_dead_code: 0.2
|
| 33 |
+
|
| 34 |
+
speaker_encoder:
|
| 35 |
+
input_dim: 128
|
| 36 |
+
out_dim: 1024
|
| 37 |
+
latent_dim: 128
|
| 38 |
+
token_num: 32
|
| 39 |
+
fsq_levels: [4, 4, 4, 4, 4, 4]
|
| 40 |
+
fsq_num_quantizers: 1
|
| 41 |
+
|
| 42 |
+
prenet:
|
| 43 |
+
input_channels: 1024
|
| 44 |
+
vocos_dim: 384
|
| 45 |
+
vocos_intermediate_dim: 2048
|
| 46 |
+
vocos_num_layers: 12
|
| 47 |
+
out_channels: 1024
|
| 48 |
+
condition_dim: 1024
|
| 49 |
+
sample_ratios: [1,1]
|
| 50 |
+
use_tanh_at_final: False
|
| 51 |
+
|
| 52 |
+
postnet:
|
| 53 |
+
input_channels: 1024
|
| 54 |
+
vocos_dim: 384
|
| 55 |
+
vocos_intermediate_dim: 2048
|
| 56 |
+
vocos_num_layers: 6
|
| 57 |
+
out_channels: 1024
|
| 58 |
+
use_tanh_at_final: False
|
| 59 |
+
|
| 60 |
+
|
trained_30_percents/BiCodec/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
|
| 3 |
+
size 625518756
|
trained_30_percents/Readme.md
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# ReSpark TTS Model
|
| 6 |
+
|
| 7 |
+
This repository contains the ReSpark Text-to-Speech (TTS) model, a powerful and efficient model for generating high-quality speech from text. It is based on the RWKV architecture and utilizes the BiCodec tokenizer for audio processing.
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
First, install the required dependencies:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install transformers rwkv-fla torch torchaudio torchvision transformers soundfile numpy librosa omegaconf soxr soundfile einx librosa
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Usage
|
| 18 |
+
|
| 19 |
+
The `tts.py` script provides a complete example of how to use this model for text-to-speech synthesis with voice cloning.
|
| 20 |
+
|
| 21 |
+
### Running the Test Script
|
| 22 |
+
|
| 23 |
+
To generate speech, simply run the script:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python tts.py
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### How it Works
|
| 30 |
+
|
| 31 |
+
The script performs the following steps:
|
| 32 |
+
1. Loads the pre-trained `AutoModelForCausalLM` and `AutoTokenizer` from the current directory.
|
| 33 |
+
2. Initializes the `BiCodecTokenizer` for audio encoding and decoding.
|
| 34 |
+
3. Loads a reference audio file (`kafka.wav`) and its corresponding transcript (`prompt_text`) to provide a voice prompt.
|
| 35 |
+
4. Resamples the reference audio to match the model's expected sample rate (24000 Hz).
|
| 36 |
+
5. Takes a target text (`text`) to be synthesized.
|
| 37 |
+
6. Calls the `generate_speech` function, which generates audio based on the target text and the voice from the reference audio.
|
| 38 |
+
7. Saves the generated audio to `output.wav`.
|
| 39 |
+
|
| 40 |
+
You can modify the `prompt_text`, `prompt_audio_file`, and `text` variables in `tts.py` to synthesize different text with different voices.
|
| 41 |
+
|
| 42 |
+
### Example Code (`tts.py`)
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import os
|
| 46 |
+
import sys
|
| 47 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
print('add current dir to sys.path', current_dir)
|
| 49 |
+
sys.path.append(current_dir)
|
| 50 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
| 51 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 52 |
+
import soundfile as sf
|
| 53 |
+
import numpy as np
|
| 54 |
+
import torch
|
| 55 |
+
from utilities import generate_embeddings
|
| 56 |
+
|
| 57 |
+
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
|
| 58 |
+
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
|
| 59 |
+
temperature=1.0, device="cuda:0"):
|
| 60 |
+
"""
|
| 61 |
+
Function to generate speech.
|
| 62 |
+
"""
|
| 63 |
+
eos_token_id = model.config.vocab_size - 1
|
| 64 |
+
|
| 65 |
+
embeddings = generate_embeddings(
|
| 66 |
+
model=model,
|
| 67 |
+
tokenizer=tokenizer,
|
| 68 |
+
text=text,
|
| 69 |
+
bicodec=bicodec,
|
| 70 |
+
prompt_text=prompt_text,
|
| 71 |
+
prompt_audio=prompt_audio
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
global_tokens = embeddings['global_tokens'].unsqueeze(0)
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
generated_outputs = model.generate(
|
| 79 |
+
inputs_embeds=embeddings['input_embs'],
|
| 80 |
+
attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
|
| 81 |
+
max_new_tokens=max_new_tokens,
|
| 82 |
+
do_sample=do_sample,
|
| 83 |
+
top_k=top_k,
|
| 84 |
+
top_p=top_p,
|
| 85 |
+
temperature=temperature,
|
| 86 |
+
eos_token_id=eos_token_id,
|
| 87 |
+
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
|
| 88 |
+
use_cache=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
semantic_tokens_tensor = generated_outputs[:,:-1]
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
|
| 95 |
+
|
| 96 |
+
return wav
|
| 97 |
+
|
| 98 |
+
# --- Main execution ---
|
| 99 |
+
device = 'cuda:0'
|
| 100 |
+
|
| 101 |
+
# Initialize tokenizers and model
|
| 102 |
+
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
|
| 103 |
+
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
|
| 104 |
+
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
|
| 105 |
+
|
| 106 |
+
model = model.bfloat16().to(device)
|
| 107 |
+
model.eval()
|
| 108 |
+
|
| 109 |
+
# Prepare prompt audio and text for voice cloning
|
| 110 |
+
prompt_text = "我们并不是通过物理移动手段找到星河的。"
|
| 111 |
+
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
|
| 112 |
+
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
|
| 113 |
+
|
| 114 |
+
# Resample audio if necessary
|
| 115 |
+
target_sample_rate = audio_tokenizer.config['sample_rate']
|
| 116 |
+
if sampling_rate != target_sample_rate:
|
| 117 |
+
from librosa import resample
|
| 118 |
+
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
|
| 119 |
+
prompt_audio = np.array(prompt_audio, dtype=np.float32)
|
| 120 |
+
|
| 121 |
+
# Text to synthesize
|
| 122 |
+
text = "科学技术是第一生产力,最近 AI的迅猛发展让我们看到了迈向星辰大海的希望。"
|
| 123 |
+
|
| 124 |
+
# Generate speech
|
| 125 |
+
wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
|
| 126 |
+
|
| 127 |
+
# Save the output
|
| 128 |
+
sf.write('output.wav', wav, target_sample_rate)
|
| 129 |
+
print("Generated audio saved to output.wav")
|
| 130 |
+
```
|
trained_30_percents/Readme_zh.md
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# ReSpark TTS 模型
|
| 6 |
+
|
| 7 |
+
本仓库包含 ReSpark 文本转语音 (TTS) 模型,这是一个强大而高效的模型,可以从文本生成高质量的语音。它基于 RWKV 架构,并利用 BiCodec-Tokenizer 进行音频处理。
|
| 8 |
+
|
| 9 |
+
## 安装
|
| 10 |
+
|
| 11 |
+
首先,请安装所需的依赖库:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pip install transformers rwkv-fla torch torchaudio torchvision transformers soundfile numpy librosa omegaconf soxr soundfile einx librosa
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## 使用方法
|
| 18 |
+
|
| 19 |
+
`tts.py` 脚本提供了一个完整的使用该模型进行文本转语音合成(带声音克隆功能)的示例。
|
| 20 |
+
|
| 21 |
+
### 运行测试脚本
|
| 22 |
+
|
| 23 |
+
要生成语音,只需运行以下脚本:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python tts.py
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 工作原理
|
| 30 |
+
|
| 31 |
+
该脚本执行以下步骤:
|
| 32 |
+
1. 从当前目录加载预训练的 `AutoModelForCausalLM` 和 `AutoTokenizer`。
|
| 33 |
+
2. 初始化用于音频编码和解码的 `BiCodecTokenizer`。
|
| 34 |
+
3. 加载一个参考音频文件 (`kafka.wav`) 及其对应的文本 (`prompt_text`) 以提供声音提示(voice prompt)。
|
| 35 |
+
4. 如果需要,将参考音频重采样以匹配模型期望的采样率 (24000 Hz)。
|
| 36 |
+
5. 指定一个需要被合成的目标文本 (`text`)。
|
| 37 |
+
6. 调用 `generate_speech` 函数,该函数会根据目标文本和参考音频中的声音生成音频。
|
| 38 |
+
7. 将生成的音频保存到 `output.wav`。
|
| 39 |
+
|
| 40 |
+
您可以修改 `tts.py` 文件中的 `prompt_text`、`prompt_audio_file` 和 `text` 变量,以使用不同的声音合成不同的文本。
|
| 41 |
+
|
| 42 |
+
### 示例代码 (`tts.py`)
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import os
|
| 46 |
+
import sys
|
| 47 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
print('add current dir to sys.path', current_dir)
|
| 49 |
+
sys.path.append(current_dir)
|
| 50 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
| 51 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 52 |
+
import soundfile as sf
|
| 53 |
+
import numpy as np
|
| 54 |
+
import torch
|
| 55 |
+
from utilities import generate_embeddings
|
| 56 |
+
|
| 57 |
+
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
|
| 58 |
+
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
|
| 59 |
+
temperature=1.0, device="cuda:0"):
|
| 60 |
+
"""
|
| 61 |
+
生成语音的函数
|
| 62 |
+
"""
|
| 63 |
+
eos_token_id = model.config.vocab_size - 1
|
| 64 |
+
|
| 65 |
+
embeddings = generate_embeddings(
|
| 66 |
+
model=model,
|
| 67 |
+
tokenizer=tokenizer,
|
| 68 |
+
text=text,
|
| 69 |
+
bicodec=bicodec,
|
| 70 |
+
prompt_text=prompt_text,
|
| 71 |
+
prompt_audio=prompt_audio
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
global_tokens = embeddings['global_tokens'].unsqueeze(0)
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
generated_outputs = model.generate(
|
| 79 |
+
inputs_embeds=embeddings['input_embs'],
|
| 80 |
+
attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
|
| 81 |
+
max_new_tokens=max_new_tokens,
|
| 82 |
+
do_sample=do_sample,
|
| 83 |
+
top_k=top_k,
|
| 84 |
+
top_p=top_p,
|
| 85 |
+
temperature=temperature,
|
| 86 |
+
eos_token_id=eos_token_id,
|
| 87 |
+
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
|
| 88 |
+
use_cache=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
semantic_tokens_tensor = generated_outputs[:,:-1]
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
|
| 95 |
+
|
| 96 |
+
return wav
|
| 97 |
+
|
| 98 |
+
# --- 主程序 ---
|
| 99 |
+
device = 'cuda:0'
|
| 100 |
+
|
| 101 |
+
# 初始化分词器和模型
|
| 102 |
+
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
|
| 103 |
+
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
|
| 104 |
+
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
|
| 105 |
+
|
| 106 |
+
model = model.bfloat16().to(device)
|
| 107 |
+
model.eval()
|
| 108 |
+
|
| 109 |
+
# 准备用于声音克隆的提示音频和文本
|
| 110 |
+
prompt_text = "我们并不是通过物理移动手段找到星河的。"
|
| 111 |
+
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
|
| 112 |
+
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
|
| 113 |
+
|
| 114 |
+
# 如果需要,重采样音频
|
| 115 |
+
target_sample_rate = audio_tokenizer.config['sample_rate']
|
| 116 |
+
if sampling_rate != target_sample_rate:
|
| 117 |
+
from librosa import resample
|
| 118 |
+
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
|
| 119 |
+
prompt_audio = np.array(prompt_audio, dtype=np.float32)
|
| 120 |
+
|
| 121 |
+
# 要合成的文本
|
| 122 |
+
text = "科学技术是第一生产力,最近 AI的迅猛发展让我们看到了迈向星辰大海的希望。"
|
| 123 |
+
|
| 124 |
+
# 生成语音
|
| 125 |
+
wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
|
| 126 |
+
|
| 127 |
+
# 保存输出
|
| 128 |
+
sf.write('output.wav', wav, target_sample_rate)
|
| 129 |
+
print("生成的音频已保存到 output.wav")
|
| 130 |
+
```
|
trained_30_percents/__init__.py
ADDED
|
File without changes
|
trained_30_percents/__pycache__/spark_llm.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
trained_30_percents/__pycache__/utilities.cpython-311.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
trained_30_percents/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<|rwkv_tokenizer_end_of_text|>": 0
|
| 3 |
+
}
|
trained_30_percents/config.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"a_low_rank_dim": 64,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"RWKV7ForSpeech"
|
| 5 |
+
],
|
| 6 |
+
"attn": null,
|
| 7 |
+
"attn_mode": "chunk",
|
| 8 |
+
"audio_global_vocab_size": 4096,
|
| 9 |
+
"auto_map": {
|
| 10 |
+
"AutoConfig": "modeling_rwkvspeech.RWKV7SpeechConfig",
|
| 11 |
+
"AutoModel": "modeling_rwkvspeech.RWKV7Model",
|
| 12 |
+
"AutoModelForCausalLM": "modeling_rwkvspeech.RWKV7ForSpeech"
|
| 13 |
+
},
|
| 14 |
+
"bos_token_id": 0,
|
| 15 |
+
"decay_low_rank_dim": 64,
|
| 16 |
+
"eos_token_id": 0,
|
| 17 |
+
"fuse_cross_entropy": true,
|
| 18 |
+
"fuse_norm": false,
|
| 19 |
+
"gate_low_rank_dim": 128,
|
| 20 |
+
"head_dim": 64,
|
| 21 |
+
"hidden_act": "sqrelu",
|
| 22 |
+
"hidden_ratio": 4.0,
|
| 23 |
+
"hidden_size": 1024,
|
| 24 |
+
"initializer_range": 0.006,
|
| 25 |
+
"intermediate_size": 4096,
|
| 26 |
+
"max_position_embeddings": 2048,
|
| 27 |
+
"model_type": "rwkv7",
|
| 28 |
+
"norm_bias": true,
|
| 29 |
+
"norm_eps": 1e-05,
|
| 30 |
+
"norm_first": true,
|
| 31 |
+
"num_heads": 32,
|
| 32 |
+
"num_hidden_layers": 24,
|
| 33 |
+
"text_vocab_size": 65536,
|
| 34 |
+
"tie_word_embeddings": false,
|
| 35 |
+
"torch_dtype": "float32",
|
| 36 |
+
"transformers_version": "4.52.4",
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"v_low_rank_dim": 32,
|
| 39 |
+
"value_dim": [
|
| 40 |
+
1024,
|
| 41 |
+
1024,
|
| 42 |
+
1024,
|
| 43 |
+
1024,
|
| 44 |
+
1024,
|
| 45 |
+
1024,
|
| 46 |
+
1024,
|
| 47 |
+
1024,
|
| 48 |
+
1024,
|
| 49 |
+
1024,
|
| 50 |
+
1024,
|
| 51 |
+
1024,
|
| 52 |
+
1024,
|
| 53 |
+
1024,
|
| 54 |
+
1024,
|
| 55 |
+
1024,
|
| 56 |
+
1024,
|
| 57 |
+
1024,
|
| 58 |
+
1024,
|
| 59 |
+
1024,
|
| 60 |
+
1024,
|
| 61 |
+
1024,
|
| 62 |
+
1024,
|
| 63 |
+
1024
|
| 64 |
+
],
|
| 65 |
+
"vocab_size": 8193
|
| 66 |
+
}
|
trained_30_percents/config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
highpass_cutoff_freq: 40
|
| 2 |
+
sample_rate: 16000
|
| 3 |
+
segment_duration: 2.4 # (s)
|
| 4 |
+
max_val_duration: 12 # (s)
|
| 5 |
+
latent_hop_length: 320
|
| 6 |
+
ref_segment_duration: 6
|
| 7 |
+
volume_normalize: true
|
trained_30_percents/configuration_rwkv7.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RWKV7Config(PretrainedConfig):
|
| 9 |
+
|
| 10 |
+
model_type = 'rwkv7'
|
| 11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
attn_mode: str = "chunk",
|
| 16 |
+
hidden_size: int = 2048,
|
| 17 |
+
hidden_ratio: Optional[int] = 4,
|
| 18 |
+
intermediate_size: Optional[int] = None,
|
| 19 |
+
num_hidden_layers: int = 24,
|
| 20 |
+
head_dim: Optional[int] = 64,
|
| 21 |
+
num_heads: Optional[int] = None,
|
| 22 |
+
decay_low_rank_dim: int = 64,
|
| 23 |
+
gate_low_rank_dim: int = 128,
|
| 24 |
+
a_low_rank_dim: int = 64,
|
| 25 |
+
v_low_rank_dim: int = 16,
|
| 26 |
+
hidden_act: str = "sqrelu",
|
| 27 |
+
max_position_embeddings: int = 2048,
|
| 28 |
+
norm_first: bool = True,
|
| 29 |
+
norm_bias: bool = True,
|
| 30 |
+
norm_eps: float = 1e-5,
|
| 31 |
+
attn: Optional[Dict] = None,
|
| 32 |
+
use_cache: bool = True,
|
| 33 |
+
pad_token_id: int = None,
|
| 34 |
+
bos_token_id: int = 1,
|
| 35 |
+
eos_token_id: int = 2,
|
| 36 |
+
tie_word_embeddings: bool = False,
|
| 37 |
+
initializer_range: float = 0.006,
|
| 38 |
+
fuse_norm: bool = True,
|
| 39 |
+
fuse_cross_entropy: bool = True,
|
| 40 |
+
vocab_size: int = 32000,
|
| 41 |
+
**kwargs
|
| 42 |
+
):
|
| 43 |
+
self.attn_mode = attn_mode
|
| 44 |
+
self.hidden_size = hidden_size
|
| 45 |
+
self.hidden_ratio = hidden_ratio
|
| 46 |
+
self.intermediate_size = intermediate_size
|
| 47 |
+
self.norm_first = norm_first
|
| 48 |
+
self.num_hidden_layers = num_hidden_layers
|
| 49 |
+
|
| 50 |
+
if head_dim is None and num_heads is not None:
|
| 51 |
+
head_dim = int(hidden_size // num_heads)
|
| 52 |
+
elif head_dim is not None and num_heads is None:
|
| 53 |
+
num_heads = int(hidden_size // head_dim)
|
| 54 |
+
|
| 55 |
+
self.head_dim = head_dim
|
| 56 |
+
self.num_heads = num_heads
|
| 57 |
+
|
| 58 |
+
self.decay_low_rank_dim = decay_low_rank_dim
|
| 59 |
+
self.gate_low_rank_dim = gate_low_rank_dim
|
| 60 |
+
self.a_low_rank_dim = a_low_rank_dim
|
| 61 |
+
self.v_low_rank_dim = v_low_rank_dim
|
| 62 |
+
self.hidden_act = hidden_act
|
| 63 |
+
self.max_position_embeddings = max_position_embeddings
|
| 64 |
+
self.norm_bias = norm_bias
|
| 65 |
+
self.norm_eps = norm_eps
|
| 66 |
+
self.attn = attn
|
| 67 |
+
self.use_cache = use_cache
|
| 68 |
+
self.initializer_range = initializer_range
|
| 69 |
+
self.fuse_norm = fuse_norm
|
| 70 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
|
| 73 |
+
if attn is not None:
|
| 74 |
+
if not isinstance(attn, Dict):
|
| 75 |
+
raise ValueError("attn must be a dictionary")
|
| 76 |
+
if 'layers' not in attn:
|
| 77 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
| 78 |
+
if 'num_heads' not in attn:
|
| 79 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
| 80 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
| 81 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
| 82 |
+
attn['window_size'] = attn.get('window_size', None)
|
| 83 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
| 84 |
+
|
| 85 |
+
super().__init__(
|
| 86 |
+
pad_token_id=pad_token_id,
|
| 87 |
+
bos_token_id=bos_token_id,
|
| 88 |
+
eos_token_id=eos_token_id,
|
| 89 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 90 |
+
**kwargs,
|
| 91 |
+
)
|
trained_30_percents/generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 0,
|
| 5 |
+
"transformers_version": "4.52.4"
|
| 6 |
+
}
|
trained_30_percents/hf_rwkv_tokenizer.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for RWKV."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {
|
| 32 |
+
"vocab_file": "rwkv_vocab_v20230424.txt",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
class TRIE:
|
| 36 |
+
__slots__ = tuple("ch,to,values,front".split(","))
|
| 37 |
+
to: list
|
| 38 |
+
values: set
|
| 39 |
+
|
| 40 |
+
def __init__(self, front=None, ch=None):
|
| 41 |
+
self.ch = ch
|
| 42 |
+
self.to = [None for ch in range(256)]
|
| 43 |
+
self.values = set()
|
| 44 |
+
self.front = front
|
| 45 |
+
|
| 46 |
+
def __repr__(self):
|
| 47 |
+
fr = self
|
| 48 |
+
ret = []
|
| 49 |
+
while fr != None:
|
| 50 |
+
if fr.ch != None:
|
| 51 |
+
ret.append(fr.ch)
|
| 52 |
+
fr = fr.front
|
| 53 |
+
return "<TRIE %s %s>" % (ret[::-1], self.values)
|
| 54 |
+
|
| 55 |
+
def add(self, key: bytes, idx: int = 0, val=None):
|
| 56 |
+
if idx == len(key):
|
| 57 |
+
if val is None:
|
| 58 |
+
val = key
|
| 59 |
+
self.values.add(val)
|
| 60 |
+
return self
|
| 61 |
+
ch = key[idx]
|
| 62 |
+
if self.to[ch] is None:
|
| 63 |
+
self.to[ch] = TRIE(front=self, ch=ch)
|
| 64 |
+
return self.to[ch].add(key, idx=idx + 1, val=val)
|
| 65 |
+
|
| 66 |
+
def find_longest(self, key: bytes, idx: int = 0):
|
| 67 |
+
u: TRIE = self
|
| 68 |
+
ch: int = key[idx]
|
| 69 |
+
|
| 70 |
+
while u.to[ch] is not None:
|
| 71 |
+
u = u.to[ch]
|
| 72 |
+
idx += 1
|
| 73 |
+
if u.values:
|
| 74 |
+
ret = idx, u, u.values
|
| 75 |
+
if idx == len(key):
|
| 76 |
+
break
|
| 77 |
+
ch = key[idx]
|
| 78 |
+
return ret
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RWKV_TOKENIZER:
|
| 82 |
+
def __init__(self, file_name):
|
| 83 |
+
self.idx2token = {}
|
| 84 |
+
sorted = [] # must be already sorted
|
| 85 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
| 86 |
+
lines = f.readlines()
|
| 87 |
+
for l in lines:
|
| 88 |
+
idx = int(l[: l.index(" ")])
|
| 89 |
+
x = eval(l[l.index(" ") : l.rindex(" ")])
|
| 90 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
| 91 |
+
assert isinstance(x, bytes)
|
| 92 |
+
|
| 93 |
+
assert len(x) == int(l[l.rindex(" ") :])
|
| 94 |
+
sorted += [x]
|
| 95 |
+
self.idx2token[idx] = x
|
| 96 |
+
|
| 97 |
+
self.token2idx = {}
|
| 98 |
+
for k, v in self.idx2token.items():
|
| 99 |
+
self.token2idx[v] = int(k)
|
| 100 |
+
|
| 101 |
+
self.root = TRIE()
|
| 102 |
+
for t, i in self.token2idx.items():
|
| 103 |
+
_ = self.root.add(t, val=(t, i))
|
| 104 |
+
|
| 105 |
+
def encodeBytes(self, src: bytes):
|
| 106 |
+
idx: int = 0
|
| 107 |
+
tokens = []
|
| 108 |
+
while idx < len(src):
|
| 109 |
+
_idx: int = idx
|
| 110 |
+
idx, _, values = self.root.find_longest(src, idx)
|
| 111 |
+
assert idx != _idx
|
| 112 |
+
_, token = next(iter(values))
|
| 113 |
+
tokens.append(token)
|
| 114 |
+
return tokens
|
| 115 |
+
|
| 116 |
+
def decodeBytes(self, tokens):
|
| 117 |
+
return b"".join(map(lambda i: self.idx2token[i], tokens))
|
| 118 |
+
|
| 119 |
+
def encode(self, src):
|
| 120 |
+
if isinstance(src, str):
|
| 121 |
+
return [self.encodeBytes(src.encode("utf-8"))]
|
| 122 |
+
elif isinstance(src, list):
|
| 123 |
+
return [self.encodeBytes(s.encode("utf-8")) for s in src]
|
| 124 |
+
|
| 125 |
+
def decode(self, tokens):
|
| 126 |
+
return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
|
| 127 |
+
# try:
|
| 128 |
+
# return self.decodeBytes(tokens).decode('utf-8')
|
| 129 |
+
# except:
|
| 130 |
+
# return '\ufffd' # bad utf-8
|
| 131 |
+
|
| 132 |
+
def printTokens(self, tokens):
|
| 133 |
+
for i in tokens:
|
| 134 |
+
s = self.idx2token[i]
|
| 135 |
+
try:
|
| 136 |
+
s = s.decode("utf-8")
|
| 137 |
+
except:
|
| 138 |
+
pass
|
| 139 |
+
print(f"{repr(s)}{i}", end=" ")
|
| 140 |
+
print()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class RwkvTokenizer(PreTrainedTokenizer):
|
| 144 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 145 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
|
| 149 |
+
):
|
| 150 |
+
if not os.path.isfile(vocab_file):
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"Can't find a vocabulary file at path '{vocab_file}'."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 156 |
+
tokens = reader.readlines()
|
| 157 |
+
|
| 158 |
+
if "add_bos_token" in kwargs:
|
| 159 |
+
self.add_bos_token = kwargs["add_bos_token"]
|
| 160 |
+
else:
|
| 161 |
+
self.add_bos_token = False
|
| 162 |
+
self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
|
| 163 |
+
vocab = self.trie_tokenizer.token2idx
|
| 164 |
+
self.encoder = vocab
|
| 165 |
+
self.decoder = {v: k for k, v in vocab.items()}
|
| 166 |
+
self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
|
| 167 |
+
super().__init__(
|
| 168 |
+
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def vocab_size(self):
|
| 173 |
+
return len(self.encoder)
|
| 174 |
+
|
| 175 |
+
def get_vocab(self):
|
| 176 |
+
vocab = self.encoder
|
| 177 |
+
vocab.update(self.added_tokens_encoder)
|
| 178 |
+
vocab = dict(sorted(vocab.items(), key=lambda item: item[1]))
|
| 179 |
+
return vocab
|
| 180 |
+
|
| 181 |
+
def _tokenize(self, text, split_special_tokens=False):
|
| 182 |
+
# return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
|
| 183 |
+
return self.trie_tokenizer.encode(text)[0]
|
| 184 |
+
|
| 185 |
+
def _convert_token_to_id(self, token):
|
| 186 |
+
return token
|
| 187 |
+
|
| 188 |
+
def _convert_id_to_token(self, index):
|
| 189 |
+
"""Converts an index (integer) in a token (byte) using the vocab."""
|
| 190 |
+
token = self.decoder.get(index, self.unk_token)
|
| 191 |
+
if isinstance(token, (bytes)):
|
| 192 |
+
token = token.decode("utf-8", errors="replace")
|
| 193 |
+
return token
|
| 194 |
+
|
| 195 |
+
def convert_tokens_to_string(self, tokens):
|
| 196 |
+
"""Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
|
| 197 |
+
out_string = b"".join(
|
| 198 |
+
[k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
|
| 199 |
+
).decode("utf-8")
|
| 200 |
+
return out_string
|
| 201 |
+
|
| 202 |
+
def save_vocabulary(
|
| 203 |
+
self, save_directory: str, filename_prefix: Optional[str] = None
|
| 204 |
+
) -> Tuple[str]:
|
| 205 |
+
index = 0
|
| 206 |
+
if os.path.isdir(save_directory):
|
| 207 |
+
vocab_file = os.path.join(
|
| 208 |
+
save_directory,
|
| 209 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
vocab_file = (
|
| 213 |
+
filename_prefix + "-" if filename_prefix else ""
|
| 214 |
+
) + save_directory
|
| 215 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 216 |
+
for token, token_index in sorted(
|
| 217 |
+
self.encoder.items(), key=lambda kv: kv[1]
|
| 218 |
+
):
|
| 219 |
+
if index != token_index:
|
| 220 |
+
logger.warning(
|
| 221 |
+
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
| 222 |
+
" Please check that the vocabulary is not corrupted!"
|
| 223 |
+
)
|
| 224 |
+
index = token_index
|
| 225 |
+
writer.write(str(token) + "\n")
|
| 226 |
+
index += 1
|
| 227 |
+
return (vocab_file,)
|
| 228 |
+
|
| 229 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 230 |
+
if self.add_bos_token:
|
| 231 |
+
bos_token_ids = [self.bos_token_id]
|
| 232 |
+
else:
|
| 233 |
+
bos_token_ids = []
|
| 234 |
+
|
| 235 |
+
output = bos_token_ids + token_ids_0
|
| 236 |
+
|
| 237 |
+
if token_ids_1 is None:
|
| 238 |
+
return output
|
| 239 |
+
|
| 240 |
+
return output + bos_token_ids + token_ids_1
|
| 241 |
+
|
| 242 |
+
def get_special_tokens_mask(
|
| 243 |
+
self,
|
| 244 |
+
token_ids_0: List[int],
|
| 245 |
+
token_ids_1: Optional[List[int]] = None,
|
| 246 |
+
already_has_special_tokens: bool = False,
|
| 247 |
+
) -> List[int]:
|
| 248 |
+
"""
|
| 249 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 250 |
+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
token_ids_0 (`List[int]`):
|
| 254 |
+
List of IDs.
|
| 255 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 256 |
+
Optional second list of IDs for sequence pairs.
|
| 257 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 258 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 262 |
+
"""
|
| 263 |
+
if already_has_special_tokens:
|
| 264 |
+
return super().get_special_tokens_mask(
|
| 265 |
+
token_ids_0=token_ids_0,
|
| 266 |
+
token_ids_1=token_ids_1,
|
| 267 |
+
already_has_special_tokens=True,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if not self.add_bos_token:
|
| 271 |
+
return super().get_special_tokens_mask(
|
| 272 |
+
token_ids_0=token_ids_0,
|
| 273 |
+
token_ids_1=token_ids_1,
|
| 274 |
+
already_has_special_tokens=False,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if token_ids_1 is None:
|
| 278 |
+
return [1] + ([0] * len(token_ids_0))
|
| 279 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
| 280 |
+
|
trained_30_percents/kafka.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7928aeaf90600d6a014a5fececdc59cdf0e2971db327a0cf56b922b7cd8f8a7
|
| 3 |
+
size 265524
|
trained_30_percents/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7c0f6731a59dd80dbfdf692ce77e11a65af53b9c55795285aa5cfee11a97fae
|
| 3 |
+
size 809355976
|
trained_30_percents/modeling_rwkvspeech.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from spark_llm import RWKV7SpeechConfig,RWKV7ForSpeech
|
| 2 |
+
from rwkvfla.models.rwkv7 import RWKV7Model
|
| 3 |
+
|
| 4 |
+
RWKV7ForCausalLM = RWKV7ForSpeech
|
| 5 |
+
RWKV7Model = RWKV7Model
|
| 6 |
+
RWKV7Config = RWKV7SpeechConfig
|
trained_30_percents/output.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54c3ce8433267b9d2e8812fed1c87af29cdcc66a017b3700016b40b243930a34
|
| 3 |
+
size 180524
|
trained_30_percents/rwkv_vocab_v20230424.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
trained_30_percents/spark_llm.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Union, Tuple, Dict, Unpack
|
| 4 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 6 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 7 |
+
from rwkvfla.models.rwkv7.modeling_rwkv7 import RWKV7Model, RWKV7PreTrainedModel, Cache,RWKV7ForCausalLM
|
| 8 |
+
from rwkvfla.models.rwkv7.modeling_rwkv7 import FusedLinearCrossEntropyLoss, FusedCrossEntropyLoss
|
| 9 |
+
from transformers.generation.utils import GenerationMixin
|
| 10 |
+
|
| 11 |
+
from rwkvfla.models.rwkv7.configuration_rwkv7 import RWKV7Config
|
| 12 |
+
|
| 13 |
+
class RWKV7SpeechConfig(RWKV7Config):
|
| 14 |
+
def __init__(self, **kwargs):
|
| 15 |
+
super().__init__(**kwargs)
|
| 16 |
+
self.text_vocab_size = kwargs.get("text_vocab_size", kwargs.get("text_vocab_size"))
|
| 17 |
+
self.audio_global_vocab_size = kwargs.get("audio_global_vocab_size", kwargs.get("audio_global_vocab_size"))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RWKV7ForSpeech(RWKV7ForCausalLM):
|
| 21 |
+
config_class = RWKV7SpeechConfig
|
| 22 |
+
def __init__(self, config: RWKV7SpeechConfig):
|
| 23 |
+
super().__init__(config)
|
| 24 |
+
self.model = RWKV7Model(config)
|
| 25 |
+
self.vocab_size = config.vocab_size
|
| 26 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)#Spark 0.5B vocab size is 8192 + 1 for eos resulting in 8193
|
| 27 |
+
self.criterion = None
|
| 28 |
+
self.text_embedder = nn.Embedding(config.text_vocab_size, config.hidden_size)
|
| 29 |
+
self.global_embedder = nn.Embedding(config.audio_global_vocab_size, config.hidden_size)#Spark 0.5B global token size is 4096
|
| 30 |
+
#TTS Tag includes GLOBAL=0, SEMANTIC=1,START_TTS=2
|
| 31 |
+
self.tts_tag_embedder = nn.Embedding(3, config.hidden_size)
|
| 32 |
+
# Initialize weights and apply final processing
|
| 33 |
+
self.post_init()
|
| 34 |
+
self.dropout = torch.nn.Dropout(0.02)
|
| 35 |
+
|
| 36 |
+
def get_input_embeddings(self):
|
| 37 |
+
return self.model.embeddings
|
| 38 |
+
|
| 39 |
+
def set_input_embeddings(self, value):
|
| 40 |
+
self.model.embeddings = value
|
| 41 |
+
|
| 42 |
+
def get_output_embeddings(self):
|
| 43 |
+
return self.lm_head
|
| 44 |
+
|
| 45 |
+
def set_output_embeddings(self, new_embeddings):
|
| 46 |
+
self.lm_head = new_embeddings
|
| 47 |
+
|
| 48 |
+
def set_decoder(self, decoder):
|
| 49 |
+
self.model = decoder
|
| 50 |
+
|
| 51 |
+
def get_decoder(self):
|
| 52 |
+
return self.model
|
| 53 |
+
|
| 54 |
+
def generate(self, *args, **kwargs):
|
| 55 |
+
try:
|
| 56 |
+
return super().generate(*args, **kwargs)
|
| 57 |
+
except AttributeError as exception:
|
| 58 |
+
if 'past_key_values' in str(exception):
|
| 59 |
+
raise AttributeError(
|
| 60 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
| 61 |
+
f"which is not supported for {self.__class__.__name__}. "
|
| 62 |
+
f"Try another generation strategy instead. "
|
| 63 |
+
f"For the available generation strategies, check this doc: "
|
| 64 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
raise exception
|
| 68 |
+
|
| 69 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 70 |
+
def prepare_inputs_for_generation(
|
| 71 |
+
self,
|
| 72 |
+
input_ids: torch.LongTensor = None,
|
| 73 |
+
past_key_values: Optional[Cache] = None,
|
| 74 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 75 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 76 |
+
use_cache: bool = True,
|
| 77 |
+
logits_to_keep: Optional[int] = None,
|
| 78 |
+
**kwargs
|
| 79 |
+
):
|
| 80 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
| 81 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 82 |
+
input_ids = input_ids[:, -1:]
|
| 83 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 84 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
| 85 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
| 86 |
+
else:
|
| 87 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
| 88 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
| 89 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
| 90 |
+
# TODO: use `next_tokens` directly instead.
|
| 91 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
| 92 |
+
|
| 93 |
+
if logits_to_keep is not None:
|
| 94 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
| 95 |
+
|
| 96 |
+
model_inputs.update({
|
| 97 |
+
'past_key_values': past_key_values,
|
| 98 |
+
'use_cache': use_cache,
|
| 99 |
+
'attention_mask': attention_mask,
|
| 100 |
+
'logits_to_keep': logits_to_keep,
|
| 101 |
+
})
|
| 102 |
+
return model_inputs
|
| 103 |
+
|
| 104 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
input_ids: torch.LongTensor = None,
|
| 108 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 109 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 110 |
+
past_key_values: Optional[Cache] = None,
|
| 111 |
+
labels: Optional[torch.LongTensor] = None,
|
| 112 |
+
use_cache: Optional[bool] = None,
|
| 113 |
+
output_attentions: Optional[bool] = None,
|
| 114 |
+
output_hidden_states: Optional[bool] = None,
|
| 115 |
+
return_dict: Optional[bool] = None,
|
| 116 |
+
logits_to_keep: Optional[int] = 0,
|
| 117 |
+
**kwargs: Unpack[Dict]
|
| 118 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 119 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 120 |
+
output_hidden_states = (
|
| 121 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 122 |
+
)
|
| 123 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 124 |
+
if self.training and inputs_embeds is not None:
|
| 125 |
+
inputs_embeds = self.dropout(inputs_embeds)
|
| 126 |
+
outputs = self.model(
|
| 127 |
+
input_ids=input_ids,
|
| 128 |
+
attention_mask=attention_mask,
|
| 129 |
+
inputs_embeds=inputs_embeds,
|
| 130 |
+
past_key_values=past_key_values,
|
| 131 |
+
use_cache=use_cache,
|
| 132 |
+
output_attentions=output_attentions,
|
| 133 |
+
output_hidden_states=output_hidden_states,
|
| 134 |
+
return_dict=return_dict,
|
| 135 |
+
**kwargs
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
hidden_states = outputs[0]
|
| 139 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
| 140 |
+
|
| 141 |
+
loss, logits = None, None
|
| 142 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
| 143 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
| 144 |
+
if labels is not None:
|
| 145 |
+
if getattr(self, 'criterion', None) is None:
|
| 146 |
+
if fuse_linear_and_cross_entropy:
|
| 147 |
+
criterion = FusedLinearCrossEntropyLoss()
|
| 148 |
+
elif self.config.fuse_cross_entropy:
|
| 149 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
| 150 |
+
else:
|
| 151 |
+
criterion = nn.CrossEntropyLoss()
|
| 152 |
+
else:
|
| 153 |
+
criterion = self.criterion
|
| 154 |
+
# Enable model parallelism
|
| 155 |
+
labels = labels.to(hidden_states.device)
|
| 156 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
| 157 |
+
if fuse_linear_and_cross_entropy:
|
| 158 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
| 159 |
+
else:
|
| 160 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
| 161 |
+
|
| 162 |
+
if not return_dict:
|
| 163 |
+
output = (logits,) + outputs[1:]
|
| 164 |
+
return (loss,) + output if loss is not None else output
|
| 165 |
+
|
| 166 |
+
return CausalLMOutputWithPast(
|
| 167 |
+
loss=loss,
|
| 168 |
+
logits=logits,
|
| 169 |
+
past_key_values=outputs.past_key_values,
|
| 170 |
+
hidden_states=outputs.hidden_states,
|
| 171 |
+
attentions=outputs.attentions,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def copy_state_dict(self, state_dict: dict):
|
| 175 |
+
"""从源 state dict 复制参数到当前模型,排除 embeddings 和 lm_head
|
| 176 |
+
The state dict is from original RWKV7 language model
|
| 177 |
+
Args:
|
| 178 |
+
state_dict: 源 state dict
|
| 179 |
+
"""
|
| 180 |
+
# 获取当前模型的 state dict
|
| 181 |
+
target_dict = self.state_dict()
|
| 182 |
+
|
| 183 |
+
# 创建新的 state dict 用于存储要复制的参数
|
| 184 |
+
new_state_dict = {}
|
| 185 |
+
|
| 186 |
+
# 遍历源 state dict 的键
|
| 187 |
+
for key in state_dict.keys():
|
| 188 |
+
# 跳过 embeddings 和 lm_head 相关的参数
|
| 189 |
+
if key == 'model.embeddings.weight':
|
| 190 |
+
new_state_dict['text_embedder.weight'] = state_dict[key]
|
| 191 |
+
continue
|
| 192 |
+
if 'embeddings' in key or 'lm_head' in key:
|
| 193 |
+
continue
|
| 194 |
+
# 如果键在当前模型中存在,则复制参数
|
| 195 |
+
if key in target_dict:
|
| 196 |
+
new_state_dict[key] = state_dict[key]
|
| 197 |
+
|
| 198 |
+
# 加载新的 state dict 到当前模型
|
| 199 |
+
info = self.load_state_dict(new_state_dict, strict=False)
|
| 200 |
+
print(info)
|
| 201 |
+
return self
|
| 202 |
+
|
trained_30_percents/sparktts/models/__pycache__/audio_tokenizer.cpython-311.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
trained_30_percents/sparktts/models/__pycache__/bicodec.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
trained_30_percents/sparktts/models/audio_tokenizer.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, Tuple, Optional, Union
|
| 21 |
+
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
| 22 |
+
|
| 23 |
+
from sparktts.utils.file import load_config
|
| 24 |
+
from sparktts.utils.audio import load_audio
|
| 25 |
+
from sparktts.models.bicodec import BiCodec
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BiCodecTokenizer:
|
| 29 |
+
"""BiCodec tokenizer for handling audio input and tokenization."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
|
| 32 |
+
super().__init__()
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
model_dir: Path to the model directory.
|
| 36 |
+
device: Device to run the model on (default is GPU if available).
|
| 37 |
+
"""
|
| 38 |
+
self.device = device
|
| 39 |
+
self.model_dir = model_dir
|
| 40 |
+
self.config = load_config(f"{model_dir}/config.yaml")
|
| 41 |
+
self._initialize_model()
|
| 42 |
+
|
| 43 |
+
def _initialize_model(self):
|
| 44 |
+
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
|
| 45 |
+
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
|
| 46 |
+
self.device
|
| 47 |
+
)
|
| 48 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 49 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
| 50 |
+
)
|
| 51 |
+
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
| 52 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
| 53 |
+
).to(self.device)
|
| 54 |
+
self.feature_extractor.config.output_hidden_states = True
|
| 55 |
+
|
| 56 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
| 57 |
+
"""Get reference audio clip for speaker embedding."""
|
| 58 |
+
ref_segment_length = (
|
| 59 |
+
int(self.config["sample_rate"] * self.config["ref_segment_duration"])
|
| 60 |
+
// self.config["latent_hop_length"]
|
| 61 |
+
* self.config["latent_hop_length"]
|
| 62 |
+
)
|
| 63 |
+
wav_length = len(wav)
|
| 64 |
+
|
| 65 |
+
if ref_segment_length > wav_length:
|
| 66 |
+
# Repeat and truncate to handle insufficient length
|
| 67 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
| 68 |
+
|
| 69 |
+
return wav[:ref_segment_length]
|
| 70 |
+
|
| 71 |
+
def process_audio(self, wav_path: Optional[Union[Path, np.ndarray]]) -> Tuple[np.ndarray, torch.Tensor]:
|
| 72 |
+
"""load auido and get reference audio from wav path"""
|
| 73 |
+
if isinstance(wav_path, Path):
|
| 74 |
+
wav = load_audio(
|
| 75 |
+
wav_path,
|
| 76 |
+
sampling_rate=self.config["sample_rate"],
|
| 77 |
+
volume_normalize=self.config["volume_normalize"],
|
| 78 |
+
)
|
| 79 |
+
elif isinstance(wav_path, np.ndarray):
|
| 80 |
+
wav = wav_path
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError(f"Unsupported audio type: {type(wav_path)}")
|
| 83 |
+
|
| 84 |
+
wav_ref = self.get_ref_clip(wav)
|
| 85 |
+
|
| 86 |
+
wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
|
| 87 |
+
return wav, wav_ref
|
| 88 |
+
|
| 89 |
+
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
"""extract wav2vec2 features"""
|
| 91 |
+
inputs = self.processor(
|
| 92 |
+
wavs,
|
| 93 |
+
sampling_rate=16000,
|
| 94 |
+
return_tensors="pt",
|
| 95 |
+
padding=True,
|
| 96 |
+
output_hidden_states=True,
|
| 97 |
+
).input_values.to(self.feature_extractor.dtype)
|
| 98 |
+
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
| 99 |
+
feats_mix = (
|
| 100 |
+
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
| 101 |
+
) / 3
|
| 102 |
+
|
| 103 |
+
return feats_mix
|
| 104 |
+
|
| 105 |
+
def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
|
| 106 |
+
"""tokenize the batch of audio
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
batch:
|
| 110 |
+
wavs (List[np.ndarray]): batch of audio
|
| 111 |
+
ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
|
| 115 |
+
global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
|
| 116 |
+
"""
|
| 117 |
+
feats = self.extract_wav2vec2_features(batch["wav"])
|
| 118 |
+
batch["feat"] = feats
|
| 119 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
| 120 |
+
|
| 121 |
+
return global_tokens, semantic_tokens
|
| 122 |
+
|
| 123 |
+
def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
"""tokenize the audio"""
|
| 125 |
+
wav, ref_wav = self.process_audio(audio_path)
|
| 126 |
+
feat = self.extract_wav2vec2_features(wav)
|
| 127 |
+
batch = {
|
| 128 |
+
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
|
| 129 |
+
"ref_wav": ref_wav.to(self.device),
|
| 130 |
+
"feat": feat.to(self.device),
|
| 131 |
+
}
|
| 132 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
| 133 |
+
|
| 134 |
+
return global_tokens, semantic_tokens
|
| 135 |
+
|
| 136 |
+
def detokenize(
|
| 137 |
+
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
|
| 138 |
+
) -> np.array:
|
| 139 |
+
"""detokenize the tokens to waveform
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
global_tokens: global tokens. shape: (batch_size, global_dim)
|
| 143 |
+
semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
|
| 147 |
+
"""
|
| 148 |
+
global_tokens = global_tokens.unsqueeze(1)
|
| 149 |
+
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
|
| 150 |
+
return wav_rec.detach().squeeze().cpu().numpy()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# test
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
import soundfile as sf
|
| 156 |
+
|
| 157 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 158 |
+
tokenizer = BiCodecTokenizer(
|
| 159 |
+
model_dir="pretrained_models/Spark-TTS-0.5B",
|
| 160 |
+
device=device,
|
| 161 |
+
)
|
| 162 |
+
wav_path = "example/prompt_audio.wav"
|
| 163 |
+
|
| 164 |
+
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
|
| 165 |
+
|
| 166 |
+
wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
|
| 167 |
+
sf.write("example/prompt_recon.wav", wav_rec, 16000)
|
trained_30_percents/sparktts/models/bicodec.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, Any
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
|
| 23 |
+
from sparktts.utils.file import load_config
|
| 24 |
+
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
|
| 25 |
+
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
|
| 26 |
+
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
|
| 27 |
+
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
|
| 28 |
+
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BiCodec(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
|
| 34 |
+
quantizer, and wave generator.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
mel_params: Dict[str, Any],
|
| 40 |
+
encoder: nn.Module,
|
| 41 |
+
decoder: nn.Module,
|
| 42 |
+
quantizer: nn.Module,
|
| 43 |
+
speaker_encoder: nn.Module,
|
| 44 |
+
prenet: nn.Module,
|
| 45 |
+
postnet: nn.Module,
|
| 46 |
+
**kwargs
|
| 47 |
+
) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Initializes the BiCodec model with the required components.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
mel_params (dict): Parameters for the mel-spectrogram transformer.
|
| 53 |
+
encoder (nn.Module): Encoder module.
|
| 54 |
+
decoder (nn.Module): Decoder module.
|
| 55 |
+
quantizer (nn.Module): Quantizer module.
|
| 56 |
+
speaker_encoder (nn.Module): Speaker encoder module.
|
| 57 |
+
prenet (nn.Module): Prenet network.
|
| 58 |
+
postnet (nn.Module): Postnet network.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.encoder = encoder
|
| 62 |
+
self.decoder = decoder
|
| 63 |
+
self.quantizer = quantizer
|
| 64 |
+
self.speaker_encoder = speaker_encoder
|
| 65 |
+
self.prenet = prenet
|
| 66 |
+
self.postnet = postnet
|
| 67 |
+
self.init_mel_transformer(mel_params)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
|
| 71 |
+
"""
|
| 72 |
+
Loads the model from a checkpoint.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model_dir (Path): Path to the model directory containing checkpoint and config.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
BiCodec: The initialized BiCodec model.
|
| 79 |
+
"""
|
| 80 |
+
ckpt_path = f'{model_dir}/model.safetensors'
|
| 81 |
+
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
|
| 82 |
+
mel_params = config["mel_params"]
|
| 83 |
+
encoder = Encoder(**config["encoder"])
|
| 84 |
+
quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
| 85 |
+
prenet = Decoder(**config["prenet"])
|
| 86 |
+
postnet = Decoder(**config["postnet"])
|
| 87 |
+
decoder = WaveGenerator(**config["decoder"])
|
| 88 |
+
speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
| 89 |
+
|
| 90 |
+
model = cls(
|
| 91 |
+
mel_params=mel_params,
|
| 92 |
+
encoder=encoder,
|
| 93 |
+
decoder=decoder,
|
| 94 |
+
quantizer=quantizer,
|
| 95 |
+
speaker_encoder=speaker_encoder,
|
| 96 |
+
prenet=prenet,
|
| 97 |
+
postnet=postnet,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
state_dict = load_file(ckpt_path)
|
| 101 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 102 |
+
|
| 103 |
+
for key in missing_keys:
|
| 104 |
+
print(f"Missing tensor: {key}")
|
| 105 |
+
for key in unexpected_keys:
|
| 106 |
+
print(f"Unexpected tensor: {key}")
|
| 107 |
+
|
| 108 |
+
model.eval()
|
| 109 |
+
model.remove_weight_norm()
|
| 110 |
+
|
| 111 |
+
return model
|
| 112 |
+
|
| 113 |
+
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 114 |
+
"""
|
| 115 |
+
Performs a forward pass through the model.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
batch (dict): A dictionary containing features, reference waveform, and target waveform.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
dict: A dictionary containing the reconstruction, features, and other metrics.
|
| 122 |
+
"""
|
| 123 |
+
feat = batch["feat"]
|
| 124 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
| 125 |
+
|
| 126 |
+
z = self.encoder(feat.transpose(1, 2))
|
| 127 |
+
vq_outputs = self.quantizer(z)
|
| 128 |
+
|
| 129 |
+
x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
|
| 130 |
+
|
| 131 |
+
conditions = d_vector
|
| 132 |
+
with_speaker_loss = False
|
| 133 |
+
|
| 134 |
+
x = self.prenet(vq_outputs["z_q"], conditions)
|
| 135 |
+
pred_feat = self.postnet(x)
|
| 136 |
+
x = x + conditions.unsqueeze(-1)
|
| 137 |
+
wav_recon = self.decoder(x)
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"vq_loss": vq_outputs["vq_loss"],
|
| 141 |
+
"perplexity": vq_outputs["perplexity"],
|
| 142 |
+
"cluster_size": vq_outputs["active_num"],
|
| 143 |
+
"recons": wav_recon,
|
| 144 |
+
"pred_feat": pred_feat,
|
| 145 |
+
"x_vector": x_vector,
|
| 146 |
+
"d_vector": d_vector,
|
| 147 |
+
"audios": batch["wav"].unsqueeze(1),
|
| 148 |
+
"with_speaker_loss": with_speaker_loss,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def tokenize(self, batch: Dict[str, Any]):
|
| 153 |
+
"""
|
| 154 |
+
Tokenizes the input audio into semantic and global tokens.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
batch (dict): The input audio features and reference waveform.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
tuple: Semantic tokens and global tokens.
|
| 161 |
+
"""
|
| 162 |
+
feat = batch["feat"]
|
| 163 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
| 164 |
+
|
| 165 |
+
z = self.encoder(feat.transpose(1, 2))
|
| 166 |
+
semantic_tokens = self.quantizer.tokenize(z)
|
| 167 |
+
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
|
| 168 |
+
|
| 169 |
+
return semantic_tokens, global_tokens
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def detokenize(self, semantic_tokens, global_tokens):
|
| 173 |
+
"""
|
| 174 |
+
Detokenizes the semantic and global tokens into a waveform.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
semantic_tokens (tensor): Semantic tokens.
|
| 178 |
+
global_tokens (tensor): Global tokens.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
tensor: Reconstructed waveform.
|
| 182 |
+
"""
|
| 183 |
+
z_q = self.quantizer.detokenize(semantic_tokens)
|
| 184 |
+
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
| 185 |
+
x = self.prenet(z_q, d_vector)
|
| 186 |
+
x = x + d_vector.unsqueeze(-1)
|
| 187 |
+
wav_recon = self.decoder(x)
|
| 188 |
+
|
| 189 |
+
return wav_recon
|
| 190 |
+
|
| 191 |
+
def init_mel_transformer(self, config: Dict[str, Any]):
|
| 192 |
+
"""
|
| 193 |
+
Initializes the MelSpectrogram transformer based on the provided configuration.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (dict): Configuration parameters for MelSpectrogram.
|
| 197 |
+
"""
|
| 198 |
+
import torchaudio.transforms as TT
|
| 199 |
+
|
| 200 |
+
self.mel_transformer = TT.MelSpectrogram(
|
| 201 |
+
config["sample_rate"],
|
| 202 |
+
config["n_fft"],
|
| 203 |
+
config["win_length"],
|
| 204 |
+
config["hop_length"],
|
| 205 |
+
config["mel_fmin"],
|
| 206 |
+
config["mel_fmax"],
|
| 207 |
+
n_mels=config["num_mels"],
|
| 208 |
+
power=1,
|
| 209 |
+
norm="slaney",
|
| 210 |
+
mel_scale="slaney",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def remove_weight_norm(self):
|
| 214 |
+
"""Removes weight normalization from all layers."""
|
| 215 |
+
def _remove_weight_norm(m):
|
| 216 |
+
try:
|
| 217 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 218 |
+
except ValueError:
|
| 219 |
+
pass # The module didn't have weight norm
|
| 220 |
+
|
| 221 |
+
self.apply(_remove_weight_norm)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Test the model
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
|
| 227 |
+
config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
|
| 228 |
+
model = BiCodec.load_from_checkpoint(
|
| 229 |
+
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Generate random inputs for testing
|
| 233 |
+
duration = 0.96
|
| 234 |
+
x = torch.randn(20, 1, int(duration * 16000))
|
| 235 |
+
feat = torch.randn(20, int(duration * 50), 1024)
|
| 236 |
+
inputs = {"feat": feat, "wav": x, "ref_wav": x}
|
| 237 |
+
|
| 238 |
+
# Forward pass
|
| 239 |
+
outputs = model(inputs)
|
| 240 |
+
semantic_tokens, global_tokens = model.tokenize(inputs)
|
| 241 |
+
wav_recon = model.detokenize(semantic_tokens, global_tokens)
|
| 242 |
+
|
| 243 |
+
# Verify if the reconstruction matches
|
| 244 |
+
if torch.allclose(outputs["recons"].detach(), wav_recon):
|
| 245 |
+
print("Test successful")
|
| 246 |
+
else:
|
| 247 |
+
print("Test failed")
|
trained_30_percents/sparktts/modules/blocks/__pycache__/layers.cpython-311.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
trained_30_percents/sparktts/modules/blocks/__pycache__/samper.cpython-311.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
trained_30_percents/sparktts/modules/blocks/__pycache__/vocos.cpython-311.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
trained_30_percents/sparktts/modules/blocks/layers.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn.utils import weight_norm
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def WNConv1d(*args, **kwargs):
|
| 25 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 29 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Scripting this brings model speed up 1.4x
|
| 33 |
+
@torch.jit.script
|
| 34 |
+
def snake(x, alpha):
|
| 35 |
+
shape = x.shape
|
| 36 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 37 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 38 |
+
x = x.reshape(shape)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Snake1d(nn.Module):
|
| 43 |
+
def __init__(self, channels):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return snake(x, self.alpha)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ResidualUnit(nn.Module):
|
| 52 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 53 |
+
super().__init__()
|
| 54 |
+
pad = ((7 - 1) * dilation) // 2
|
| 55 |
+
self.block = nn.Sequential(
|
| 56 |
+
Snake1d(dim),
|
| 57 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 58 |
+
Snake1d(dim),
|
| 59 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
y = self.block(x)
|
| 64 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 65 |
+
if pad > 0:
|
| 66 |
+
x = x[..., pad:-pad]
|
| 67 |
+
return x + y
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def init_weights(m):
|
| 71 |
+
if isinstance(m, nn.Conv1d):
|
| 72 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 73 |
+
nn.init.constant_(m.bias, 0)
|
trained_30_percents/sparktts/modules/blocks/samper.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SamplingBlock(nn.Module):
|
| 23 |
+
"""Sampling block for upsampling or downsampling"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dim: int,
|
| 28 |
+
groups: int = 1,
|
| 29 |
+
upsample_scale: int = 1,
|
| 30 |
+
downsample_scale: int = 1,
|
| 31 |
+
) -> None:
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
dim: input dimension
|
| 35 |
+
groups: number of groups
|
| 36 |
+
upsample_scale: upsampling scale
|
| 37 |
+
downsample_scale: downsampling scale
|
| 38 |
+
"""
|
| 39 |
+
super(SamplingBlock, self).__init__()
|
| 40 |
+
|
| 41 |
+
self.upsample_scale = upsample_scale
|
| 42 |
+
self.downsample_scale = downsample_scale
|
| 43 |
+
|
| 44 |
+
if self.upsample_scale > 1:
|
| 45 |
+
self.de_conv_upsampler = nn.Sequential(
|
| 46 |
+
nn.LeakyReLU(0.2),
|
| 47 |
+
nn.ConvTranspose1d(
|
| 48 |
+
dim,
|
| 49 |
+
dim,
|
| 50 |
+
kernel_size=upsample_scale * 2,
|
| 51 |
+
stride=upsample_scale,
|
| 52 |
+
padding=upsample_scale // 2 + upsample_scale % 2,
|
| 53 |
+
output_padding=upsample_scale % 2,
|
| 54 |
+
groups=groups,
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if self.downsample_scale > 1:
|
| 59 |
+
self.conv_downsampler = nn.Sequential(
|
| 60 |
+
nn.LeakyReLU(0.2),
|
| 61 |
+
nn.Conv1d(
|
| 62 |
+
dim,
|
| 63 |
+
dim,
|
| 64 |
+
kernel_size=2 * downsample_scale,
|
| 65 |
+
stride=downsample_scale,
|
| 66 |
+
padding=downsample_scale // 2 + downsample_scale % 2,
|
| 67 |
+
groups=groups,
|
| 68 |
+
),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def repeat_upsampler(x, upsample_scale):
|
| 73 |
+
return x.repeat_interleave(upsample_scale, dim=2)
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def skip_downsampler(x, downsample_scale):
|
| 77 |
+
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
x = x.transpose(1, 2)
|
| 81 |
+
if self.upsample_scale > 1:
|
| 82 |
+
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
|
| 83 |
+
deconv_res = self.de_conv_upsampler(x)
|
| 84 |
+
upmerge_res = repeat_res + deconv_res
|
| 85 |
+
else:
|
| 86 |
+
upmerge_res = x
|
| 87 |
+
repeat_res = x
|
| 88 |
+
|
| 89 |
+
if self.downsample_scale > 1:
|
| 90 |
+
conv_res = self.conv_downsampler(upmerge_res)
|
| 91 |
+
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
|
| 92 |
+
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
|
| 93 |
+
else:
|
| 94 |
+
conv_res = upmerge_res
|
| 95 |
+
skip2_res = upmerge_res
|
| 96 |
+
skip1_res = repeat_res
|
| 97 |
+
|
| 98 |
+
final_res = conv_res + skip1_res + skip2_res
|
| 99 |
+
|
| 100 |
+
return final_res
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# test
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 106 |
+
model = SamplingBlock(1024, 1024, upsample_scale=2)
|
| 107 |
+
model_down = SamplingBlock(1024, 1024, downsample_scale=2)
|
| 108 |
+
output = model(test_input)
|
| 109 |
+
output_down = model_down(test_input)
|
| 110 |
+
print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
|
| 111 |
+
print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
|
| 112 |
+
if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
|
| 113 |
+
[8, 1024, 25]
|
| 114 |
+
):
|
| 115 |
+
print("test successful")
|
trained_30_percents/sparktts/modules/blocks/vocos.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import Tuple
|
| 21 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 22 |
+
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConvNeXtBlock(nn.Module):
|
| 27 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim (int): Number of input channels.
|
| 31 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 32 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 33 |
+
Defaults to None.
|
| 34 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 35 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
intermediate_dim: int,
|
| 42 |
+
layer_scale_init_value: float,
|
| 43 |
+
condition_dim: Optional[int] = None,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.dwconv = nn.Conv1d(
|
| 47 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 48 |
+
) # depthwise conv
|
| 49 |
+
self.adanorm = condition_dim is not None
|
| 50 |
+
if condition_dim:
|
| 51 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
| 52 |
+
else:
|
| 53 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 54 |
+
self.pwconv1 = nn.Linear(
|
| 55 |
+
dim, intermediate_dim
|
| 56 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 57 |
+
self.act = nn.GELU()
|
| 58 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 59 |
+
self.gamma = (
|
| 60 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 61 |
+
if layer_scale_init_value > 0
|
| 62 |
+
else None
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
residual = x
|
| 69 |
+
x = self.dwconv(x)
|
| 70 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 71 |
+
if self.adanorm:
|
| 72 |
+
assert cond_embedding_id is not None
|
| 73 |
+
x = self.norm(x, cond_embedding_id)
|
| 74 |
+
else:
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
x = self.pwconv1(x)
|
| 77 |
+
x = self.act(x)
|
| 78 |
+
x = self.pwconv2(x)
|
| 79 |
+
if self.gamma is not None:
|
| 80 |
+
x = self.gamma * x
|
| 81 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 82 |
+
|
| 83 |
+
x = residual + x
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class AdaLayerNorm(nn.Module):
|
| 88 |
+
"""
|
| 89 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
condition_dim (int): Dimension of the condition.
|
| 93 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.eps = eps
|
| 99 |
+
self.dim = embedding_dim
|
| 100 |
+
self.scale = nn.Linear(condition_dim, embedding_dim)
|
| 101 |
+
self.shift = nn.Linear(condition_dim, embedding_dim)
|
| 102 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 103 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
scale = self.scale(cond_embedding)
|
| 107 |
+
shift = self.shift(cond_embedding)
|
| 108 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 109 |
+
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class ResBlock1(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 116 |
+
but without upsampling layers.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
dim (int): Number of input channels.
|
| 120 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 121 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 122 |
+
Defaults to (1, 3, 5).
|
| 123 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 124 |
+
Defaults to 0.1.
|
| 125 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 126 |
+
Defaults to None.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
dim: int,
|
| 132 |
+
kernel_size: int = 3,
|
| 133 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
| 134 |
+
lrelu_slope: float = 0.1,
|
| 135 |
+
layer_scale_init_value: Optional[float] = None,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.lrelu_slope = lrelu_slope
|
| 139 |
+
self.convs1 = nn.ModuleList(
|
| 140 |
+
[
|
| 141 |
+
weight_norm(
|
| 142 |
+
nn.Conv1d(
|
| 143 |
+
dim,
|
| 144 |
+
dim,
|
| 145 |
+
kernel_size,
|
| 146 |
+
1,
|
| 147 |
+
dilation=dilation[0],
|
| 148 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 149 |
+
)
|
| 150 |
+
),
|
| 151 |
+
weight_norm(
|
| 152 |
+
nn.Conv1d(
|
| 153 |
+
dim,
|
| 154 |
+
dim,
|
| 155 |
+
kernel_size,
|
| 156 |
+
1,
|
| 157 |
+
dilation=dilation[1],
|
| 158 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 159 |
+
)
|
| 160 |
+
),
|
| 161 |
+
weight_norm(
|
| 162 |
+
nn.Conv1d(
|
| 163 |
+
dim,
|
| 164 |
+
dim,
|
| 165 |
+
kernel_size,
|
| 166 |
+
1,
|
| 167 |
+
dilation=dilation[2],
|
| 168 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 169 |
+
)
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.convs2 = nn.ModuleList(
|
| 175 |
+
[
|
| 176 |
+
weight_norm(
|
| 177 |
+
nn.Conv1d(
|
| 178 |
+
dim,
|
| 179 |
+
dim,
|
| 180 |
+
kernel_size,
|
| 181 |
+
1,
|
| 182 |
+
dilation=1,
|
| 183 |
+
padding=self.get_padding(kernel_size, 1),
|
| 184 |
+
)
|
| 185 |
+
),
|
| 186 |
+
weight_norm(
|
| 187 |
+
nn.Conv1d(
|
| 188 |
+
dim,
|
| 189 |
+
dim,
|
| 190 |
+
kernel_size,
|
| 191 |
+
1,
|
| 192 |
+
dilation=1,
|
| 193 |
+
padding=self.get_padding(kernel_size, 1),
|
| 194 |
+
)
|
| 195 |
+
),
|
| 196 |
+
weight_norm(
|
| 197 |
+
nn.Conv1d(
|
| 198 |
+
dim,
|
| 199 |
+
dim,
|
| 200 |
+
kernel_size,
|
| 201 |
+
1,
|
| 202 |
+
dilation=1,
|
| 203 |
+
padding=self.get_padding(kernel_size, 1),
|
| 204 |
+
)
|
| 205 |
+
),
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
self.gamma = nn.ParameterList(
|
| 210 |
+
[
|
| 211 |
+
(
|
| 212 |
+
nn.Parameter(
|
| 213 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 214 |
+
)
|
| 215 |
+
if layer_scale_init_value is not None
|
| 216 |
+
else None
|
| 217 |
+
),
|
| 218 |
+
(
|
| 219 |
+
nn.Parameter(
|
| 220 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 221 |
+
)
|
| 222 |
+
if layer_scale_init_value is not None
|
| 223 |
+
else None
|
| 224 |
+
),
|
| 225 |
+
(
|
| 226 |
+
nn.Parameter(
|
| 227 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 228 |
+
)
|
| 229 |
+
if layer_scale_init_value is not None
|
| 230 |
+
else None
|
| 231 |
+
),
|
| 232 |
+
]
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 236 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 237 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 238 |
+
xt = c1(xt)
|
| 239 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 240 |
+
xt = c2(xt)
|
| 241 |
+
if gamma is not None:
|
| 242 |
+
xt = gamma * xt
|
| 243 |
+
x = xt + x
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def remove_weight_norm(self):
|
| 247 |
+
for l in self.convs1:
|
| 248 |
+
remove_weight_norm(l)
|
| 249 |
+
for l in self.convs2:
|
| 250 |
+
remove_weight_norm(l)
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 254 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class Backbone(nn.Module):
|
| 258 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
| 259 |
+
|
| 260 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
| 264 |
+
C denotes output features, and L is the sequence length.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 268 |
+
and H denotes the model dimension.
|
| 269 |
+
"""
|
| 270 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class VocosBackbone(Backbone):
|
| 274 |
+
"""
|
| 275 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
input_channels (int): Number of input features channels.
|
| 279 |
+
dim (int): Hidden dimension of the model.
|
| 280 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 281 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 282 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 283 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 284 |
+
None means non-conditional model. Defaults to None.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
input_channels: int,
|
| 290 |
+
dim: int,
|
| 291 |
+
intermediate_dim: int,
|
| 292 |
+
num_layers: int,
|
| 293 |
+
layer_scale_init_value: Optional[float] = None,
|
| 294 |
+
condition_dim: Optional[int] = None,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.input_channels = input_channels
|
| 298 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
| 299 |
+
self.adanorm = condition_dim is not None
|
| 300 |
+
if condition_dim:
|
| 301 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
| 302 |
+
else:
|
| 303 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 304 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 305 |
+
self.convnext = nn.ModuleList(
|
| 306 |
+
[
|
| 307 |
+
ConvNeXtBlock(
|
| 308 |
+
dim=dim,
|
| 309 |
+
intermediate_dim=intermediate_dim,
|
| 310 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 311 |
+
condition_dim=condition_dim,
|
| 312 |
+
)
|
| 313 |
+
for _ in range(num_layers)
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 317 |
+
self.apply(self._init_weights)
|
| 318 |
+
|
| 319 |
+
def _init_weights(self, m):
|
| 320 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 321 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 322 |
+
nn.init.constant_(m.bias, 0)
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
|
| 325 |
+
x = self.embed(x)
|
| 326 |
+
if self.adanorm:
|
| 327 |
+
assert condition is not None
|
| 328 |
+
x = self.norm(x.transpose(1, 2), condition)
|
| 329 |
+
else:
|
| 330 |
+
x = self.norm(x.transpose(1, 2))
|
| 331 |
+
x = x.transpose(1, 2)
|
| 332 |
+
for conv_block in self.convnext:
|
| 333 |
+
x = conv_block(x, condition)
|
| 334 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class VocosResNetBackbone(Backbone):
|
| 339 |
+
"""
|
| 340 |
+
Vocos backbone module built with ResBlocks.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
input_channels (int): Number of input features channels.
|
| 344 |
+
dim (int): Hidden dimension of the model.
|
| 345 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
| 346 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
input_channels,
|
| 352 |
+
dim,
|
| 353 |
+
num_blocks,
|
| 354 |
+
layer_scale_init_value=None,
|
| 355 |
+
):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.input_channels = input_channels
|
| 358 |
+
self.embed = weight_norm(
|
| 359 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
| 360 |
+
)
|
| 361 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
| 362 |
+
self.resnet = nn.Sequential(
|
| 363 |
+
*[
|
| 364 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
| 365 |
+
for _ in range(num_blocks)
|
| 366 |
+
]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 370 |
+
x = self.embed(x)
|
| 371 |
+
x = self.resnet(x)
|
| 372 |
+
x = x.transpose(1, 2)
|
| 373 |
+
return x
|
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_decoder.cpython-311.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_encoder.cpython-311.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/wave_generator.cpython-311.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
trained_30_percents/sparktts/modules/encoder_decoder/feat_decoder.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
| 23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Decoder(nn.Module):
|
| 27 |
+
"""Decoder module with convnext and upsampling blocks
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
sample_ratios (List[int]): sample ratios
|
| 31 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
input_channels: int,
|
| 37 |
+
vocos_dim: int,
|
| 38 |
+
vocos_intermediate_dim: int,
|
| 39 |
+
vocos_num_layers: int,
|
| 40 |
+
out_channels: int,
|
| 41 |
+
condition_dim: int = None,
|
| 42 |
+
sample_ratios: List[int] = [1, 1],
|
| 43 |
+
use_tanh_at_final: bool = False,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.linear_pre = nn.Linear(input_channels, vocos_dim)
|
| 48 |
+
modules = [
|
| 49 |
+
nn.Sequential(
|
| 50 |
+
SamplingBlock(
|
| 51 |
+
dim=vocos_dim,
|
| 52 |
+
groups=vocos_dim,
|
| 53 |
+
upsample_scale=ratio,
|
| 54 |
+
),
|
| 55 |
+
VocosBackbone(
|
| 56 |
+
input_channels=vocos_dim,
|
| 57 |
+
dim=vocos_dim,
|
| 58 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 59 |
+
num_layers=2,
|
| 60 |
+
condition_dim=None,
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
for ratio in sample_ratios
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
self.downsample = nn.Sequential(*modules)
|
| 67 |
+
|
| 68 |
+
self.vocos_backbone = VocosBackbone(
|
| 69 |
+
input_channels=vocos_dim,
|
| 70 |
+
dim=vocos_dim,
|
| 71 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 72 |
+
num_layers=vocos_num_layers,
|
| 73 |
+
condition_dim=condition_dim,
|
| 74 |
+
)
|
| 75 |
+
self.linear = nn.Linear(vocos_dim, out_channels)
|
| 76 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 77 |
+
|
| 78 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
|
| 79 |
+
"""encoder forward.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
| 86 |
+
"""
|
| 87 |
+
x = self.linear_pre(x.transpose(1, 2))
|
| 88 |
+
x = self.downsample(x).transpose(1, 2)
|
| 89 |
+
x = self.vocos_backbone(x, condition=c)
|
| 90 |
+
x = self.linear(x).transpose(1, 2)
|
| 91 |
+
if self.use_tanh_at_final:
|
| 92 |
+
x = torch.tanh(x)
|
| 93 |
+
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# test
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 100 |
+
condition = torch.randn(8, 256)
|
| 101 |
+
decoder = Decoder(
|
| 102 |
+
input_channels=1024,
|
| 103 |
+
vocos_dim=384,
|
| 104 |
+
vocos_intermediate_dim=2048,
|
| 105 |
+
vocos_num_layers=12,
|
| 106 |
+
out_channels=256,
|
| 107 |
+
condition_dim=256,
|
| 108 |
+
sample_ratios=[2, 2],
|
| 109 |
+
)
|
| 110 |
+
output = decoder(test_input, condition)
|
| 111 |
+
print(output.shape) # torch.Size([8, 256, 200])
|
| 112 |
+
if output.shape == torch.Size([8, 256, 200]):
|
| 113 |
+
print("Decoder test passed")
|
| 114 |
+
else:
|
| 115 |
+
print("Decoder test failed")
|
trained_30_percents/sparktts/modules/encoder_decoder/feat_encoder.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
| 23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Encoder(nn.Module):
|
| 27 |
+
"""Encoder module with convnext and downsampling blocks"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
input_channels: int,
|
| 32 |
+
vocos_dim: int,
|
| 33 |
+
vocos_intermediate_dim: int,
|
| 34 |
+
vocos_num_layers: int,
|
| 35 |
+
out_channels: int,
|
| 36 |
+
sample_ratios: List[int] = [1, 1],
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
"""
|
| 40 |
+
Encoder module with VocosBackbone and sampling blocks.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
sample_ratios (List[int]): sample ratios
|
| 44 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
| 45 |
+
"""
|
| 46 |
+
self.encoder = VocosBackbone(
|
| 47 |
+
input_channels=input_channels,
|
| 48 |
+
dim=vocos_dim,
|
| 49 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 50 |
+
num_layers=vocos_num_layers,
|
| 51 |
+
condition_dim=None,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
modules = [
|
| 55 |
+
nn.Sequential(
|
| 56 |
+
SamplingBlock(
|
| 57 |
+
dim=vocos_dim,
|
| 58 |
+
groups=vocos_dim,
|
| 59 |
+
downsample_scale=ratio,
|
| 60 |
+
),
|
| 61 |
+
VocosBackbone(
|
| 62 |
+
input_channels=vocos_dim,
|
| 63 |
+
dim=vocos_dim,
|
| 64 |
+
intermediate_dim=vocos_intermediate_dim,
|
| 65 |
+
num_layers=2,
|
| 66 |
+
condition_dim=None,
|
| 67 |
+
),
|
| 68 |
+
)
|
| 69 |
+
for ratio in sample_ratios
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
self.downsample = nn.Sequential(*modules)
|
| 73 |
+
|
| 74 |
+
self.project = nn.Linear(vocos_dim, out_channels)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor, *args):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
| 83 |
+
"""
|
| 84 |
+
x = self.encoder(x)
|
| 85 |
+
x = self.downsample(x)
|
| 86 |
+
x = self.project(x)
|
| 87 |
+
return x.transpose(1, 2)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# test
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
| 93 |
+
encoder = Encoder(
|
| 94 |
+
input_channels=1024,
|
| 95 |
+
vocos_dim=384,
|
| 96 |
+
vocos_intermediate_dim=2048,
|
| 97 |
+
vocos_num_layers=12,
|
| 98 |
+
out_channels=256,
|
| 99 |
+
sample_ratios=[2, 2],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
output = encoder(test_input)
|
| 103 |
+
print(output.shape) # torch.Size([8, 256, 12])
|
| 104 |
+
if output.shape == torch.Size([8, 256, 12]):
|
| 105 |
+
print("test successful")
|
trained_30_percents/sparktts/modules/encoder_decoder/wave_generator.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Xinsheng Wang ([email protected])
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from sparktts.modules.blocks.layers import (
|
| 21 |
+
Snake1d,
|
| 22 |
+
WNConv1d,
|
| 23 |
+
ResidualUnit,
|
| 24 |
+
WNConvTranspose1d,
|
| 25 |
+
init_weights,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DecoderBlock(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
input_dim: int = 16,
|
| 33 |
+
output_dim: int = 8,
|
| 34 |
+
kernel_size: int = 2,
|
| 35 |
+
stride: int = 1,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.block = nn.Sequential(
|
| 39 |
+
Snake1d(input_dim),
|
| 40 |
+
WNConvTranspose1d(
|
| 41 |
+
input_dim,
|
| 42 |
+
output_dim,
|
| 43 |
+
kernel_size=kernel_size,
|
| 44 |
+
stride=stride,
|
| 45 |
+
padding=(kernel_size - stride) // 2,
|
| 46 |
+
),
|
| 47 |
+
ResidualUnit(output_dim, dilation=1),
|
| 48 |
+
ResidualUnit(output_dim, dilation=3),
|
| 49 |
+
ResidualUnit(output_dim, dilation=9),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.block(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class WaveGenerator(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
input_channel,
|
| 60 |
+
channels,
|
| 61 |
+
rates,
|
| 62 |
+
kernel_sizes,
|
| 63 |
+
d_out: int = 1,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
# Add first conv layer
|
| 68 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 69 |
+
|
| 70 |
+
# Add upsampling + MRF blocks
|
| 71 |
+
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
|
| 72 |
+
input_dim = channels // 2**i
|
| 73 |
+
output_dim = channels // 2 ** (i + 1)
|
| 74 |
+
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
|
| 75 |
+
|
| 76 |
+
# Add final conv layer
|
| 77 |
+
layers += [
|
| 78 |
+
Snake1d(output_dim),
|
| 79 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 80 |
+
nn.Tanh(),
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
self.model = nn.Sequential(*layers)
|
| 84 |
+
|
| 85 |
+
self.apply(init_weights)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
return self.model(x)
|
trained_30_percents/sparktts/modules/fsq/__pycache__/finite_scalar_quantization.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
trained_30_percents/sparktts/modules/fsq/__pycache__/residual_fsq.cpython-311.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
trained_30_percents/sparktts/modules/fsq/finite_scalar_quantization.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
| 3 |
+
Code adapted from Jax version in Appendix A.1
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
from functools import wraps, partial
|
| 8 |
+
from contextlib import nullcontext
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import Module
|
| 14 |
+
from torch import Tensor, int32
|
| 15 |
+
from torch.amp import autocast
|
| 16 |
+
|
| 17 |
+
from einops import rearrange, pack, unpack
|
| 18 |
+
|
| 19 |
+
# helper functions
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def exists(v):
|
| 23 |
+
return v is not None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def default(*args):
|
| 27 |
+
for arg in args:
|
| 28 |
+
if exists(arg):
|
| 29 |
+
return arg
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def maybe(fn):
|
| 34 |
+
@wraps(fn)
|
| 35 |
+
def inner(x, *args, **kwargs):
|
| 36 |
+
if not exists(x):
|
| 37 |
+
return x
|
| 38 |
+
return fn(x, *args, **kwargs)
|
| 39 |
+
|
| 40 |
+
return inner
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def pack_one(t, pattern):
|
| 44 |
+
return pack([t], pattern)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def unpack_one(t, ps, pattern):
|
| 48 |
+
return unpack(t, ps, pattern)[0]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# tensor helpers
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def round_ste(z: Tensor) -> Tensor:
|
| 55 |
+
"""Round with straight through gradients."""
|
| 56 |
+
zhat = z.round()
|
| 57 |
+
return z + (zhat - z).detach()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# main class
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FSQ(Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
levels: List[int],
|
| 67 |
+
dim: int | None = None,
|
| 68 |
+
num_codebooks=1,
|
| 69 |
+
keep_num_codebooks_dim: bool | None = None,
|
| 70 |
+
scale: float | None = None,
|
| 71 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
| 72 |
+
channel_first: bool = False,
|
| 73 |
+
projection_has_bias: bool = True,
|
| 74 |
+
return_indices=True,
|
| 75 |
+
force_quantization_f32=True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
_levels = torch.tensor(levels, dtype=int32)
|
| 79 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
| 80 |
+
|
| 81 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
| 82 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
| 83 |
+
|
| 84 |
+
self.scale = scale
|
| 85 |
+
|
| 86 |
+
codebook_dim = len(levels)
|
| 87 |
+
self.codebook_dim = codebook_dim
|
| 88 |
+
|
| 89 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
| 90 |
+
self.num_codebooks = num_codebooks
|
| 91 |
+
self.effective_codebook_dim = effective_codebook_dim
|
| 92 |
+
|
| 93 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
| 94 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
| 95 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
| 96 |
+
|
| 97 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
| 98 |
+
|
| 99 |
+
self.channel_first = channel_first
|
| 100 |
+
|
| 101 |
+
has_projections = self.dim != effective_codebook_dim
|
| 102 |
+
self.project_in = (
|
| 103 |
+
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
| 104 |
+
if has_projections
|
| 105 |
+
else nn.Identity()
|
| 106 |
+
)
|
| 107 |
+
self.project_out = (
|
| 108 |
+
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
| 109 |
+
if has_projections
|
| 110 |
+
else nn.Identity()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.has_projections = has_projections
|
| 114 |
+
|
| 115 |
+
self.return_indices = return_indices
|
| 116 |
+
if return_indices:
|
| 117 |
+
self.codebook_size = self._levels.prod().item()
|
| 118 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
| 119 |
+
self.register_buffer(
|
| 120 |
+
"implicit_codebook", implicit_codebook, persistent=False
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.allowed_dtypes = allowed_dtypes
|
| 124 |
+
self.force_quantization_f32 = force_quantization_f32
|
| 125 |
+
|
| 126 |
+
def bound(self, z, eps: float = 1e-3):
|
| 127 |
+
"""Bound `z`, an array of shape (..., d)."""
|
| 128 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
| 129 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
| 130 |
+
shift = (offset / half_l).atanh()
|
| 131 |
+
return (z + shift).tanh() * half_l - offset
|
| 132 |
+
|
| 133 |
+
def quantize(self, z):
|
| 134 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
| 135 |
+
quantized = round_ste(self.bound(z))
|
| 136 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
| 137 |
+
return quantized / half_width
|
| 138 |
+
|
| 139 |
+
def _scale_and_shift(self, zhat_normalized):
|
| 140 |
+
half_width = self._levels // 2
|
| 141 |
+
return (zhat_normalized * half_width) + half_width
|
| 142 |
+
|
| 143 |
+
def _scale_and_shift_inverse(self, zhat):
|
| 144 |
+
half_width = self._levels // 2
|
| 145 |
+
return (zhat - half_width) / half_width
|
| 146 |
+
|
| 147 |
+
def _indices_to_codes(self, indices):
|
| 148 |
+
level_indices = self.indices_to_level_indices(indices)
|
| 149 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
| 150 |
+
return codes
|
| 151 |
+
|
| 152 |
+
def codes_to_indices(self, zhat):
|
| 153 |
+
"""Converts a `code` to an index in the codebook."""
|
| 154 |
+
assert zhat.shape[-1] == self.codebook_dim
|
| 155 |
+
zhat = self._scale_and_shift(zhat)
|
| 156 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
| 157 |
+
|
| 158 |
+
def indices_to_level_indices(self, indices):
|
| 159 |
+
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
| 160 |
+
indices = rearrange(indices, "... -> ... 1")
|
| 161 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
| 162 |
+
return codes_non_centered
|
| 163 |
+
|
| 164 |
+
def indices_to_codes(self, indices):
|
| 165 |
+
"""Inverse of `codes_to_indices`."""
|
| 166 |
+
assert exists(indices)
|
| 167 |
+
|
| 168 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
| 169 |
+
|
| 170 |
+
codes = self._indices_to_codes(indices)
|
| 171 |
+
|
| 172 |
+
if self.keep_num_codebooks_dim:
|
| 173 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
| 174 |
+
|
| 175 |
+
codes = self.project_out(codes)
|
| 176 |
+
|
| 177 |
+
if is_img_or_video or self.channel_first:
|
| 178 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
| 179 |
+
|
| 180 |
+
return codes
|
| 181 |
+
|
| 182 |
+
def forward(self, z):
|
| 183 |
+
"""
|
| 184 |
+
einstein notation
|
| 185 |
+
b - batch
|
| 186 |
+
n - sequence (or flattened spatial dimensions)
|
| 187 |
+
d - feature dimension
|
| 188 |
+
c - number of codebook dim
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
is_img_or_video = z.ndim >= 4
|
| 192 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
| 193 |
+
|
| 194 |
+
# standardize image or video into (batch, seq, dimension)
|
| 195 |
+
|
| 196 |
+
if need_move_channel_last:
|
| 197 |
+
z = rearrange(z, "b d ... -> b ... d")
|
| 198 |
+
z, ps = pack_one(z, "b * d")
|
| 199 |
+
|
| 200 |
+
assert (
|
| 201 |
+
z.shape[-1] == self.dim
|
| 202 |
+
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
| 203 |
+
|
| 204 |
+
z = self.project_in(z)
|
| 205 |
+
|
| 206 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
| 207 |
+
|
| 208 |
+
# whether to force quantization step to be full precision or not
|
| 209 |
+
|
| 210 |
+
force_f32 = self.force_quantization_f32
|
| 211 |
+
quantization_context = (
|
| 212 |
+
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with quantization_context():
|
| 216 |
+
orig_dtype = z.dtype
|
| 217 |
+
|
| 218 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
| 219 |
+
z = z.float()
|
| 220 |
+
|
| 221 |
+
codes = self.quantize(z)
|
| 222 |
+
|
| 223 |
+
# returning indices could be optional
|
| 224 |
+
|
| 225 |
+
indices = None
|
| 226 |
+
|
| 227 |
+
if self.return_indices:
|
| 228 |
+
indices = self.codes_to_indices(codes)
|
| 229 |
+
|
| 230 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
| 231 |
+
|
| 232 |
+
codes = codes.type(orig_dtype)
|
| 233 |
+
|
| 234 |
+
# project out
|
| 235 |
+
|
| 236 |
+
out = self.project_out(codes)
|
| 237 |
+
|
| 238 |
+
# reconstitute image or video dimensions
|
| 239 |
+
|
| 240 |
+
if need_move_channel_last:
|
| 241 |
+
out = unpack_one(out, ps, "b * d")
|
| 242 |
+
out = rearrange(out, "b ... d -> b d ...")
|
| 243 |
+
|
| 244 |
+
indices = maybe(unpack_one)(indices, ps, "b * c")
|
| 245 |
+
|
| 246 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
| 247 |
+
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
| 248 |
+
|
| 249 |
+
# return quantized output and indices
|
| 250 |
+
|
| 251 |
+
return out, indices
|
trained_30_percents/sparktts/modules/fsq/residual_fsq.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
from torch.amp import autocast
|
| 10 |
+
from einx import get_at
|
| 11 |
+
from einops import rearrange, reduce, pack, unpack
|
| 12 |
+
|
| 13 |
+
from sparktts.modules.fsq.finite_scalar_quantization import FSQ
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def exists(val):
|
| 17 |
+
return val is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def first(l):
|
| 21 |
+
return l[0]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
return val if exists(val) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def round_up_multiple(num, mult):
|
| 29 |
+
return ceil(num / mult) * mult
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# distributed helpers
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_distributed():
|
| 36 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_maybe_sync_seed(device, max_size=10_000):
|
| 40 |
+
rand_int = torch.randint(0, max_size, (), device=device)
|
| 41 |
+
|
| 42 |
+
if is_distributed():
|
| 43 |
+
dist.all_reduce(rand_int)
|
| 44 |
+
|
| 45 |
+
return rand_int.item()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ResidualFSQ(Module):
|
| 49 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
*,
|
| 54 |
+
levels: List[int],
|
| 55 |
+
num_quantizers,
|
| 56 |
+
dim=None,
|
| 57 |
+
is_channel_first=False,
|
| 58 |
+
quantize_dropout=False,
|
| 59 |
+
quantize_dropout_cutoff_index=0,
|
| 60 |
+
quantize_dropout_multiple_of=1,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
codebook_dim = len(levels)
|
| 65 |
+
dim = default(dim, codebook_dim)
|
| 66 |
+
|
| 67 |
+
requires_projection = codebook_dim != dim
|
| 68 |
+
self.project_in = (
|
| 69 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
| 70 |
+
)
|
| 71 |
+
self.project_out = (
|
| 72 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
| 73 |
+
)
|
| 74 |
+
self.has_projections = requires_projection
|
| 75 |
+
|
| 76 |
+
self.is_channel_first = is_channel_first
|
| 77 |
+
self.num_quantizers = num_quantizers
|
| 78 |
+
|
| 79 |
+
self.levels = levels
|
| 80 |
+
self.layers = nn.ModuleList([])
|
| 81 |
+
|
| 82 |
+
levels_tensor = torch.Tensor(levels)
|
| 83 |
+
|
| 84 |
+
scales = []
|
| 85 |
+
|
| 86 |
+
for ind in range(num_quantizers):
|
| 87 |
+
scales.append((levels_tensor - 1) ** -ind)
|
| 88 |
+
|
| 89 |
+
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
|
| 90 |
+
|
| 91 |
+
self.layers.append(fsq)
|
| 92 |
+
|
| 93 |
+
assert all([not fsq.has_projections for fsq in self.layers])
|
| 94 |
+
|
| 95 |
+
self.codebook_size = self.layers[0].codebook_size
|
| 96 |
+
|
| 97 |
+
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
| 98 |
+
|
| 99 |
+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
| 100 |
+
|
| 101 |
+
assert quantize_dropout_cutoff_index >= 0
|
| 102 |
+
|
| 103 |
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
| 104 |
+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def codebooks(self):
|
| 108 |
+
codebooks = [layer.implicit_codebook for layer in self.layers]
|
| 109 |
+
codebooks = torch.stack(codebooks, dim=0)
|
| 110 |
+
return codebooks
|
| 111 |
+
|
| 112 |
+
def get_codes_from_indices(self, indices):
|
| 113 |
+
|
| 114 |
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
| 115 |
+
|
| 116 |
+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
| 117 |
+
|
| 118 |
+
indices, ps = pack([indices], "b * q")
|
| 119 |
+
|
| 120 |
+
# because of quantize dropout, one can pass in indices that are coarse
|
| 121 |
+
# and the network should be able to reconstruct
|
| 122 |
+
|
| 123 |
+
if quantize_dim < self.num_quantizers:
|
| 124 |
+
assert (
|
| 125 |
+
self.quantize_dropout > 0.0
|
| 126 |
+
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
| 127 |
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
| 128 |
+
|
| 129 |
+
# take care of quantizer dropout
|
| 130 |
+
|
| 131 |
+
mask = indices == -1
|
| 132 |
+
indices = indices.masked_fill(
|
| 133 |
+
mask, 0
|
| 134 |
+
) # have it fetch a dummy code to be masked out later
|
| 135 |
+
|
| 136 |
+
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
|
| 137 |
+
|
| 138 |
+
# mask out any codes that were dropout-ed
|
| 139 |
+
|
| 140 |
+
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
|
| 141 |
+
|
| 142 |
+
# scale the codes
|
| 143 |
+
|
| 144 |
+
scales = rearrange(self.scales, "q d -> q 1 1 d")
|
| 145 |
+
all_codes = all_codes * scales
|
| 146 |
+
|
| 147 |
+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
| 148 |
+
|
| 149 |
+
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
| 150 |
+
|
| 151 |
+
return all_codes
|
| 152 |
+
|
| 153 |
+
def get_output_from_indices(self, indices):
|
| 154 |
+
codes = self.get_codes_from_indices(indices)
|
| 155 |
+
codes_summed = reduce(codes, "q ... -> ...", "sum")
|
| 156 |
+
return self.project_out(codes_summed)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
|
| 159 |
+
num_quant, quant_dropout_multiple_of, device = (
|
| 160 |
+
self.num_quantizers,
|
| 161 |
+
self.quantize_dropout_multiple_of,
|
| 162 |
+
x.device,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# handle channel first
|
| 166 |
+
|
| 167 |
+
if self.is_channel_first:
|
| 168 |
+
x = rearrange(x, "b d ... -> b ... d")
|
| 169 |
+
x, ps = pack([x], "b * d")
|
| 170 |
+
|
| 171 |
+
# maybe project in
|
| 172 |
+
|
| 173 |
+
x = self.project_in(x)
|
| 174 |
+
|
| 175 |
+
quantized_out = 0.0
|
| 176 |
+
residual = x
|
| 177 |
+
|
| 178 |
+
all_indices = []
|
| 179 |
+
|
| 180 |
+
should_quantize_dropout = self.training and self.quantize_dropout
|
| 181 |
+
|
| 182 |
+
# sample a layer index at which to dropout further residual quantization
|
| 183 |
+
# also prepare null indices
|
| 184 |
+
|
| 185 |
+
if should_quantize_dropout:
|
| 186 |
+
|
| 187 |
+
# check if seed is manually passed in
|
| 188 |
+
|
| 189 |
+
if not exists(rand_quantize_dropout_fixed_seed):
|
| 190 |
+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
|
| 191 |
+
|
| 192 |
+
rand = random.Random(rand_quantize_dropout_fixed_seed)
|
| 193 |
+
|
| 194 |
+
rand_quantize_dropout_index = rand.randrange(
|
| 195 |
+
self.quantize_dropout_cutoff_index, num_quant
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if quant_dropout_multiple_of != 1:
|
| 199 |
+
rand_quantize_dropout_index = (
|
| 200 |
+
round_up_multiple(
|
| 201 |
+
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
| 202 |
+
)
|
| 203 |
+
- 1
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
null_indices = torch.full(
|
| 207 |
+
x.shape[:2], -1.0, device=device, dtype=torch.long
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# go through the layers
|
| 211 |
+
|
| 212 |
+
with autocast("cuda", enabled=False):
|
| 213 |
+
for quantizer_index, (layer, scale) in enumerate(
|
| 214 |
+
zip(self.layers, self.scales)
|
| 215 |
+
):
|
| 216 |
+
|
| 217 |
+
if (
|
| 218 |
+
should_quantize_dropout
|
| 219 |
+
and quantizer_index > rand_quantize_dropout_index
|
| 220 |
+
):
|
| 221 |
+
all_indices.append(null_indices)
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
quantized, indices = layer(residual / scale)
|
| 225 |
+
|
| 226 |
+
quantized = quantized * scale
|
| 227 |
+
|
| 228 |
+
residual = residual - quantized.detach()
|
| 229 |
+
quantized_out = quantized_out + quantized
|
| 230 |
+
|
| 231 |
+
all_indices.append(indices)
|
| 232 |
+
|
| 233 |
+
# project out, if needed
|
| 234 |
+
|
| 235 |
+
quantized_out = self.project_out(quantized_out)
|
| 236 |
+
|
| 237 |
+
# stack all indices
|
| 238 |
+
|
| 239 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
| 240 |
+
|
| 241 |
+
# channel first out
|
| 242 |
+
|
| 243 |
+
if self.is_channel_first:
|
| 244 |
+
(quantized_out,) = unpack(quantized_out, ps, "b * d")
|
| 245 |
+
(all_indices,) = unpack(all_indices, ps, "b * d")
|
| 246 |
+
|
| 247 |
+
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
|
| 248 |
+
all_indices = rearrange(all_indices, "b ... d -> b d ...")
|
| 249 |
+
|
| 250 |
+
# return
|
| 251 |
+
|
| 252 |
+
ret = (quantized_out, all_indices)
|
| 253 |
+
|
| 254 |
+
if not return_all_codes:
|
| 255 |
+
return ret
|
| 256 |
+
|
| 257 |
+
# whether to return all codes from all codebooks across layers
|
| 258 |
+
|
| 259 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
| 260 |
+
|
| 261 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
| 262 |
+
|
| 263 |
+
return (*ret, all_codes)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# grouped residual fsq
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class GroupedResidualFSQ(Module):
|
| 270 |
+
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.dim = dim
|
| 273 |
+
self.groups = groups
|
| 274 |
+
assert (dim % groups) == 0
|
| 275 |
+
dim_per_group = dim // groups
|
| 276 |
+
|
| 277 |
+
self.accept_image_fmap = accept_image_fmap
|
| 278 |
+
|
| 279 |
+
self.rvqs = nn.ModuleList([])
|
| 280 |
+
|
| 281 |
+
for _ in range(groups):
|
| 282 |
+
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
|
| 283 |
+
|
| 284 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def codebooks(self):
|
| 288 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def split_dim(self):
|
| 292 |
+
return 1 if self.accept_image_fmap else -1
|
| 293 |
+
|
| 294 |
+
def get_codes_from_indices(self, indices):
|
| 295 |
+
codes = tuple(
|
| 296 |
+
rvq.get_codes_from_indices(chunk_indices)
|
| 297 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
| 298 |
+
)
|
| 299 |
+
return torch.stack(codes)
|
| 300 |
+
|
| 301 |
+
def get_output_from_indices(self, indices):
|
| 302 |
+
outputs = tuple(
|
| 303 |
+
rvq.get_output_from_indices(chunk_indices)
|
| 304 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
| 305 |
+
)
|
| 306 |
+
return torch.cat(outputs, dim=self.split_dim)
|
| 307 |
+
|
| 308 |
+
def forward(self, x, return_all_codes=False):
|
| 309 |
+
shape, split_dim, device = x.shape, self.split_dim, x.device
|
| 310 |
+
assert shape[split_dim] == self.dim
|
| 311 |
+
|
| 312 |
+
# split the feature dimension into groups
|
| 313 |
+
|
| 314 |
+
x = x.chunk(self.groups, dim=split_dim)
|
| 315 |
+
|
| 316 |
+
forward_kwargs = dict(
|
| 317 |
+
return_all_codes=return_all_codes,
|
| 318 |
+
rand_quantize_dropout_fixed_seed=(
|
| 319 |
+
get_maybe_sync_seed(device) if self.training else None
|
| 320 |
+
),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# invoke residual vq on each group
|
| 324 |
+
|
| 325 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
| 326 |
+
out = tuple(zip(*out))
|
| 327 |
+
|
| 328 |
+
# otherwise, get all the zipped outputs and combine them
|
| 329 |
+
|
| 330 |
+
quantized, all_indices, *maybe_all_codes = out
|
| 331 |
+
|
| 332 |
+
quantized = torch.cat(quantized, dim=split_dim)
|
| 333 |
+
all_indices = torch.stack(all_indices)
|
| 334 |
+
|
| 335 |
+
ret = (quantized, all_indices, *maybe_all_codes)
|
| 336 |
+
return ret
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
model = ResidualFSQ(
|
| 341 |
+
levels=[4, 4, 4, 4, 4, 4],
|
| 342 |
+
num_quantizers=1,
|
| 343 |
+
dim=30,
|
| 344 |
+
is_channel_first=True,
|
| 345 |
+
quantize_dropout=False,
|
| 346 |
+
)
|
| 347 |
+
x = torch.randn(2, 30, 10)
|
| 348 |
+
quantize, embed_ind = model(x)
|
| 349 |
+
|
| 350 |
+
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
|
| 351 |
+
|
| 352 |
+
print(quantize == emb_from_ind.transpose(1, 2))
|
| 353 |
+
|
| 354 |
+
print("quantize shape", quantize.shape)
|
| 355 |
+
print("embed_ind", embed_ind)
|
trained_30_percents/sparktts/modules/speaker/__pycache__/ecapa_tdnn.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
trained_30_percents/sparktts/modules/speaker/__pycache__/perceiver_encoder.cpython-311.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
trained_30_percents/sparktts/modules/speaker/__pycache__/pooling_layers.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
trained_30_percents/sparktts/modules/speaker/__pycache__/speaker_encoder.cpython-311.pyc
ADDED
|
Binary file (7.26 kB). View file
|
|
|
trained_30_percents/sparktts/modules/speaker/ecapa_tdnn.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Zhengyang Chen ([email protected])
|
| 2 |
+
# 2022 Hongji Wang ([email protected])
|
| 3 |
+
# 2023 Bing Han ([email protected])
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
""" This implementation is adapted from github repo:
|
| 18 |
+
https://github.com/lawlict/ECAPA-TDNN.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
import sparktts.modules.speaker.pooling_layers as pooling_layers
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Res2Conv1dReluBn(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
in_channels == out_channels == channels
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size=1,
|
| 37 |
+
stride=1,
|
| 38 |
+
padding=0,
|
| 39 |
+
dilation=1,
|
| 40 |
+
bias=True,
|
| 41 |
+
scale=4,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
| 45 |
+
self.scale = scale
|
| 46 |
+
self.width = channels // scale
|
| 47 |
+
self.nums = scale if scale == 1 else scale - 1
|
| 48 |
+
|
| 49 |
+
self.convs = []
|
| 50 |
+
self.bns = []
|
| 51 |
+
for i in range(self.nums):
|
| 52 |
+
self.convs.append(
|
| 53 |
+
nn.Conv1d(
|
| 54 |
+
self.width,
|
| 55 |
+
self.width,
|
| 56 |
+
kernel_size,
|
| 57 |
+
stride,
|
| 58 |
+
padding,
|
| 59 |
+
dilation,
|
| 60 |
+
bias=bias,
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
| 64 |
+
self.convs = nn.ModuleList(self.convs)
|
| 65 |
+
self.bns = nn.ModuleList(self.bns)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
out = []
|
| 69 |
+
spx = torch.split(x, self.width, 1)
|
| 70 |
+
sp = spx[0]
|
| 71 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
| 72 |
+
# Order: conv -> relu -> bn
|
| 73 |
+
if i >= 1:
|
| 74 |
+
sp = sp + spx[i]
|
| 75 |
+
sp = conv(sp)
|
| 76 |
+
sp = bn(F.relu(sp))
|
| 77 |
+
out.append(sp)
|
| 78 |
+
if self.scale != 1:
|
| 79 |
+
out.append(spx[self.nums])
|
| 80 |
+
out = torch.cat(out, dim=1)
|
| 81 |
+
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
""" Conv1d + BatchNorm1d + ReLU
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Conv1dReluBn(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_channels,
|
| 94 |
+
out_channels,
|
| 95 |
+
kernel_size=1,
|
| 96 |
+
stride=1,
|
| 97 |
+
padding=0,
|
| 98 |
+
dilation=1,
|
| 99 |
+
bias=True,
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.conv = nn.Conv1d(
|
| 103 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
| 104 |
+
)
|
| 105 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
return self.bn(F.relu(self.conv(x)))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
""" The SE connection of 1D case.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class SE_Connect(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
| 120 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
out = x.mean(dim=2)
|
| 124 |
+
out = F.relu(self.linear1(out))
|
| 125 |
+
out = torch.sigmoid(self.linear2(out))
|
| 126 |
+
out = x * out.unsqueeze(2)
|
| 127 |
+
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SE_Res2Block(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.se_res2block = nn.Sequential(
|
| 140 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
| 141 |
+
Res2Conv1dReluBn(
|
| 142 |
+
channels, kernel_size, stride, padding, dilation, scale=scale
|
| 143 |
+
),
|
| 144 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
| 145 |
+
SE_Connect(channels),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
return x + self.se_res2block(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ECAPA_TDNN(nn.Module):
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
channels=512,
|
| 157 |
+
feat_dim=80,
|
| 158 |
+
embed_dim=192,
|
| 159 |
+
pooling_func="ASTP",
|
| 160 |
+
global_context_att=False,
|
| 161 |
+
emb_bn=False,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
|
| 166 |
+
self.layer2 = SE_Res2Block(
|
| 167 |
+
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
|
| 168 |
+
)
|
| 169 |
+
self.layer3 = SE_Res2Block(
|
| 170 |
+
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
|
| 171 |
+
)
|
| 172 |
+
self.layer4 = SE_Res2Block(
|
| 173 |
+
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
cat_channels = channels * 3
|
| 177 |
+
out_channels = 512 * 3
|
| 178 |
+
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
|
| 179 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
| 180 |
+
in_dim=out_channels, global_context_att=global_context_att
|
| 181 |
+
)
|
| 182 |
+
self.pool_out_dim = self.pool.get_out_dim()
|
| 183 |
+
self.bn = nn.BatchNorm1d(self.pool_out_dim)
|
| 184 |
+
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
|
| 185 |
+
self.emb_bn = emb_bn
|
| 186 |
+
if emb_bn: # better in SSL for SV
|
| 187 |
+
self.bn2 = nn.BatchNorm1d(embed_dim)
|
| 188 |
+
else:
|
| 189 |
+
self.bn2 = nn.Identity()
|
| 190 |
+
|
| 191 |
+
def forward(self, x, return_latent=False):
|
| 192 |
+
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
| 193 |
+
|
| 194 |
+
out1 = self.layer1(x)
|
| 195 |
+
out2 = self.layer2(out1)
|
| 196 |
+
out3 = self.layer3(out2)
|
| 197 |
+
out4 = self.layer4(out3)
|
| 198 |
+
|
| 199 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
| 200 |
+
latent = F.relu(self.conv(out))
|
| 201 |
+
out = self.bn(self.pool(latent))
|
| 202 |
+
out = self.linear(out)
|
| 203 |
+
if self.emb_bn:
|
| 204 |
+
out = self.bn2(out)
|
| 205 |
+
|
| 206 |
+
if return_latent:
|
| 207 |
+
return out, latent
|
| 208 |
+
return out
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 212 |
+
return ECAPA_TDNN(
|
| 213 |
+
channels=1024,
|
| 214 |
+
feat_dim=feat_dim,
|
| 215 |
+
embed_dim=embed_dim,
|
| 216 |
+
pooling_func=pooling_func,
|
| 217 |
+
emb_bn=emb_bn,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 222 |
+
return ECAPA_TDNN(
|
| 223 |
+
channels=1024,
|
| 224 |
+
feat_dim=feat_dim,
|
| 225 |
+
embed_dim=embed_dim,
|
| 226 |
+
pooling_func=pooling_func,
|
| 227 |
+
global_context_att=True,
|
| 228 |
+
emb_bn=emb_bn,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 233 |
+
return ECAPA_TDNN(
|
| 234 |
+
channels=512,
|
| 235 |
+
feat_dim=feat_dim,
|
| 236 |
+
embed_dim=embed_dim,
|
| 237 |
+
pooling_func=pooling_func,
|
| 238 |
+
emb_bn=emb_bn,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
| 243 |
+
return ECAPA_TDNN(
|
| 244 |
+
channels=512,
|
| 245 |
+
feat_dim=feat_dim,
|
| 246 |
+
embed_dim=embed_dim,
|
| 247 |
+
pooling_func=pooling_func,
|
| 248 |
+
global_context_att=True,
|
| 249 |
+
emb_bn=emb_bn,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
x = torch.zeros(1, 200, 100)
|
| 255 |
+
model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
|
| 256 |
+
model.eval()
|
| 257 |
+
out, latent = model(x, True)
|
| 258 |
+
print(out.shape)
|
| 259 |
+
print(latent.shape)
|
| 260 |
+
|
| 261 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 262 |
+
print("{} M".format(num_params / 1e6))
|
| 263 |
+
|
| 264 |
+
# from thop import profile
|
| 265 |
+
# x_np = torch.randn(1, 200, 80)
|
| 266 |
+
# flops, params = profile(model, inputs=(x_np, ))
|
| 267 |
+
# print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
|
trained_30_percents/sparktts/modules/speaker/perceiver_encoder.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
| 17 |
+
|
| 18 |
+
from collections import namedtuple
|
| 19 |
+
from functools import wraps
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from einops import rearrange, repeat
|
| 24 |
+
from einops.layers.torch import Rearrange
|
| 25 |
+
from packaging import version
|
| 26 |
+
from torch import einsum, nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def exists(val):
|
| 30 |
+
return val is not None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def once(fn):
|
| 34 |
+
called = False
|
| 35 |
+
|
| 36 |
+
@wraps(fn)
|
| 37 |
+
def inner(x):
|
| 38 |
+
nonlocal called
|
| 39 |
+
if called:
|
| 40 |
+
return
|
| 41 |
+
called = True
|
| 42 |
+
return fn(x)
|
| 43 |
+
|
| 44 |
+
return inner
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
print_once = once(print)
|
| 48 |
+
|
| 49 |
+
# main class
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Attend(nn.Module):
|
| 53 |
+
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.dropout = dropout
|
| 56 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 57 |
+
|
| 58 |
+
self.causal = causal
|
| 59 |
+
self.register_buffer("mask", None, persistent=False)
|
| 60 |
+
|
| 61 |
+
self.use_flash = use_flash
|
| 62 |
+
assert not (
|
| 63 |
+
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 64 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 65 |
+
|
| 66 |
+
# determine efficient attention configs for cuda and cpu
|
| 67 |
+
self.config = namedtuple(
|
| 68 |
+
"EfficientAttentionConfig",
|
| 69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
| 70 |
+
)
|
| 71 |
+
self.cpu_config = self.config(True, True, True)
|
| 72 |
+
self.cuda_config = None
|
| 73 |
+
|
| 74 |
+
if not torch.cuda.is_available() or not use_flash:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 78 |
+
|
| 79 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
| 80 |
+
print_once(
|
| 81 |
+
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
| 82 |
+
)
|
| 83 |
+
self.cuda_config = self.config(True, False, False)
|
| 84 |
+
else:
|
| 85 |
+
print_once(
|
| 86 |
+
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
| 87 |
+
)
|
| 88 |
+
self.cuda_config = self.config(False, True, True)
|
| 89 |
+
|
| 90 |
+
def get_mask(self, n, device):
|
| 91 |
+
if exists(self.mask) and self.mask.shape[-1] >= n:
|
| 92 |
+
return self.mask[:n, :n]
|
| 93 |
+
|
| 94 |
+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
| 95 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 96 |
+
return mask
|
| 97 |
+
|
| 98 |
+
def flash_attn(self, q, k, v, mask=None):
|
| 99 |
+
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
| 100 |
+
|
| 101 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
| 102 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
| 103 |
+
|
| 104 |
+
if k.ndim == 3:
|
| 105 |
+
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
| 106 |
+
|
| 107 |
+
if v.ndim == 3:
|
| 108 |
+
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
| 109 |
+
|
| 110 |
+
# Check if mask exists and expand to compatible shape
|
| 111 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
| 112 |
+
|
| 113 |
+
if exists(mask):
|
| 114 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
| 115 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
| 116 |
+
|
| 117 |
+
# Check if there is a compatible device for flash attention
|
| 118 |
+
|
| 119 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
| 120 |
+
|
| 121 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
| 122 |
+
|
| 123 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 124 |
+
out = F.scaled_dot_product_attention(
|
| 125 |
+
q,
|
| 126 |
+
k,
|
| 127 |
+
v,
|
| 128 |
+
attn_mask=mask,
|
| 129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 130 |
+
is_causal=self.causal,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def forward(self, q, k, v, mask=None):
|
| 136 |
+
"""
|
| 137 |
+
einstein notation
|
| 138 |
+
b - batch
|
| 139 |
+
h - heads
|
| 140 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 141 |
+
d - feature dimension
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
n, device = q.shape[-2], q.device
|
| 145 |
+
|
| 146 |
+
scale = q.shape[-1] ** -0.5
|
| 147 |
+
|
| 148 |
+
if self.use_flash:
|
| 149 |
+
return self.flash_attn(q, k, v, mask=mask)
|
| 150 |
+
|
| 151 |
+
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
| 152 |
+
|
| 153 |
+
# similarity
|
| 154 |
+
|
| 155 |
+
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
| 156 |
+
|
| 157 |
+
# key padding mask
|
| 158 |
+
|
| 159 |
+
if exists(mask):
|
| 160 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
| 161 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
| 162 |
+
|
| 163 |
+
# causal mask
|
| 164 |
+
|
| 165 |
+
if self.causal:
|
| 166 |
+
causal_mask = self.get_mask(n, device)
|
| 167 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
| 168 |
+
|
| 169 |
+
# attention
|
| 170 |
+
|
| 171 |
+
attn = sim.softmax(dim=-1)
|
| 172 |
+
attn = self.attn_dropout(attn)
|
| 173 |
+
|
| 174 |
+
# aggregate values
|
| 175 |
+
|
| 176 |
+
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
| 177 |
+
|
| 178 |
+
return out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def Sequential(*mods):
|
| 182 |
+
return nn.Sequential(*filter(exists, mods))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def exists(x):
|
| 186 |
+
return x is not None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def default(val, d):
|
| 190 |
+
if exists(val):
|
| 191 |
+
return val
|
| 192 |
+
return d() if callable(d) else d
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class RMSNorm(nn.Module):
|
| 196 |
+
def __init__(self, dim, scale=True, dim_cond=None):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.cond = exists(dim_cond)
|
| 199 |
+
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
| 200 |
+
|
| 201 |
+
self.scale = dim**0.5
|
| 202 |
+
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
| 203 |
+
|
| 204 |
+
def forward(self, x, cond=None):
|
| 205 |
+
gamma = default(self.gamma, 1)
|
| 206 |
+
out = F.normalize(x, dim=-1) * self.scale * gamma
|
| 207 |
+
|
| 208 |
+
if not self.cond:
|
| 209 |
+
return out
|
| 210 |
+
|
| 211 |
+
assert exists(cond)
|
| 212 |
+
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
| 213 |
+
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
| 214 |
+
return out * gamma + beta
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CausalConv1d(nn.Conv1d):
|
| 218 |
+
def __init__(self, *args, **kwargs):
|
| 219 |
+
super().__init__(*args, **kwargs)
|
| 220 |
+
(kernel_size,) = self.kernel_size
|
| 221 |
+
(dilation,) = self.dilation
|
| 222 |
+
(stride,) = self.stride
|
| 223 |
+
|
| 224 |
+
assert stride == 1
|
| 225 |
+
self.causal_padding = dilation * (kernel_size - 1)
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 229 |
+
return super().forward(causal_padded_x)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class GEGLU(nn.Module):
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
x, gate = x.chunk(2, dim=-1)
|
| 235 |
+
return F.gelu(gate) * x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def FeedForward(dim, mult=4, causal_conv=False):
|
| 239 |
+
dim_inner = int(dim * mult * 2 / 3)
|
| 240 |
+
|
| 241 |
+
conv = None
|
| 242 |
+
if causal_conv:
|
| 243 |
+
conv = nn.Sequential(
|
| 244 |
+
Rearrange("b n d -> b d n"),
|
| 245 |
+
CausalConv1d(dim_inner, dim_inner, 3),
|
| 246 |
+
Rearrange("b d n -> b n d"),
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return Sequential(
|
| 250 |
+
nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class Attention(nn.Module):
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
dim,
|
| 258 |
+
*,
|
| 259 |
+
dim_context=None,
|
| 260 |
+
causal=False,
|
| 261 |
+
dim_head=64,
|
| 262 |
+
heads=8,
|
| 263 |
+
dropout=0.0,
|
| 264 |
+
use_flash=False,
|
| 265 |
+
cross_attn_include_queries=False,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.scale = dim_head**-0.5
|
| 269 |
+
self.heads = heads
|
| 270 |
+
self.cross_attn_include_queries = cross_attn_include_queries
|
| 271 |
+
|
| 272 |
+
dim_inner = dim_head * heads
|
| 273 |
+
dim_context = default(dim_context, dim)
|
| 274 |
+
|
| 275 |
+
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
| 276 |
+
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
| 277 |
+
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
| 278 |
+
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
| 279 |
+
|
| 280 |
+
def forward(self, x, context=None, mask=None):
|
| 281 |
+
h, has_context = self.heads, exists(context)
|
| 282 |
+
|
| 283 |
+
context = default(context, x)
|
| 284 |
+
|
| 285 |
+
if has_context and self.cross_attn_include_queries:
|
| 286 |
+
context = torch.cat((x, context), dim=-2)
|
| 287 |
+
|
| 288 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
| 289 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 290 |
+
|
| 291 |
+
out = self.attend(q, k, v, mask=mask)
|
| 292 |
+
|
| 293 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 294 |
+
return self.to_out(out)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class PerceiverResampler(nn.Module):
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
*,
|
| 301 |
+
dim,
|
| 302 |
+
depth=2,
|
| 303 |
+
dim_context=None,
|
| 304 |
+
num_latents=32,
|
| 305 |
+
dim_head=64,
|
| 306 |
+
heads=8,
|
| 307 |
+
ff_mult=4,
|
| 308 |
+
use_flash_attn=False,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
dim_context = default(dim_context, dim)
|
| 312 |
+
|
| 313 |
+
self.proj_context = (
|
| 314 |
+
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 318 |
+
nn.init.normal_(self.latents, std=0.02)
|
| 319 |
+
|
| 320 |
+
self.layers = nn.ModuleList([])
|
| 321 |
+
for _ in range(depth):
|
| 322 |
+
self.layers.append(
|
| 323 |
+
nn.ModuleList(
|
| 324 |
+
[
|
| 325 |
+
Attention(
|
| 326 |
+
dim=dim,
|
| 327 |
+
dim_head=dim_head,
|
| 328 |
+
heads=heads,
|
| 329 |
+
use_flash=use_flash_attn,
|
| 330 |
+
cross_attn_include_queries=True,
|
| 331 |
+
),
|
| 332 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.norm = RMSNorm(dim)
|
| 338 |
+
|
| 339 |
+
def forward(self, x, mask=None):
|
| 340 |
+
batch = x.shape[0]
|
| 341 |
+
|
| 342 |
+
x = self.proj_context(x)
|
| 343 |
+
|
| 344 |
+
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
| 345 |
+
|
| 346 |
+
for attn, ff in self.layers:
|
| 347 |
+
latents = attn(latents, x, mask=mask) + latents
|
| 348 |
+
latents = ff(latents) + latents
|
| 349 |
+
|
| 350 |
+
return self.norm(latents)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
model = PerceiverResampler(dim=256, dim_context=80)
|
| 355 |
+
x = torch.randn(8, 200, 80)
|
| 356 |
+
out = model(x)
|
| 357 |
+
print(out.shape) # [8, 32, 80]
|
| 358 |
+
|
| 359 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 360 |
+
print("{} M".format(num_params / 1e6))
|
trained_30_percents/sparktts/modules/speaker/pooling_layers.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Shuai Wang ([email protected])
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Pooling functions to aggregate frame-level deep features
|
| 16 |
+
into segment-level speaker embeddings
|
| 17 |
+
|
| 18 |
+
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
|
| 19 |
+
even though we remove the mean statistic, on Voxceleb.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TAP(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Temporal average pooling, only first-order mean is considered
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 33 |
+
super(TAP, self).__init__()
|
| 34 |
+
self.in_dim = in_dim
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pooling_mean = x.mean(dim=-1)
|
| 38 |
+
# To be compatable with 2D input
|
| 39 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 40 |
+
return pooling_mean
|
| 41 |
+
|
| 42 |
+
def get_out_dim(self):
|
| 43 |
+
self.out_dim = self.in_dim
|
| 44 |
+
return self.out_dim
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TSDP(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
Temporal standard deviation pooling, only second-order std is considered
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 53 |
+
super(TSDP, self).__init__()
|
| 54 |
+
self.in_dim = in_dim
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# The last dimension is the temporal axis
|
| 58 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
| 59 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 60 |
+
return pooling_std
|
| 61 |
+
|
| 62 |
+
def get_out_dim(self):
|
| 63 |
+
self.out_dim = self.in_dim
|
| 64 |
+
return self.out_dim
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TSTP(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
| 70 |
+
x-vector
|
| 71 |
+
Comment: simple concatenation can not make full use of both statistics
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, in_dim=0, **kwargs):
|
| 75 |
+
super(TSTP, self).__init__()
|
| 76 |
+
self.in_dim = in_dim
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
# The last dimension is the temporal axis
|
| 80 |
+
pooling_mean = x.mean(dim=-1)
|
| 81 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
| 82 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 83 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 84 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
| 85 |
+
return stats
|
| 86 |
+
|
| 87 |
+
def get_out_dim(self):
|
| 88 |
+
self.out_dim = self.in_dim * 2
|
| 89 |
+
return self.out_dim
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ASTP(nn.Module):
|
| 93 |
+
""" Attentive statistics pooling: Channel- and context-dependent
|
| 94 |
+
statistics pooling, first used in ECAPA_TDNN.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self,
|
| 98 |
+
in_dim,
|
| 99 |
+
bottleneck_dim=128,
|
| 100 |
+
global_context_att=False,
|
| 101 |
+
**kwargs):
|
| 102 |
+
super(ASTP, self).__init__()
|
| 103 |
+
self.in_dim = in_dim
|
| 104 |
+
self.global_context_att = global_context_att
|
| 105 |
+
|
| 106 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
| 107 |
+
# need to transpose inputs.
|
| 108 |
+
if global_context_att:
|
| 109 |
+
self.linear1 = nn.Conv1d(
|
| 110 |
+
in_dim * 3, bottleneck_dim,
|
| 111 |
+
kernel_size=1) # equals W and b in the paper
|
| 112 |
+
else:
|
| 113 |
+
self.linear1 = nn.Conv1d(
|
| 114 |
+
in_dim, bottleneck_dim,
|
| 115 |
+
kernel_size=1) # equals W and b in the paper
|
| 116 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
| 117 |
+
kernel_size=1) # equals V and k in the paper
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
"""
|
| 121 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
| 122 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
| 123 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 124 |
+
"""
|
| 125 |
+
if len(x.shape) == 4:
|
| 126 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
| 127 |
+
assert len(x.shape) == 3
|
| 128 |
+
|
| 129 |
+
if self.global_context_att:
|
| 130 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 131 |
+
context_std = torch.sqrt(
|
| 132 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
|
| 133 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 134 |
+
else:
|
| 135 |
+
x_in = x
|
| 136 |
+
|
| 137 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
| 138 |
+
alpha = torch.tanh(
|
| 139 |
+
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
| 140 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 141 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 142 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 143 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
| 144 |
+
return torch.cat([mean, std], dim=1)
|
| 145 |
+
|
| 146 |
+
def get_out_dim(self):
|
| 147 |
+
self.out_dim = 2 * self.in_dim
|
| 148 |
+
return self.out_dim
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MHASTP(torch.nn.Module):
|
| 152 |
+
""" Multi head attentive statistics pooling
|
| 153 |
+
Reference:
|
| 154 |
+
Self Multi-Head Attention for Speaker Recognition
|
| 155 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
in_dim,
|
| 160 |
+
layer_num=2,
|
| 161 |
+
head_num=2,
|
| 162 |
+
d_s=1,
|
| 163 |
+
bottleneck_dim=64,
|
| 164 |
+
**kwargs):
|
| 165 |
+
super(MHASTP, self).__init__()
|
| 166 |
+
assert (in_dim % head_num
|
| 167 |
+
) == 0 # make sure that head num can be divided by input_dim
|
| 168 |
+
self.in_dim = in_dim
|
| 169 |
+
self.head_num = head_num
|
| 170 |
+
d_model = int(in_dim / head_num)
|
| 171 |
+
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
|
| 172 |
+
if d_s > 1:
|
| 173 |
+
d_s = d_model
|
| 174 |
+
else:
|
| 175 |
+
d_s = 1
|
| 176 |
+
self.d_s = d_s
|
| 177 |
+
channel_dims[0], channel_dims[-1] = d_model, d_s
|
| 178 |
+
heads_att_trans = []
|
| 179 |
+
for i in range(self.head_num):
|
| 180 |
+
att_trans = nn.Sequential()
|
| 181 |
+
for i in range(layer_num - 1):
|
| 182 |
+
att_trans.add_module(
|
| 183 |
+
'att_' + str(i),
|
| 184 |
+
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
|
| 185 |
+
att_trans.add_module('tanh' + str(i), nn.Tanh())
|
| 186 |
+
att_trans.add_module(
|
| 187 |
+
'att_' + str(layer_num - 1),
|
| 188 |
+
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
|
| 189 |
+
1, 1))
|
| 190 |
+
heads_att_trans.append(att_trans)
|
| 191 |
+
self.heads_att_trans = nn.ModuleList(heads_att_trans)
|
| 192 |
+
|
| 193 |
+
def forward(self, input):
|
| 194 |
+
"""
|
| 195 |
+
input: a 3-dimensional tensor in xvector architecture
|
| 196 |
+
or a 4-dimensional tensor in resnet architecture
|
| 197 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 198 |
+
"""
|
| 199 |
+
if len(input.shape) == 4: # B x F x T
|
| 200 |
+
input = input.reshape(input.shape[0],
|
| 201 |
+
input.shape[1] * input.shape[2],
|
| 202 |
+
input.shape[3])
|
| 203 |
+
assert len(input.shape) == 3
|
| 204 |
+
bs, f_dim, t_dim = input.shape
|
| 205 |
+
chunks = torch.chunk(input, self.head_num, 1)
|
| 206 |
+
# split
|
| 207 |
+
chunks_out = []
|
| 208 |
+
# for i in range(self.head_num):
|
| 209 |
+
# att_score = self.heads_att_trans[i](chunks[i])
|
| 210 |
+
for i, layer in enumerate(self.heads_att_trans):
|
| 211 |
+
att_score = layer(chunks[i])
|
| 212 |
+
alpha = F.softmax(att_score, dim=-1)
|
| 213 |
+
mean = torch.sum(alpha * chunks[i], dim=2)
|
| 214 |
+
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
|
| 215 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
| 216 |
+
chunks_out.append(torch.cat((mean, std), dim=1))
|
| 217 |
+
out = torch.cat(chunks_out, dim=1)
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
def get_out_dim(self):
|
| 221 |
+
self.out_dim = 2 * self.in_dim
|
| 222 |
+
return self.out_dim
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class MQMHASTP(torch.nn.Module):
|
| 226 |
+
""" An attentive pooling
|
| 227 |
+
Reference:
|
| 228 |
+
multi query multi head attentive statistics pooling
|
| 229 |
+
https://arxiv.org/pdf/2110.05042.pdf
|
| 230 |
+
Args:
|
| 231 |
+
in_dim: the feature dimension of input
|
| 232 |
+
layer_num: the number of layer in the pooling layer
|
| 233 |
+
query_num: the number of querys
|
| 234 |
+
head_num: the number of heads
|
| 235 |
+
bottleneck_dim: the bottleneck dimension
|
| 236 |
+
|
| 237 |
+
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
|
| 238 |
+
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
|
| 239 |
+
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
|
| 240 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
| 241 |
+
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
|
| 242 |
+
https://arxiv.org/pdf/1803.10963.pdf
|
| 243 |
+
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
|
| 244 |
+
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self,
|
| 248 |
+
in_dim,
|
| 249 |
+
layer_num=2,
|
| 250 |
+
query_num=2,
|
| 251 |
+
head_num=8,
|
| 252 |
+
d_s=2,
|
| 253 |
+
bottleneck_dim=64,
|
| 254 |
+
**kwargs):
|
| 255 |
+
super(MQMHASTP, self).__init__()
|
| 256 |
+
self.n_query = nn.ModuleList([
|
| 257 |
+
MHASTP(in_dim,
|
| 258 |
+
layer_num=layer_num,
|
| 259 |
+
head_num=head_num,
|
| 260 |
+
d_s=d_s,
|
| 261 |
+
bottleneck_dim=bottleneck_dim) for i in range(query_num)
|
| 262 |
+
])
|
| 263 |
+
self.query_num = query_num
|
| 264 |
+
self.in_dim = in_dim
|
| 265 |
+
|
| 266 |
+
def forward(self, input):
|
| 267 |
+
"""
|
| 268 |
+
input: a 3-dimensional tensor in xvector architecture
|
| 269 |
+
or a 4-dimensional tensor in resnet architecture
|
| 270 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 271 |
+
"""
|
| 272 |
+
if len(input.shape) == 4: # B x F x T
|
| 273 |
+
input = input.reshape(input.shape[0],
|
| 274 |
+
input.shape[1] * input.shape[2],
|
| 275 |
+
input.shape[3])
|
| 276 |
+
assert len(input.shape) == 3
|
| 277 |
+
res = []
|
| 278 |
+
for i, layer in enumerate(self.n_query):
|
| 279 |
+
res.append(layer(input))
|
| 280 |
+
out = torch.cat(res, dim=-1)
|
| 281 |
+
return out
|
| 282 |
+
|
| 283 |
+
def get_out_dim(self):
|
| 284 |
+
self.out_dim = self.in_dim * 2 * self.query_num
|
| 285 |
+
return self.out_dim
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == '__main__':
|
| 289 |
+
data = torch.randn(16, 512, 10, 35)
|
| 290 |
+
# model = StatisticsPooling()
|
| 291 |
+
model = MQMHASTP(512 * 10)
|
| 292 |
+
model = MHASTP(512 * 10)
|
| 293 |
+
model = MQMHASTP(512 * 10, context=False)
|
| 294 |
+
print(model)
|
| 295 |
+
|
| 296 |
+
out = model(data)
|
| 297 |
+
print(out.shape)
|
| 298 |
+
print(model.get_out_dim())
|
trained_30_percents/sparktts/modules/speaker/speaker_encoder.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SparkAudio
|
| 2 |
+
# 2025 Xinsheng Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from typing import List, Tuple
|
| 20 |
+
from sparktts.modules.fsq.residual_fsq import ResidualFSQ
|
| 21 |
+
from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
|
| 22 |
+
from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
x-vector + d-vector
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SpeakerEncoder(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
input_dim (int): acoustic feature dimension
|
| 34 |
+
out_dim (int): output dimension of x-vector and d-vector
|
| 35 |
+
latent_dim (int): latent dimension before quantization
|
| 36 |
+
token_num (int): sequence length of speaker tokens
|
| 37 |
+
fsq_levels (List[int]): number of levels for each quantizer
|
| 38 |
+
fsq_num_quantizers (int): number of quantizers
|
| 39 |
+
|
| 40 |
+
Return:
|
| 41 |
+
speaker_embs: (B, T2, out_dim)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
input_dim: int = 100,
|
| 47 |
+
out_dim: int = 512,
|
| 48 |
+
latent_dim: int = 128,
|
| 49 |
+
token_num: int = 32,
|
| 50 |
+
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
|
| 51 |
+
fsq_num_quantizers: int = 1,
|
| 52 |
+
):
|
| 53 |
+
super(SpeakerEncoder, self).__init__()
|
| 54 |
+
|
| 55 |
+
self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
|
| 56 |
+
feat_dim=input_dim, embed_dim=out_dim
|
| 57 |
+
)
|
| 58 |
+
self.perceiver_sampler = PerceiverResampler(
|
| 59 |
+
dim=latent_dim, dim_context=512 * 3, num_latents=token_num
|
| 60 |
+
)
|
| 61 |
+
self.quantizer = ResidualFSQ(
|
| 62 |
+
levels=fsq_levels,
|
| 63 |
+
num_quantizers=fsq_num_quantizers,
|
| 64 |
+
dim=latent_dim,
|
| 65 |
+
is_channel_first=True,
|
| 66 |
+
quantize_dropout=False,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.project = nn.Linear(latent_dim * token_num, out_dim)
|
| 70 |
+
|
| 71 |
+
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
|
| 73 |
+
return zq.transpose(1, 2)
|
| 74 |
+
|
| 75 |
+
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
mels = mels.transpose(1, 2)
|
| 77 |
+
x = self.perceiver_sampler(mels).transpose(1, 2)
|
| 78 |
+
zq, indices = self.quantizer(x)
|
| 79 |
+
return indices
|
| 80 |
+
|
| 81 |
+
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
mels: (B, D_mel, T1)
|
| 85 |
+
|
| 86 |
+
Return:
|
| 87 |
+
x_vector: (B, out_dim)
|
| 88 |
+
d_vector: (B, out_dim)
|
| 89 |
+
"""
|
| 90 |
+
# mels = mels.transpose(1,2)
|
| 91 |
+
|
| 92 |
+
x_vector, features = self.speaker_encoder(mels, True)
|
| 93 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
| 94 |
+
zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
|
| 95 |
+
x = zq.reshape(zq.shape[0], -1)
|
| 96 |
+
d_vector = self.project(x)
|
| 97 |
+
|
| 98 |
+
return x_vector, d_vector
|
| 99 |
+
|
| 100 |
+
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""tokenize the input mel spectrogram"""
|
| 102 |
+
_, features = self.speaker_encoder(mels, True)
|
| 103 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
| 104 |
+
zq, indices = self.quantizer(x)
|
| 105 |
+
return indices
|
| 106 |
+
|
| 107 |
+
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
"""detokenize the input indices to d-vector"""
|
| 109 |
+
zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
|
| 110 |
+
x = zq.reshape(zq.shape[0], -1)
|
| 111 |
+
d_vector = self.project(x)
|
| 112 |
+
return d_vector
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
model = SpeakerEncoder(
|
| 116 |
+
input_dim=100,
|
| 117 |
+
latent_dim=128,
|
| 118 |
+
token_num=32,
|
| 119 |
+
fsq_levels=[4, 4, 4, 4, 4, 4],
|
| 120 |
+
fsq_num_quantizers=1,
|
| 121 |
+
)
|
| 122 |
+
mel = torch.randn(8, 200, 100)
|
| 123 |
+
x_vector, d_vector = model(mel)
|
| 124 |
+
print("x-vector shape", x_vector.shape)
|
| 125 |
+
print("d-vector shape", d_vector.shape)
|
| 126 |
+
|
| 127 |
+
indices = model.tokenize(mel)
|
| 128 |
+
print("indices shape", indices.shape)
|
| 129 |
+
d_vector_post = model.detokenize(indices)
|
| 130 |
+
print("d-vector shape", d_vector_post.shape)
|
| 131 |
+
if d_vector_post.all() == d_vector.all():
|
| 132 |
+
print("d-vector post and d-vector are the same")
|
| 133 |
+
else:
|
| 134 |
+
print("d-vector post and d-vector are different")
|
| 135 |
+
num_params = sum(param.numel() for param in model.parameters())
|
| 136 |
+
print("{} M".format(num_params / 1e6))
|
trained_30_percents/sparktts/modules/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc
ADDED
|
Binary file (9.15 kB). View file
|
|
|