yueyulin commited on
Commit
4fe3e8c
·
verified ·
1 Parent(s): 1ee4524

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. trained_30_percents/.gitignore +16 -0
  3. trained_30_percents/BiCodec/config.yaml +60 -0
  4. trained_30_percents/BiCodec/model.safetensors +3 -0
  5. trained_30_percents/Readme.md +130 -0
  6. trained_30_percents/Readme_zh.md +130 -0
  7. trained_30_percents/__init__.py +0 -0
  8. trained_30_percents/__pycache__/spark_llm.cpython-311.pyc +0 -0
  9. trained_30_percents/__pycache__/utilities.cpython-311.pyc +0 -0
  10. trained_30_percents/added_tokens.json +3 -0
  11. trained_30_percents/config.json +66 -0
  12. trained_30_percents/config.yaml +7 -0
  13. trained_30_percents/configuration_rwkv7.py +91 -0
  14. trained_30_percents/generation_config.json +6 -0
  15. trained_30_percents/hf_rwkv_tokenizer.py +280 -0
  16. trained_30_percents/kafka.wav +3 -0
  17. trained_30_percents/model.safetensors +3 -0
  18. trained_30_percents/modeling_rwkvspeech.py +6 -0
  19. trained_30_percents/output.wav +3 -0
  20. trained_30_percents/rwkv_vocab_v20230424.txt +0 -0
  21. trained_30_percents/spark_llm.py +202 -0
  22. trained_30_percents/sparktts/models/__pycache__/audio_tokenizer.cpython-311.pyc +0 -0
  23. trained_30_percents/sparktts/models/__pycache__/bicodec.cpython-311.pyc +0 -0
  24. trained_30_percents/sparktts/models/audio_tokenizer.py +167 -0
  25. trained_30_percents/sparktts/models/bicodec.py +247 -0
  26. trained_30_percents/sparktts/modules/blocks/__pycache__/layers.cpython-311.pyc +0 -0
  27. trained_30_percents/sparktts/modules/blocks/__pycache__/samper.cpython-311.pyc +0 -0
  28. trained_30_percents/sparktts/modules/blocks/__pycache__/vocos.cpython-311.pyc +0 -0
  29. trained_30_percents/sparktts/modules/blocks/layers.py +73 -0
  30. trained_30_percents/sparktts/modules/blocks/samper.py +115 -0
  31. trained_30_percents/sparktts/modules/blocks/vocos.py +373 -0
  32. trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_decoder.cpython-311.pyc +0 -0
  33. trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_encoder.cpython-311.pyc +0 -0
  34. trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/wave_generator.cpython-311.pyc +0 -0
  35. trained_30_percents/sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
  36. trained_30_percents/sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
  37. trained_30_percents/sparktts/modules/encoder_decoder/wave_generator.py +88 -0
  38. trained_30_percents/sparktts/modules/fsq/__pycache__/finite_scalar_quantization.cpython-311.pyc +0 -0
  39. trained_30_percents/sparktts/modules/fsq/__pycache__/residual_fsq.cpython-311.pyc +0 -0
  40. trained_30_percents/sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
  41. trained_30_percents/sparktts/modules/fsq/residual_fsq.py +355 -0
  42. trained_30_percents/sparktts/modules/speaker/__pycache__/ecapa_tdnn.cpython-311.pyc +0 -0
  43. trained_30_percents/sparktts/modules/speaker/__pycache__/perceiver_encoder.cpython-311.pyc +0 -0
  44. trained_30_percents/sparktts/modules/speaker/__pycache__/pooling_layers.cpython-311.pyc +0 -0
  45. trained_30_percents/sparktts/modules/speaker/__pycache__/speaker_encoder.cpython-311.pyc +0 -0
  46. trained_30_percents/sparktts/modules/speaker/ecapa_tdnn.py +267 -0
  47. trained_30_percents/sparktts/modules/speaker/perceiver_encoder.py +360 -0
  48. trained_30_percents/sparktts/modules/speaker/pooling_layers.py +298 -0
  49. trained_30_percents/sparktts/modules/speaker/speaker_encoder.py +136 -0
  50. trained_30_percents/sparktts/modules/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trained_30_percents/kafka.wav filter=lfs diff=lfs merge=lfs -text
37
+ trained_30_percents/output.wav filter=lfs diff=lfs merge=lfs -text
trained_30_percents/.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python build artifacts
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ # Environment variables
6
+ .env
7
+
8
+ # Virtual environment
9
+ venv/
10
+
11
+ # Model backups and outputs
12
+ model.fp32.safetensors
13
+ output.wav
14
+
15
+ # Temporary scripts
16
+ check_dtype.py
trained_30_percents/BiCodec/config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio_tokenizer:
2
+ mel_params:
3
+ sample_rate: 16000
4
+ n_fft: 1024
5
+ win_length: 640
6
+ hop_length: 320
7
+ mel_fmin: 10
8
+ mel_fmax: null
9
+ num_mels: 128
10
+
11
+ encoder:
12
+ input_channels: 1024
13
+ vocos_dim: 384
14
+ vocos_intermediate_dim: 2048
15
+ vocos_num_layers: 12
16
+ out_channels: 1024
17
+ sample_ratios: [1,1]
18
+
19
+ decoder:
20
+ input_channel: 1024
21
+ channels: 1536
22
+ rates: [8, 5, 4, 2]
23
+ kernel_sizes: [16,11,8,4]
24
+
25
+ quantizer:
26
+ input_dim: 1024
27
+ codebook_size: 8192
28
+ codebook_dim: 8
29
+ commitment: 0.25
30
+ codebook_loss_weight: 2.0
31
+ use_l2_normlize: True
32
+ threshold_ema_dead_code: 0.2
33
+
34
+ speaker_encoder:
35
+ input_dim: 128
36
+ out_dim: 1024
37
+ latent_dim: 128
38
+ token_num: 32
39
+ fsq_levels: [4, 4, 4, 4, 4, 4]
40
+ fsq_num_quantizers: 1
41
+
42
+ prenet:
43
+ input_channels: 1024
44
+ vocos_dim: 384
45
+ vocos_intermediate_dim: 2048
46
+ vocos_num_layers: 12
47
+ out_channels: 1024
48
+ condition_dim: 1024
49
+ sample_ratios: [1,1]
50
+ use_tanh_at_final: False
51
+
52
+ postnet:
53
+ input_channels: 1024
54
+ vocos_dim: 384
55
+ vocos_intermediate_dim: 2048
56
+ vocos_num_layers: 6
57
+ out_channels: 1024
58
+ use_tanh_at_final: False
59
+
60
+
trained_30_percents/BiCodec/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
3
+ size 625518756
trained_30_percents/Readme.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # ReSpark TTS Model
6
+
7
+ This repository contains the ReSpark Text-to-Speech (TTS) model, a powerful and efficient model for generating high-quality speech from text. It is based on the RWKV architecture and utilizes the BiCodec tokenizer for audio processing.
8
+
9
+ ## Installation
10
+
11
+ First, install the required dependencies:
12
+
13
+ ```bash
14
+ pip install transformers rwkv-fla torch torchaudio torchvision transformers soundfile numpy librosa omegaconf soxr soundfile einx librosa
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ The `tts.py` script provides a complete example of how to use this model for text-to-speech synthesis with voice cloning.
20
+
21
+ ### Running the Test Script
22
+
23
+ To generate speech, simply run the script:
24
+
25
+ ```bash
26
+ python tts.py
27
+ ```
28
+
29
+ ### How it Works
30
+
31
+ The script performs the following steps:
32
+ 1. Loads the pre-trained `AutoModelForCausalLM` and `AutoTokenizer` from the current directory.
33
+ 2. Initializes the `BiCodecTokenizer` for audio encoding and decoding.
34
+ 3. Loads a reference audio file (`kafka.wav`) and its corresponding transcript (`prompt_text`) to provide a voice prompt.
35
+ 4. Resamples the reference audio to match the model's expected sample rate (24000 Hz).
36
+ 5. Takes a target text (`text`) to be synthesized.
37
+ 6. Calls the `generate_speech` function, which generates audio based on the target text and the voice from the reference audio.
38
+ 7. Saves the generated audio to `output.wav`.
39
+
40
+ You can modify the `prompt_text`, `prompt_audio_file`, and `text` variables in `tts.py` to synthesize different text with different voices.
41
+
42
+ ### Example Code (`tts.py`)
43
+
44
+ ```python
45
+ import os
46
+ import sys
47
+ current_dir = os.path.dirname(os.path.abspath(__file__))
48
+ print('add current dir to sys.path', current_dir)
49
+ sys.path.append(current_dir)
50
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
51
+ from transformers import AutoTokenizer, AutoModelForCausalLM
52
+ import soundfile as sf
53
+ import numpy as np
54
+ import torch
55
+ from utilities import generate_embeddings
56
+
57
+ def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
58
+ max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
59
+ temperature=1.0, device="cuda:0"):
60
+ """
61
+ Function to generate speech.
62
+ """
63
+ eos_token_id = model.config.vocab_size - 1
64
+
65
+ embeddings = generate_embeddings(
66
+ model=model,
67
+ tokenizer=tokenizer,
68
+ text=text,
69
+ bicodec=bicodec,
70
+ prompt_text=prompt_text,
71
+ prompt_audio=prompt_audio
72
+ )
73
+
74
+ global_tokens = embeddings['global_tokens'].unsqueeze(0)
75
+ model.eval()
76
+
77
+ with torch.no_grad():
78
+ generated_outputs = model.generate(
79
+ inputs_embeds=embeddings['input_embs'],
80
+ attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
81
+ max_new_tokens=max_new_tokens,
82
+ do_sample=do_sample,
83
+ top_k=top_k,
84
+ top_p=top_p,
85
+ temperature=temperature,
86
+ eos_token_id=eos_token_id,
87
+ pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
88
+ use_cache=True
89
+ )
90
+
91
+ semantic_tokens_tensor = generated_outputs[:,:-1]
92
+
93
+ with torch.no_grad():
94
+ wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
95
+
96
+ return wav
97
+
98
+ # --- Main execution ---
99
+ device = 'cuda:0'
100
+
101
+ # Initialize tokenizers and model
102
+ audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
103
+ tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
104
+ model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
105
+
106
+ model = model.bfloat16().to(device)
107
+ model.eval()
108
+
109
+ # Prepare prompt audio and text for voice cloning
110
+ prompt_text = "我们并不是通过物理移动手段找到星河的。"
111
+ prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
112
+ prompt_audio, sampling_rate = sf.read(prompt_audio_file)
113
+
114
+ # Resample audio if necessary
115
+ target_sample_rate = audio_tokenizer.config['sample_rate']
116
+ if sampling_rate != target_sample_rate:
117
+ from librosa import resample
118
+ prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
119
+ prompt_audio = np.array(prompt_audio, dtype=np.float32)
120
+
121
+ # Text to synthesize
122
+ text = "科学技术是第一生产力,最近 AI的迅猛发展让我们看到了迈向星辰大海的希望。"
123
+
124
+ # Generate speech
125
+ wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
126
+
127
+ # Save the output
128
+ sf.write('output.wav', wav, target_sample_rate)
129
+ print("Generated audio saved to output.wav")
130
+ ```
trained_30_percents/Readme_zh.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # ReSpark TTS 模型
6
+
7
+ 本仓库包含 ReSpark 文本转语音 (TTS) 模型,这是一个强大而高效的模型,可以从文本生成高质量的语音。它基于 RWKV 架构,并利用 BiCodec-Tokenizer 进行音频处理。
8
+
9
+ ## 安装
10
+
11
+ 首先,请安装所需的依赖库:
12
+
13
+ ```bash
14
+ pip install transformers rwkv-fla torch torchaudio torchvision transformers soundfile numpy librosa omegaconf soxr soundfile einx librosa
15
+ ```
16
+
17
+ ## 使用方法
18
+
19
+ `tts.py` 脚本提供了一个完整的使用该模型进行文本转语音合成(带声音克隆功能)的示例。
20
+
21
+ ### 运行测试脚本
22
+
23
+ 要生成语音,只需运行以下脚本:
24
+
25
+ ```bash
26
+ python tts.py
27
+ ```
28
+
29
+ ### 工作原理
30
+
31
+ 该脚本执行以下步骤:
32
+ 1. 从当前目录加载预训练的 `AutoModelForCausalLM` 和 `AutoTokenizer`。
33
+ 2. 初始化用于音频编码和解码的 `BiCodecTokenizer`。
34
+ 3. 加载一个参考音频文件 (`kafka.wav`) 及其对应的文本 (`prompt_text`) 以提供声音提示(voice prompt)。
35
+ 4. 如果需要,将参考音频重采样以匹配模型期望的采样率 (24000 Hz)。
36
+ 5. 指定一个需要被合成的目标文本 (`text`)。
37
+ 6. 调用 `generate_speech` 函数,该函数会根据目标文本和参考音频中的声音生成音频。
38
+ 7. 将生成的音频保存到 `output.wav`。
39
+
40
+ 您可以修改 `tts.py` 文件中的 `prompt_text`、`prompt_audio_file` 和 `text` 变量,以使用不同的声音合成不同的文本。
41
+
42
+ ### 示例代码 (`tts.py`)
43
+
44
+ ```python
45
+ import os
46
+ import sys
47
+ current_dir = os.path.dirname(os.path.abspath(__file__))
48
+ print('add current dir to sys.path', current_dir)
49
+ sys.path.append(current_dir)
50
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
51
+ from transformers import AutoTokenizer, AutoModelForCausalLM
52
+ import soundfile as sf
53
+ import numpy as np
54
+ import torch
55
+ from utilities import generate_embeddings
56
+
57
+ def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
58
+ max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
59
+ temperature=1.0, device="cuda:0"):
60
+ """
61
+ 生成语音的函数
62
+ """
63
+ eos_token_id = model.config.vocab_size - 1
64
+
65
+ embeddings = generate_embeddings(
66
+ model=model,
67
+ tokenizer=tokenizer,
68
+ text=text,
69
+ bicodec=bicodec,
70
+ prompt_text=prompt_text,
71
+ prompt_audio=prompt_audio
72
+ )
73
+
74
+ global_tokens = embeddings['global_tokens'].unsqueeze(0)
75
+ model.eval()
76
+
77
+ with torch.no_grad():
78
+ generated_outputs = model.generate(
79
+ inputs_embeds=embeddings['input_embs'],
80
+ attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
81
+ max_new_tokens=max_new_tokens,
82
+ do_sample=do_sample,
83
+ top_k=top_k,
84
+ top_p=top_p,
85
+ temperature=temperature,
86
+ eos_token_id=eos_token_id,
87
+ pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
88
+ use_cache=True
89
+ )
90
+
91
+ semantic_tokens_tensor = generated_outputs[:,:-1]
92
+
93
+ with torch.no_grad():
94
+ wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
95
+
96
+ return wav
97
+
98
+ # --- 主程序 ---
99
+ device = 'cuda:0'
100
+
101
+ # 初始化分词器和模型
102
+ audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
103
+ tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
104
+ model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
105
+
106
+ model = model.bfloat16().to(device)
107
+ model.eval()
108
+
109
+ # 准备用于声音克隆的提示音频和文本
110
+ prompt_text = "我们并不是通过物理移动手段找到星河的。"
111
+ prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
112
+ prompt_audio, sampling_rate = sf.read(prompt_audio_file)
113
+
114
+ # 如果需要,重采样音频
115
+ target_sample_rate = audio_tokenizer.config['sample_rate']
116
+ if sampling_rate != target_sample_rate:
117
+ from librosa import resample
118
+ prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
119
+ prompt_audio = np.array(prompt_audio, dtype=np.float32)
120
+
121
+ # 要合成的文本
122
+ text = "科学技术是第一生产力,最近 AI的迅猛发展让我们看到了迈向星辰大海的希望。"
123
+
124
+ # 生成语音
125
+ wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
126
+
127
+ # 保存输出
128
+ sf.write('output.wav', wav, target_sample_rate)
129
+ print("生成的音频已保存到 output.wav")
130
+ ```
trained_30_percents/__init__.py ADDED
File without changes
trained_30_percents/__pycache__/spark_llm.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
trained_30_percents/__pycache__/utilities.cpython-311.pyc ADDED
Binary file (3.79 kB). View file
 
trained_30_percents/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|rwkv_tokenizer_end_of_text|>": 0
3
+ }
trained_30_percents/config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "a_low_rank_dim": 64,
3
+ "architectures": [
4
+ "RWKV7ForSpeech"
5
+ ],
6
+ "attn": null,
7
+ "attn_mode": "chunk",
8
+ "audio_global_vocab_size": 4096,
9
+ "auto_map": {
10
+ "AutoConfig": "modeling_rwkvspeech.RWKV7SpeechConfig",
11
+ "AutoModel": "modeling_rwkvspeech.RWKV7Model",
12
+ "AutoModelForCausalLM": "modeling_rwkvspeech.RWKV7ForSpeech"
13
+ },
14
+ "bos_token_id": 0,
15
+ "decay_low_rank_dim": 64,
16
+ "eos_token_id": 0,
17
+ "fuse_cross_entropy": true,
18
+ "fuse_norm": false,
19
+ "gate_low_rank_dim": 128,
20
+ "head_dim": 64,
21
+ "hidden_act": "sqrelu",
22
+ "hidden_ratio": 4.0,
23
+ "hidden_size": 1024,
24
+ "initializer_range": 0.006,
25
+ "intermediate_size": 4096,
26
+ "max_position_embeddings": 2048,
27
+ "model_type": "rwkv7",
28
+ "norm_bias": true,
29
+ "norm_eps": 1e-05,
30
+ "norm_first": true,
31
+ "num_heads": 32,
32
+ "num_hidden_layers": 24,
33
+ "text_vocab_size": 65536,
34
+ "tie_word_embeddings": false,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.52.4",
37
+ "use_cache": true,
38
+ "v_low_rank_dim": 32,
39
+ "value_dim": [
40
+ 1024,
41
+ 1024,
42
+ 1024,
43
+ 1024,
44
+ 1024,
45
+ 1024,
46
+ 1024,
47
+ 1024,
48
+ 1024,
49
+ 1024,
50
+ 1024,
51
+ 1024,
52
+ 1024,
53
+ 1024,
54
+ 1024,
55
+ 1024,
56
+ 1024,
57
+ 1024,
58
+ 1024,
59
+ 1024,
60
+ 1024,
61
+ 1024,
62
+ 1024,
63
+ 1024
64
+ ],
65
+ "vocab_size": 8193
66
+ }
trained_30_percents/config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ highpass_cutoff_freq: 40
2
+ sample_rate: 16000
3
+ segment_duration: 2.4 # (s)
4
+ max_val_duration: 12 # (s)
5
+ latent_hop_length: 320
6
+ ref_segment_duration: 6
7
+ volume_normalize: true
trained_30_percents/configuration_rwkv7.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV7Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv7'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ hidden_ratio: Optional[int] = 4,
18
+ intermediate_size: Optional[int] = None,
19
+ num_hidden_layers: int = 24,
20
+ head_dim: Optional[int] = 64,
21
+ num_heads: Optional[int] = None,
22
+ decay_low_rank_dim: int = 64,
23
+ gate_low_rank_dim: int = 128,
24
+ a_low_rank_dim: int = 64,
25
+ v_low_rank_dim: int = 16,
26
+ hidden_act: str = "sqrelu",
27
+ max_position_embeddings: int = 2048,
28
+ norm_first: bool = True,
29
+ norm_bias: bool = True,
30
+ norm_eps: float = 1e-5,
31
+ attn: Optional[Dict] = None,
32
+ use_cache: bool = True,
33
+ pad_token_id: int = None,
34
+ bos_token_id: int = 1,
35
+ eos_token_id: int = 2,
36
+ tie_word_embeddings: bool = False,
37
+ initializer_range: float = 0.006,
38
+ fuse_norm: bool = True,
39
+ fuse_cross_entropy: bool = True,
40
+ vocab_size: int = 32000,
41
+ **kwargs
42
+ ):
43
+ self.attn_mode = attn_mode
44
+ self.hidden_size = hidden_size
45
+ self.hidden_ratio = hidden_ratio
46
+ self.intermediate_size = intermediate_size
47
+ self.norm_first = norm_first
48
+ self.num_hidden_layers = num_hidden_layers
49
+
50
+ if head_dim is None and num_heads is not None:
51
+ head_dim = int(hidden_size // num_heads)
52
+ elif head_dim is not None and num_heads is None:
53
+ num_heads = int(hidden_size // head_dim)
54
+
55
+ self.head_dim = head_dim
56
+ self.num_heads = num_heads
57
+
58
+ self.decay_low_rank_dim = decay_low_rank_dim
59
+ self.gate_low_rank_dim = gate_low_rank_dim
60
+ self.a_low_rank_dim = a_low_rank_dim
61
+ self.v_low_rank_dim = v_low_rank_dim
62
+ self.hidden_act = hidden_act
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.norm_bias = norm_bias
65
+ self.norm_eps = norm_eps
66
+ self.attn = attn
67
+ self.use_cache = use_cache
68
+ self.initializer_range = initializer_range
69
+ self.fuse_norm = fuse_norm
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
trained_30_percents/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.52.4"
6
+ }
trained_30_percents/hf_rwkv_tokenizer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RWKV."""
16
+
17
+ import os
18
+ import re
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple
20
+
21
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from transformers.utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ pass
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "rwkv_vocab_v20230424.txt",
33
+ }
34
+
35
+ class TRIE:
36
+ __slots__ = tuple("ch,to,values,front".split(","))
37
+ to: list
38
+ values: set
39
+
40
+ def __init__(self, front=None, ch=None):
41
+ self.ch = ch
42
+ self.to = [None for ch in range(256)]
43
+ self.values = set()
44
+ self.front = front
45
+
46
+ def __repr__(self):
47
+ fr = self
48
+ ret = []
49
+ while fr != None:
50
+ if fr.ch != None:
51
+ ret.append(fr.ch)
52
+ fr = fr.front
53
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
54
+
55
+ def add(self, key: bytes, idx: int = 0, val=None):
56
+ if idx == len(key):
57
+ if val is None:
58
+ val = key
59
+ self.values.add(val)
60
+ return self
61
+ ch = key[idx]
62
+ if self.to[ch] is None:
63
+ self.to[ch] = TRIE(front=self, ch=ch)
64
+ return self.to[ch].add(key, idx=idx + 1, val=val)
65
+
66
+ def find_longest(self, key: bytes, idx: int = 0):
67
+ u: TRIE = self
68
+ ch: int = key[idx]
69
+
70
+ while u.to[ch] is not None:
71
+ u = u.to[ch]
72
+ idx += 1
73
+ if u.values:
74
+ ret = idx, u, u.values
75
+ if idx == len(key):
76
+ break
77
+ ch = key[idx]
78
+ return ret
79
+
80
+
81
+ class RWKV_TOKENIZER:
82
+ def __init__(self, file_name):
83
+ self.idx2token = {}
84
+ sorted = [] # must be already sorted
85
+ with open(file_name, "r", encoding="utf-8") as f:
86
+ lines = f.readlines()
87
+ for l in lines:
88
+ idx = int(l[: l.index(" ")])
89
+ x = eval(l[l.index(" ") : l.rindex(" ")])
90
+ x = x.encode("utf-8") if isinstance(x, str) else x
91
+ assert isinstance(x, bytes)
92
+
93
+ assert len(x) == int(l[l.rindex(" ") :])
94
+ sorted += [x]
95
+ self.idx2token[idx] = x
96
+
97
+ self.token2idx = {}
98
+ for k, v in self.idx2token.items():
99
+ self.token2idx[v] = int(k)
100
+
101
+ self.root = TRIE()
102
+ for t, i in self.token2idx.items():
103
+ _ = self.root.add(t, val=(t, i))
104
+
105
+ def encodeBytes(self, src: bytes):
106
+ idx: int = 0
107
+ tokens = []
108
+ while idx < len(src):
109
+ _idx: int = idx
110
+ idx, _, values = self.root.find_longest(src, idx)
111
+ assert idx != _idx
112
+ _, token = next(iter(values))
113
+ tokens.append(token)
114
+ return tokens
115
+
116
+ def decodeBytes(self, tokens):
117
+ return b"".join(map(lambda i: self.idx2token[i], tokens))
118
+
119
+ def encode(self, src):
120
+ if isinstance(src, str):
121
+ return [self.encodeBytes(src.encode("utf-8"))]
122
+ elif isinstance(src, list):
123
+ return [self.encodeBytes(s.encode("utf-8")) for s in src]
124
+
125
+ def decode(self, tokens):
126
+ return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
127
+ # try:
128
+ # return self.decodeBytes(tokens).decode('utf-8')
129
+ # except:
130
+ # return '\ufffd' # bad utf-8
131
+
132
+ def printTokens(self, tokens):
133
+ for i in tokens:
134
+ s = self.idx2token[i]
135
+ try:
136
+ s = s.decode("utf-8")
137
+ except:
138
+ pass
139
+ print(f"{repr(s)}{i}", end=" ")
140
+ print()
141
+
142
+
143
+ class RwkvTokenizer(PreTrainedTokenizer):
144
+ vocab_files_names = VOCAB_FILES_NAMES
145
+ model_input_names = ["input_ids", "attention_mask"]
146
+
147
+ def __init__(
148
+ self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
149
+ ):
150
+ if not os.path.isfile(vocab_file):
151
+ raise ValueError(
152
+ f"Can't find a vocabulary file at path '{vocab_file}'."
153
+ )
154
+
155
+ with open(vocab_file, "r", encoding="utf-8") as reader:
156
+ tokens = reader.readlines()
157
+
158
+ if "add_bos_token" in kwargs:
159
+ self.add_bos_token = kwargs["add_bos_token"]
160
+ else:
161
+ self.add_bos_token = False
162
+ self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
163
+ vocab = self.trie_tokenizer.token2idx
164
+ self.encoder = vocab
165
+ self.decoder = {v: k for k, v in vocab.items()}
166
+ self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
167
+ super().__init__(
168
+ bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
169
+ )
170
+
171
+ @property
172
+ def vocab_size(self):
173
+ return len(self.encoder)
174
+
175
+ def get_vocab(self):
176
+ vocab = self.encoder
177
+ vocab.update(self.added_tokens_encoder)
178
+ vocab = dict(sorted(vocab.items(), key=lambda item: item[1]))
179
+ return vocab
180
+
181
+ def _tokenize(self, text, split_special_tokens=False):
182
+ # return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
183
+ return self.trie_tokenizer.encode(text)[0]
184
+
185
+ def _convert_token_to_id(self, token):
186
+ return token
187
+
188
+ def _convert_id_to_token(self, index):
189
+ """Converts an index (integer) in a token (byte) using the vocab."""
190
+ token = self.decoder.get(index, self.unk_token)
191
+ if isinstance(token, (bytes)):
192
+ token = token.decode("utf-8", errors="replace")
193
+ return token
194
+
195
+ def convert_tokens_to_string(self, tokens):
196
+ """Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
197
+ out_string = b"".join(
198
+ [k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
199
+ ).decode("utf-8")
200
+ return out_string
201
+
202
+ def save_vocabulary(
203
+ self, save_directory: str, filename_prefix: Optional[str] = None
204
+ ) -> Tuple[str]:
205
+ index = 0
206
+ if os.path.isdir(save_directory):
207
+ vocab_file = os.path.join(
208
+ save_directory,
209
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
210
+ )
211
+ else:
212
+ vocab_file = (
213
+ filename_prefix + "-" if filename_prefix else ""
214
+ ) + save_directory
215
+ with open(vocab_file, "w", encoding="utf-8") as writer:
216
+ for token, token_index in sorted(
217
+ self.encoder.items(), key=lambda kv: kv[1]
218
+ ):
219
+ if index != token_index:
220
+ logger.warning(
221
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
222
+ " Please check that the vocabulary is not corrupted!"
223
+ )
224
+ index = token_index
225
+ writer.write(str(token) + "\n")
226
+ index += 1
227
+ return (vocab_file,)
228
+
229
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
230
+ if self.add_bos_token:
231
+ bos_token_ids = [self.bos_token_id]
232
+ else:
233
+ bos_token_ids = []
234
+
235
+ output = bos_token_ids + token_ids_0
236
+
237
+ if token_ids_1 is None:
238
+ return output
239
+
240
+ return output + bos_token_ids + token_ids_1
241
+
242
+ def get_special_tokens_mask(
243
+ self,
244
+ token_ids_0: List[int],
245
+ token_ids_1: Optional[List[int]] = None,
246
+ already_has_special_tokens: bool = False,
247
+ ) -> List[int]:
248
+ """
249
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
250
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
251
+
252
+ Args:
253
+ token_ids_0 (`List[int]`):
254
+ List of IDs.
255
+ token_ids_1 (`List[int]`, *optional*):
256
+ Optional second list of IDs for sequence pairs.
257
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not the token list is already formatted with special tokens for the model.
259
+
260
+ Returns:
261
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
262
+ """
263
+ if already_has_special_tokens:
264
+ return super().get_special_tokens_mask(
265
+ token_ids_0=token_ids_0,
266
+ token_ids_1=token_ids_1,
267
+ already_has_special_tokens=True,
268
+ )
269
+
270
+ if not self.add_bos_token:
271
+ return super().get_special_tokens_mask(
272
+ token_ids_0=token_ids_0,
273
+ token_ids_1=token_ids_1,
274
+ already_has_special_tokens=False,
275
+ )
276
+
277
+ if token_ids_1 is None:
278
+ return [1] + ([0] * len(token_ids_0))
279
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
280
+
trained_30_percents/kafka.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7928aeaf90600d6a014a5fececdc59cdf0e2971db327a0cf56b922b7cd8f8a7
3
+ size 265524
trained_30_percents/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7c0f6731a59dd80dbfdf692ce77e11a65af53b9c55795285aa5cfee11a97fae
3
+ size 809355976
trained_30_percents/modeling_rwkvspeech.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from spark_llm import RWKV7SpeechConfig,RWKV7ForSpeech
2
+ from rwkvfla.models.rwkv7 import RWKV7Model
3
+
4
+ RWKV7ForCausalLM = RWKV7ForSpeech
5
+ RWKV7Model = RWKV7Model
6
+ RWKV7Config = RWKV7SpeechConfig
trained_30_percents/output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54c3ce8433267b9d2e8812fed1c87af29cdcc66a017b3700016b40b243930a34
3
+ size 180524
trained_30_percents/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
trained_30_percents/spark_llm.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Union, Tuple, Dict, Unpack
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from transformers.utils.deprecation import deprecate_kwarg
7
+ from rwkvfla.models.rwkv7.modeling_rwkv7 import RWKV7Model, RWKV7PreTrainedModel, Cache,RWKV7ForCausalLM
8
+ from rwkvfla.models.rwkv7.modeling_rwkv7 import FusedLinearCrossEntropyLoss, FusedCrossEntropyLoss
9
+ from transformers.generation.utils import GenerationMixin
10
+
11
+ from rwkvfla.models.rwkv7.configuration_rwkv7 import RWKV7Config
12
+
13
+ class RWKV7SpeechConfig(RWKV7Config):
14
+ def __init__(self, **kwargs):
15
+ super().__init__(**kwargs)
16
+ self.text_vocab_size = kwargs.get("text_vocab_size", kwargs.get("text_vocab_size"))
17
+ self.audio_global_vocab_size = kwargs.get("audio_global_vocab_size", kwargs.get("audio_global_vocab_size"))
18
+
19
+
20
+ class RWKV7ForSpeech(RWKV7ForCausalLM):
21
+ config_class = RWKV7SpeechConfig
22
+ def __init__(self, config: RWKV7SpeechConfig):
23
+ super().__init__(config)
24
+ self.model = RWKV7Model(config)
25
+ self.vocab_size = config.vocab_size
26
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)#Spark 0.5B vocab size is 8192 + 1 for eos resulting in 8193
27
+ self.criterion = None
28
+ self.text_embedder = nn.Embedding(config.text_vocab_size, config.hidden_size)
29
+ self.global_embedder = nn.Embedding(config.audio_global_vocab_size, config.hidden_size)#Spark 0.5B global token size is 4096
30
+ #TTS Tag includes GLOBAL=0, SEMANTIC=1,START_TTS=2
31
+ self.tts_tag_embedder = nn.Embedding(3, config.hidden_size)
32
+ # Initialize weights and apply final processing
33
+ self.post_init()
34
+ self.dropout = torch.nn.Dropout(0.02)
35
+
36
+ def get_input_embeddings(self):
37
+ return self.model.embeddings
38
+
39
+ def set_input_embeddings(self, value):
40
+ self.model.embeddings = value
41
+
42
+ def get_output_embeddings(self):
43
+ return self.lm_head
44
+
45
+ def set_output_embeddings(self, new_embeddings):
46
+ self.lm_head = new_embeddings
47
+
48
+ def set_decoder(self, decoder):
49
+ self.model = decoder
50
+
51
+ def get_decoder(self):
52
+ return self.model
53
+
54
+ def generate(self, *args, **kwargs):
55
+ try:
56
+ return super().generate(*args, **kwargs)
57
+ except AttributeError as exception:
58
+ if 'past_key_values' in str(exception):
59
+ raise AttributeError(
60
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
61
+ f"which is not supported for {self.__class__.__name__}. "
62
+ f"Try another generation strategy instead. "
63
+ f"For the available generation strategies, check this doc: "
64
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
65
+ )
66
+ else:
67
+ raise exception
68
+
69
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
70
+ def prepare_inputs_for_generation(
71
+ self,
72
+ input_ids: torch.LongTensor = None,
73
+ past_key_values: Optional[Cache] = None,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ inputs_embeds: Optional[torch.Tensor] = None,
76
+ use_cache: bool = True,
77
+ logits_to_keep: Optional[int] = None,
78
+ **kwargs
79
+ ):
80
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
81
+ if past_key_values is not None and len(past_key_values) > 0:
82
+ input_ids = input_ids[:, -1:]
83
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
84
+ if inputs_embeds is not None and len(past_key_values) == 0:
85
+ model_inputs = {'inputs_embeds': inputs_embeds}
86
+ else:
87
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
88
+ # recompiles graphs as the stride of the inputs is a guard.
89
+ # Ref: https://github.com/huggingface/transformers/pull/29114
90
+ # TODO: use `next_tokens` directly instead.
91
+ model_inputs = {'input_ids': input_ids.contiguous()}
92
+
93
+ if logits_to_keep is not None:
94
+ model_inputs['logits_to_keep'] = logits_to_keep
95
+
96
+ model_inputs.update({
97
+ 'past_key_values': past_key_values,
98
+ 'use_cache': use_cache,
99
+ 'attention_mask': attention_mask,
100
+ 'logits_to_keep': logits_to_keep,
101
+ })
102
+ return model_inputs
103
+
104
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
105
+ def forward(
106
+ self,
107
+ input_ids: torch.LongTensor = None,
108
+ attention_mask: Optional[torch.Tensor] = None,
109
+ inputs_embeds: Optional[torch.Tensor] = None,
110
+ past_key_values: Optional[Cache] = None,
111
+ labels: Optional[torch.LongTensor] = None,
112
+ use_cache: Optional[bool] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ logits_to_keep: Optional[int] = 0,
117
+ **kwargs: Unpack[Dict]
118
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
119
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
120
+ output_hidden_states = (
121
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
122
+ )
123
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
124
+ if self.training and inputs_embeds is not None:
125
+ inputs_embeds = self.dropout(inputs_embeds)
126
+ outputs = self.model(
127
+ input_ids=input_ids,
128
+ attention_mask=attention_mask,
129
+ inputs_embeds=inputs_embeds,
130
+ past_key_values=past_key_values,
131
+ use_cache=use_cache,
132
+ output_attentions=output_attentions,
133
+ output_hidden_states=output_hidden_states,
134
+ return_dict=return_dict,
135
+ **kwargs
136
+ )
137
+
138
+ hidden_states = outputs[0]
139
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
140
+
141
+ loss, logits = None, None
142
+ if not fuse_linear_and_cross_entropy or labels is None:
143
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
144
+ if labels is not None:
145
+ if getattr(self, 'criterion', None) is None:
146
+ if fuse_linear_and_cross_entropy:
147
+ criterion = FusedLinearCrossEntropyLoss()
148
+ elif self.config.fuse_cross_entropy:
149
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
150
+ else:
151
+ criterion = nn.CrossEntropyLoss()
152
+ else:
153
+ criterion = self.criterion
154
+ # Enable model parallelism
155
+ labels = labels.to(hidden_states.device)
156
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
157
+ if fuse_linear_and_cross_entropy:
158
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
159
+ else:
160
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
161
+
162
+ if not return_dict:
163
+ output = (logits,) + outputs[1:]
164
+ return (loss,) + output if loss is not None else output
165
+
166
+ return CausalLMOutputWithPast(
167
+ loss=loss,
168
+ logits=logits,
169
+ past_key_values=outputs.past_key_values,
170
+ hidden_states=outputs.hidden_states,
171
+ attentions=outputs.attentions,
172
+ )
173
+
174
+ def copy_state_dict(self, state_dict: dict):
175
+ """从源 state dict 复制参数到当前模型,排除 embeddings 和 lm_head
176
+ The state dict is from original RWKV7 language model
177
+ Args:
178
+ state_dict: 源 state dict
179
+ """
180
+ # 获取当前模型的 state dict
181
+ target_dict = self.state_dict()
182
+
183
+ # 创建新的 state dict 用于存储要复制的参数
184
+ new_state_dict = {}
185
+
186
+ # 遍历源 state dict 的键
187
+ for key in state_dict.keys():
188
+ # 跳过 embeddings 和 lm_head 相关的参数
189
+ if key == 'model.embeddings.weight':
190
+ new_state_dict['text_embedder.weight'] = state_dict[key]
191
+ continue
192
+ if 'embeddings' in key or 'lm_head' in key:
193
+ continue
194
+ # 如果键在当前模型中存在,则复制参数
195
+ if key in target_dict:
196
+ new_state_dict[key] = state_dict[key]
197
+
198
+ # 加载新的 state dict 到当前模型
199
+ info = self.load_state_dict(new_state_dict, strict=False)
200
+ print(info)
201
+ return self
202
+
trained_30_percents/sparktts/models/__pycache__/audio_tokenizer.cpython-311.pyc ADDED
Binary file (8.95 kB). View file
 
trained_30_percents/sparktts/models/__pycache__/bicodec.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
trained_30_percents/sparktts/models/audio_tokenizer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import numpy as np
19
+ from pathlib import Path
20
+ from typing import Any, Dict, Tuple, Optional, Union
21
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
22
+
23
+ from sparktts.utils.file import load_config
24
+ from sparktts.utils.audio import load_audio
25
+ from sparktts.models.bicodec import BiCodec
26
+
27
+
28
+ class BiCodecTokenizer:
29
+ """BiCodec tokenizer for handling audio input and tokenization."""
30
+
31
+ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
32
+ super().__init__()
33
+ """
34
+ Args:
35
+ model_dir: Path to the model directory.
36
+ device: Device to run the model on (default is GPU if available).
37
+ """
38
+ self.device = device
39
+ self.model_dir = model_dir
40
+ self.config = load_config(f"{model_dir}/config.yaml")
41
+ self._initialize_model()
42
+
43
+ def _initialize_model(self):
44
+ """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
45
+ self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
46
+ self.device
47
+ )
48
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
49
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
50
+ )
51
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(
52
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
53
+ ).to(self.device)
54
+ self.feature_extractor.config.output_hidden_states = True
55
+
56
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
57
+ """Get reference audio clip for speaker embedding."""
58
+ ref_segment_length = (
59
+ int(self.config["sample_rate"] * self.config["ref_segment_duration"])
60
+ // self.config["latent_hop_length"]
61
+ * self.config["latent_hop_length"]
62
+ )
63
+ wav_length = len(wav)
64
+
65
+ if ref_segment_length > wav_length:
66
+ # Repeat and truncate to handle insufficient length
67
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
68
+
69
+ return wav[:ref_segment_length]
70
+
71
+ def process_audio(self, wav_path: Optional[Union[Path, np.ndarray]]) -> Tuple[np.ndarray, torch.Tensor]:
72
+ """load auido and get reference audio from wav path"""
73
+ if isinstance(wav_path, Path):
74
+ wav = load_audio(
75
+ wav_path,
76
+ sampling_rate=self.config["sample_rate"],
77
+ volume_normalize=self.config["volume_normalize"],
78
+ )
79
+ elif isinstance(wav_path, np.ndarray):
80
+ wav = wav_path
81
+ else:
82
+ raise ValueError(f"Unsupported audio type: {type(wav_path)}")
83
+
84
+ wav_ref = self.get_ref_clip(wav)
85
+
86
+ wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
87
+ return wav, wav_ref
88
+
89
+ def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
90
+ """extract wav2vec2 features"""
91
+ inputs = self.processor(
92
+ wavs,
93
+ sampling_rate=16000,
94
+ return_tensors="pt",
95
+ padding=True,
96
+ output_hidden_states=True,
97
+ ).input_values.to(self.feature_extractor.dtype)
98
+ feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
99
+ feats_mix = (
100
+ feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
101
+ ) / 3
102
+
103
+ return feats_mix
104
+
105
+ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
106
+ """tokenize the batch of audio
107
+
108
+ Args:
109
+ batch:
110
+ wavs (List[np.ndarray]): batch of audio
111
+ ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
112
+
113
+ Returns:
114
+ semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
115
+ global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
116
+ """
117
+ feats = self.extract_wav2vec2_features(batch["wav"])
118
+ batch["feat"] = feats
119
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
120
+
121
+ return global_tokens, semantic_tokens
122
+
123
+ def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """tokenize the audio"""
125
+ wav, ref_wav = self.process_audio(audio_path)
126
+ feat = self.extract_wav2vec2_features(wav)
127
+ batch = {
128
+ "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
129
+ "ref_wav": ref_wav.to(self.device),
130
+ "feat": feat.to(self.device),
131
+ }
132
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
133
+
134
+ return global_tokens, semantic_tokens
135
+
136
+ def detokenize(
137
+ self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
138
+ ) -> np.array:
139
+ """detokenize the tokens to waveform
140
+
141
+ Args:
142
+ global_tokens: global tokens. shape: (batch_size, global_dim)
143
+ semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
144
+
145
+ Returns:
146
+ wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
147
+ """
148
+ global_tokens = global_tokens.unsqueeze(1)
149
+ wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
150
+ return wav_rec.detach().squeeze().cpu().numpy()
151
+
152
+
153
+ # test
154
+ if __name__ == "__main__":
155
+ import soundfile as sf
156
+
157
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
158
+ tokenizer = BiCodecTokenizer(
159
+ model_dir="pretrained_models/Spark-TTS-0.5B",
160
+ device=device,
161
+ )
162
+ wav_path = "example/prompt_audio.wav"
163
+
164
+ global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
165
+
166
+ wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
167
+ sf.write("example/prompt_recon.wav", wav_rec, 16000)
trained_30_percents/sparktts/models/bicodec.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from pathlib import Path
19
+ from typing import Dict, Any
20
+ from omegaconf import DictConfig
21
+ from safetensors.torch import load_file
22
+
23
+ from sparktts.utils.file import load_config
24
+ from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
25
+ from sparktts.modules.encoder_decoder.feat_encoder import Encoder
26
+ from sparktts.modules.encoder_decoder.feat_decoder import Decoder
27
+ from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
28
+ from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
29
+
30
+
31
+ class BiCodec(nn.Module):
32
+ """
33
+ BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
34
+ quantizer, and wave generator.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ mel_params: Dict[str, Any],
40
+ encoder: nn.Module,
41
+ decoder: nn.Module,
42
+ quantizer: nn.Module,
43
+ speaker_encoder: nn.Module,
44
+ prenet: nn.Module,
45
+ postnet: nn.Module,
46
+ **kwargs
47
+ ) -> None:
48
+ """
49
+ Initializes the BiCodec model with the required components.
50
+
51
+ Args:
52
+ mel_params (dict): Parameters for the mel-spectrogram transformer.
53
+ encoder (nn.Module): Encoder module.
54
+ decoder (nn.Module): Decoder module.
55
+ quantizer (nn.Module): Quantizer module.
56
+ speaker_encoder (nn.Module): Speaker encoder module.
57
+ prenet (nn.Module): Prenet network.
58
+ postnet (nn.Module): Postnet network.
59
+ """
60
+ super().__init__()
61
+ self.encoder = encoder
62
+ self.decoder = decoder
63
+ self.quantizer = quantizer
64
+ self.speaker_encoder = speaker_encoder
65
+ self.prenet = prenet
66
+ self.postnet = postnet
67
+ self.init_mel_transformer(mel_params)
68
+
69
+ @classmethod
70
+ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
71
+ """
72
+ Loads the model from a checkpoint.
73
+
74
+ Args:
75
+ model_dir (Path): Path to the model directory containing checkpoint and config.
76
+
77
+ Returns:
78
+ BiCodec: The initialized BiCodec model.
79
+ """
80
+ ckpt_path = f'{model_dir}/model.safetensors'
81
+ config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
82
+ mel_params = config["mel_params"]
83
+ encoder = Encoder(**config["encoder"])
84
+ quantizer = FactorizedVectorQuantize(**config["quantizer"])
85
+ prenet = Decoder(**config["prenet"])
86
+ postnet = Decoder(**config["postnet"])
87
+ decoder = WaveGenerator(**config["decoder"])
88
+ speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
89
+
90
+ model = cls(
91
+ mel_params=mel_params,
92
+ encoder=encoder,
93
+ decoder=decoder,
94
+ quantizer=quantizer,
95
+ speaker_encoder=speaker_encoder,
96
+ prenet=prenet,
97
+ postnet=postnet,
98
+ )
99
+
100
+ state_dict = load_file(ckpt_path)
101
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
102
+
103
+ for key in missing_keys:
104
+ print(f"Missing tensor: {key}")
105
+ for key in unexpected_keys:
106
+ print(f"Unexpected tensor: {key}")
107
+
108
+ model.eval()
109
+ model.remove_weight_norm()
110
+
111
+ return model
112
+
113
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
114
+ """
115
+ Performs a forward pass through the model.
116
+
117
+ Args:
118
+ batch (dict): A dictionary containing features, reference waveform, and target waveform.
119
+
120
+ Returns:
121
+ dict: A dictionary containing the reconstruction, features, and other metrics.
122
+ """
123
+ feat = batch["feat"]
124
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
125
+
126
+ z = self.encoder(feat.transpose(1, 2))
127
+ vq_outputs = self.quantizer(z)
128
+
129
+ x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
130
+
131
+ conditions = d_vector
132
+ with_speaker_loss = False
133
+
134
+ x = self.prenet(vq_outputs["z_q"], conditions)
135
+ pred_feat = self.postnet(x)
136
+ x = x + conditions.unsqueeze(-1)
137
+ wav_recon = self.decoder(x)
138
+
139
+ return {
140
+ "vq_loss": vq_outputs["vq_loss"],
141
+ "perplexity": vq_outputs["perplexity"],
142
+ "cluster_size": vq_outputs["active_num"],
143
+ "recons": wav_recon,
144
+ "pred_feat": pred_feat,
145
+ "x_vector": x_vector,
146
+ "d_vector": d_vector,
147
+ "audios": batch["wav"].unsqueeze(1),
148
+ "with_speaker_loss": with_speaker_loss,
149
+ }
150
+
151
+ @torch.no_grad()
152
+ def tokenize(self, batch: Dict[str, Any]):
153
+ """
154
+ Tokenizes the input audio into semantic and global tokens.
155
+
156
+ Args:
157
+ batch (dict): The input audio features and reference waveform.
158
+
159
+ Returns:
160
+ tuple: Semantic tokens and global tokens.
161
+ """
162
+ feat = batch["feat"]
163
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
164
+
165
+ z = self.encoder(feat.transpose(1, 2))
166
+ semantic_tokens = self.quantizer.tokenize(z)
167
+ global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
168
+
169
+ return semantic_tokens, global_tokens
170
+
171
+ @torch.no_grad()
172
+ def detokenize(self, semantic_tokens, global_tokens):
173
+ """
174
+ Detokenizes the semantic and global tokens into a waveform.
175
+
176
+ Args:
177
+ semantic_tokens (tensor): Semantic tokens.
178
+ global_tokens (tensor): Global tokens.
179
+
180
+ Returns:
181
+ tensor: Reconstructed waveform.
182
+ """
183
+ z_q = self.quantizer.detokenize(semantic_tokens)
184
+ d_vector = self.speaker_encoder.detokenize(global_tokens)
185
+ x = self.prenet(z_q, d_vector)
186
+ x = x + d_vector.unsqueeze(-1)
187
+ wav_recon = self.decoder(x)
188
+
189
+ return wav_recon
190
+
191
+ def init_mel_transformer(self, config: Dict[str, Any]):
192
+ """
193
+ Initializes the MelSpectrogram transformer based on the provided configuration.
194
+
195
+ Args:
196
+ config (dict): Configuration parameters for MelSpectrogram.
197
+ """
198
+ import torchaudio.transforms as TT
199
+
200
+ self.mel_transformer = TT.MelSpectrogram(
201
+ config["sample_rate"],
202
+ config["n_fft"],
203
+ config["win_length"],
204
+ config["hop_length"],
205
+ config["mel_fmin"],
206
+ config["mel_fmax"],
207
+ n_mels=config["num_mels"],
208
+ power=1,
209
+ norm="slaney",
210
+ mel_scale="slaney",
211
+ )
212
+
213
+ def remove_weight_norm(self):
214
+ """Removes weight normalization from all layers."""
215
+ def _remove_weight_norm(m):
216
+ try:
217
+ torch.nn.utils.remove_weight_norm(m)
218
+ except ValueError:
219
+ pass # The module didn't have weight norm
220
+
221
+ self.apply(_remove_weight_norm)
222
+
223
+
224
+ # Test the model
225
+ if __name__ == "__main__":
226
+
227
+ config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
228
+ model = BiCodec.load_from_checkpoint(
229
+ model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
230
+ )
231
+
232
+ # Generate random inputs for testing
233
+ duration = 0.96
234
+ x = torch.randn(20, 1, int(duration * 16000))
235
+ feat = torch.randn(20, int(duration * 50), 1024)
236
+ inputs = {"feat": feat, "wav": x, "ref_wav": x}
237
+
238
+ # Forward pass
239
+ outputs = model(inputs)
240
+ semantic_tokens, global_tokens = model.tokenize(inputs)
241
+ wav_recon = model.detokenize(semantic_tokens, global_tokens)
242
+
243
+ # Verify if the reconstruction matches
244
+ if torch.allclose(outputs["recons"].detach(), wav_recon):
245
+ print("Test successful")
246
+ else:
247
+ print("Test failed")
trained_30_percents/sparktts/modules/blocks/__pycache__/layers.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
trained_30_percents/sparktts/modules/blocks/__pycache__/samper.cpython-311.pyc ADDED
Binary file (4.52 kB). View file
 
trained_30_percents/sparktts/modules/blocks/__pycache__/vocos.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
trained_30_percents/sparktts/modules/blocks/layers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ class ResidualUnit(nn.Module):
52
+ def __init__(self, dim: int = 16, dilation: int = 1):
53
+ super().__init__()
54
+ pad = ((7 - 1) * dilation) // 2
55
+ self.block = nn.Sequential(
56
+ Snake1d(dim),
57
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
58
+ Snake1d(dim),
59
+ WNConv1d(dim, dim, kernel_size=1),
60
+ )
61
+
62
+ def forward(self, x):
63
+ y = self.block(x)
64
+ pad = (x.shape[-1] - y.shape[-1]) // 2
65
+ if pad > 0:
66
+ x = x[..., pad:-pad]
67
+ return x + y
68
+
69
+
70
+ def init_weights(m):
71
+ if isinstance(m, nn.Conv1d):
72
+ nn.init.trunc_normal_(m.weight, std=0.02)
73
+ nn.init.constant_(m.bias, 0)
trained_30_percents/sparktts/modules/blocks/samper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class SamplingBlock(nn.Module):
23
+ """Sampling block for upsampling or downsampling"""
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ groups: int = 1,
29
+ upsample_scale: int = 1,
30
+ downsample_scale: int = 1,
31
+ ) -> None:
32
+ """
33
+ Args:
34
+ dim: input dimension
35
+ groups: number of groups
36
+ upsample_scale: upsampling scale
37
+ downsample_scale: downsampling scale
38
+ """
39
+ super(SamplingBlock, self).__init__()
40
+
41
+ self.upsample_scale = upsample_scale
42
+ self.downsample_scale = downsample_scale
43
+
44
+ if self.upsample_scale > 1:
45
+ self.de_conv_upsampler = nn.Sequential(
46
+ nn.LeakyReLU(0.2),
47
+ nn.ConvTranspose1d(
48
+ dim,
49
+ dim,
50
+ kernel_size=upsample_scale * 2,
51
+ stride=upsample_scale,
52
+ padding=upsample_scale // 2 + upsample_scale % 2,
53
+ output_padding=upsample_scale % 2,
54
+ groups=groups,
55
+ ),
56
+ )
57
+
58
+ if self.downsample_scale > 1:
59
+ self.conv_downsampler = nn.Sequential(
60
+ nn.LeakyReLU(0.2),
61
+ nn.Conv1d(
62
+ dim,
63
+ dim,
64
+ kernel_size=2 * downsample_scale,
65
+ stride=downsample_scale,
66
+ padding=downsample_scale // 2 + downsample_scale % 2,
67
+ groups=groups,
68
+ ),
69
+ )
70
+
71
+ @staticmethod
72
+ def repeat_upsampler(x, upsample_scale):
73
+ return x.repeat_interleave(upsample_scale, dim=2)
74
+
75
+ @staticmethod
76
+ def skip_downsampler(x, downsample_scale):
77
+ return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
78
+
79
+ def forward(self, x):
80
+ x = x.transpose(1, 2)
81
+ if self.upsample_scale > 1:
82
+ repeat_res = self.repeat_upsampler(x, self.upsample_scale)
83
+ deconv_res = self.de_conv_upsampler(x)
84
+ upmerge_res = repeat_res + deconv_res
85
+ else:
86
+ upmerge_res = x
87
+ repeat_res = x
88
+
89
+ if self.downsample_scale > 1:
90
+ conv_res = self.conv_downsampler(upmerge_res)
91
+ skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
92
+ skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
93
+ else:
94
+ conv_res = upmerge_res
95
+ skip2_res = upmerge_res
96
+ skip1_res = repeat_res
97
+
98
+ final_res = conv_res + skip1_res + skip2_res
99
+
100
+ return final_res
101
+
102
+
103
+ # test
104
+ if __name__ == "__main__":
105
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
106
+ model = SamplingBlock(1024, 1024, upsample_scale=2)
107
+ model_down = SamplingBlock(1024, 1024, downsample_scale=2)
108
+ output = model(test_input)
109
+ output_down = model_down(test_input)
110
+ print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
111
+ print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
112
+ if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
113
+ [8, 1024, 25]
114
+ ):
115
+ print("test successful")
trained_30_percents/sparktts/modules/blocks/vocos.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import Tuple
21
+ from torch.nn.utils import weight_norm, remove_weight_norm
22
+
23
+ from typing import Optional
24
+
25
+
26
+ class ConvNeXtBlock(nn.Module):
27
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
28
+
29
+ Args:
30
+ dim (int): Number of input channels.
31
+ intermediate_dim (int): Dimensionality of the intermediate layer.
32
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
33
+ Defaults to None.
34
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
35
+ None means non-conditional LayerNorm. Defaults to None.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ intermediate_dim: int,
42
+ layer_scale_init_value: float,
43
+ condition_dim: Optional[int] = None,
44
+ ):
45
+ super().__init__()
46
+ self.dwconv = nn.Conv1d(
47
+ dim, dim, kernel_size=7, padding=3, groups=dim
48
+ ) # depthwise conv
49
+ self.adanorm = condition_dim is not None
50
+ if condition_dim:
51
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
52
+ else:
53
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
54
+ self.pwconv1 = nn.Linear(
55
+ dim, intermediate_dim
56
+ ) # pointwise/1x1 convs, implemented with linear layers
57
+ self.act = nn.GELU()
58
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+
65
+ def forward(
66
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
67
+ ) -> torch.Tensor:
68
+ residual = x
69
+ x = self.dwconv(x)
70
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
71
+ if self.adanorm:
72
+ assert cond_embedding_id is not None
73
+ x = self.norm(x, cond_embedding_id)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.pwconv1(x)
77
+ x = self.act(x)
78
+ x = self.pwconv2(x)
79
+ if self.gamma is not None:
80
+ x = self.gamma * x
81
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
82
+
83
+ x = residual + x
84
+ return x
85
+
86
+
87
+ class AdaLayerNorm(nn.Module):
88
+ """
89
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
90
+
91
+ Args:
92
+ condition_dim (int): Dimension of the condition.
93
+ embedding_dim (int): Dimension of the embeddings.
94
+ """
95
+
96
+ def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
97
+ super().__init__()
98
+ self.eps = eps
99
+ self.dim = embedding_dim
100
+ self.scale = nn.Linear(condition_dim, embedding_dim)
101
+ self.shift = nn.Linear(condition_dim, embedding_dim)
102
+ torch.nn.init.ones_(self.scale.weight)
103
+ torch.nn.init.zeros_(self.shift.weight)
104
+
105
+ def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
106
+ scale = self.scale(cond_embedding)
107
+ shift = self.shift(cond_embedding)
108
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
109
+ x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
110
+ return x
111
+
112
+
113
+ class ResBlock1(nn.Module):
114
+ """
115
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
116
+ but without upsampling layers.
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
121
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
122
+ Defaults to (1, 3, 5).
123
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
124
+ Defaults to 0.1.
125
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
126
+ Defaults to None.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ kernel_size: int = 3,
133
+ dilation: Tuple[int, int, int] = (1, 3, 5),
134
+ lrelu_slope: float = 0.1,
135
+ layer_scale_init_value: Optional[float] = None,
136
+ ):
137
+ super().__init__()
138
+ self.lrelu_slope = lrelu_slope
139
+ self.convs1 = nn.ModuleList(
140
+ [
141
+ weight_norm(
142
+ nn.Conv1d(
143
+ dim,
144
+ dim,
145
+ kernel_size,
146
+ 1,
147
+ dilation=dilation[0],
148
+ padding=self.get_padding(kernel_size, dilation[0]),
149
+ )
150
+ ),
151
+ weight_norm(
152
+ nn.Conv1d(
153
+ dim,
154
+ dim,
155
+ kernel_size,
156
+ 1,
157
+ dilation=dilation[1],
158
+ padding=self.get_padding(kernel_size, dilation[1]),
159
+ )
160
+ ),
161
+ weight_norm(
162
+ nn.Conv1d(
163
+ dim,
164
+ dim,
165
+ kernel_size,
166
+ 1,
167
+ dilation=dilation[2],
168
+ padding=self.get_padding(kernel_size, dilation[2]),
169
+ )
170
+ ),
171
+ ]
172
+ )
173
+
174
+ self.convs2 = nn.ModuleList(
175
+ [
176
+ weight_norm(
177
+ nn.Conv1d(
178
+ dim,
179
+ dim,
180
+ kernel_size,
181
+ 1,
182
+ dilation=1,
183
+ padding=self.get_padding(kernel_size, 1),
184
+ )
185
+ ),
186
+ weight_norm(
187
+ nn.Conv1d(
188
+ dim,
189
+ dim,
190
+ kernel_size,
191
+ 1,
192
+ dilation=1,
193
+ padding=self.get_padding(kernel_size, 1),
194
+ )
195
+ ),
196
+ weight_norm(
197
+ nn.Conv1d(
198
+ dim,
199
+ dim,
200
+ kernel_size,
201
+ 1,
202
+ dilation=1,
203
+ padding=self.get_padding(kernel_size, 1),
204
+ )
205
+ ),
206
+ ]
207
+ )
208
+
209
+ self.gamma = nn.ParameterList(
210
+ [
211
+ (
212
+ nn.Parameter(
213
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
214
+ )
215
+ if layer_scale_init_value is not None
216
+ else None
217
+ ),
218
+ (
219
+ nn.Parameter(
220
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
221
+ )
222
+ if layer_scale_init_value is not None
223
+ else None
224
+ ),
225
+ (
226
+ nn.Parameter(
227
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
228
+ )
229
+ if layer_scale_init_value is not None
230
+ else None
231
+ ),
232
+ ]
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
237
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
238
+ xt = c1(xt)
239
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
240
+ xt = c2(xt)
241
+ if gamma is not None:
242
+ xt = gamma * xt
243
+ x = xt + x
244
+ return x
245
+
246
+ def remove_weight_norm(self):
247
+ for l in self.convs1:
248
+ remove_weight_norm(l)
249
+ for l in self.convs2:
250
+ remove_weight_norm(l)
251
+
252
+ @staticmethod
253
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
254
+ return int((kernel_size * dilation - dilation) / 2)
255
+
256
+
257
+ class Backbone(nn.Module):
258
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
259
+
260
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
+ """
262
+ Args:
263
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
264
+ C denotes output features, and L is the sequence length.
265
+
266
+ Returns:
267
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
268
+ and H denotes the model dimension.
269
+ """
270
+ raise NotImplementedError("Subclasses must implement the forward method.")
271
+
272
+
273
+ class VocosBackbone(Backbone):
274
+ """
275
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
276
+
277
+ Args:
278
+ input_channels (int): Number of input features channels.
279
+ dim (int): Hidden dimension of the model.
280
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
281
+ num_layers (int): Number of ConvNeXtBlock layers.
282
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
283
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
284
+ None means non-conditional model. Defaults to None.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ input_channels: int,
290
+ dim: int,
291
+ intermediate_dim: int,
292
+ num_layers: int,
293
+ layer_scale_init_value: Optional[float] = None,
294
+ condition_dim: Optional[int] = None,
295
+ ):
296
+ super().__init__()
297
+ self.input_channels = input_channels
298
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
299
+ self.adanorm = condition_dim is not None
300
+ if condition_dim:
301
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
302
+ else:
303
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
304
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
305
+ self.convnext = nn.ModuleList(
306
+ [
307
+ ConvNeXtBlock(
308
+ dim=dim,
309
+ intermediate_dim=intermediate_dim,
310
+ layer_scale_init_value=layer_scale_init_value,
311
+ condition_dim=condition_dim,
312
+ )
313
+ for _ in range(num_layers)
314
+ ]
315
+ )
316
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
317
+ self.apply(self._init_weights)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
321
+ nn.init.trunc_normal_(m.weight, std=0.02)
322
+ nn.init.constant_(m.bias, 0)
323
+
324
+ def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
325
+ x = self.embed(x)
326
+ if self.adanorm:
327
+ assert condition is not None
328
+ x = self.norm(x.transpose(1, 2), condition)
329
+ else:
330
+ x = self.norm(x.transpose(1, 2))
331
+ x = x.transpose(1, 2)
332
+ for conv_block in self.convnext:
333
+ x = conv_block(x, condition)
334
+ x = self.final_layer_norm(x.transpose(1, 2))
335
+ return x
336
+
337
+
338
+ class VocosResNetBackbone(Backbone):
339
+ """
340
+ Vocos backbone module built with ResBlocks.
341
+
342
+ Args:
343
+ input_channels (int): Number of input features channels.
344
+ dim (int): Hidden dimension of the model.
345
+ num_blocks (int): Number of ResBlock1 blocks.
346
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ input_channels,
352
+ dim,
353
+ num_blocks,
354
+ layer_scale_init_value=None,
355
+ ):
356
+ super().__init__()
357
+ self.input_channels = input_channels
358
+ self.embed = weight_norm(
359
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
360
+ )
361
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
362
+ self.resnet = nn.Sequential(
363
+ *[
364
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
365
+ for _ in range(num_blocks)
366
+ ]
367
+ )
368
+
369
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
370
+ x = self.embed(x)
371
+ x = self.resnet(x)
372
+ x = x.transpose(1, 2)
373
+ return x
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_decoder.cpython-311.pyc ADDED
Binary file (4.26 kB). View file
 
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/feat_encoder.cpython-311.pyc ADDED
Binary file (3.44 kB). View file
 
trained_30_percents/sparktts/modules/encoder_decoder/__pycache__/wave_generator.cpython-311.pyc ADDED
Binary file (3.36 kB). View file
 
trained_30_percents/sparktts/modules/encoder_decoder/feat_decoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Decoder(nn.Module):
27
+ """Decoder module with convnext and upsampling blocks
28
+
29
+ Args:
30
+ sample_ratios (List[int]): sample ratios
31
+ example: [2, 2] means downsample by 2x and then upsample by 2x
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_channels: int,
37
+ vocos_dim: int,
38
+ vocos_intermediate_dim: int,
39
+ vocos_num_layers: int,
40
+ out_channels: int,
41
+ condition_dim: int = None,
42
+ sample_ratios: List[int] = [1, 1],
43
+ use_tanh_at_final: bool = False,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.linear_pre = nn.Linear(input_channels, vocos_dim)
48
+ modules = [
49
+ nn.Sequential(
50
+ SamplingBlock(
51
+ dim=vocos_dim,
52
+ groups=vocos_dim,
53
+ upsample_scale=ratio,
54
+ ),
55
+ VocosBackbone(
56
+ input_channels=vocos_dim,
57
+ dim=vocos_dim,
58
+ intermediate_dim=vocos_intermediate_dim,
59
+ num_layers=2,
60
+ condition_dim=None,
61
+ ),
62
+ )
63
+ for ratio in sample_ratios
64
+ ]
65
+
66
+ self.downsample = nn.Sequential(*modules)
67
+
68
+ self.vocos_backbone = VocosBackbone(
69
+ input_channels=vocos_dim,
70
+ dim=vocos_dim,
71
+ intermediate_dim=vocos_intermediate_dim,
72
+ num_layers=vocos_num_layers,
73
+ condition_dim=condition_dim,
74
+ )
75
+ self.linear = nn.Linear(vocos_dim, out_channels)
76
+ self.use_tanh_at_final = use_tanh_at_final
77
+
78
+ def forward(self, x: torch.Tensor, c: torch.Tensor = None):
79
+ """encoder forward.
80
+
81
+ Args:
82
+ x (torch.Tensor): (batch_size, input_channels, length)
83
+
84
+ Returns:
85
+ x (torch.Tensor): (batch_size, encode_channels, length)
86
+ """
87
+ x = self.linear_pre(x.transpose(1, 2))
88
+ x = self.downsample(x).transpose(1, 2)
89
+ x = self.vocos_backbone(x, condition=c)
90
+ x = self.linear(x).transpose(1, 2)
91
+ if self.use_tanh_at_final:
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+
97
+ # test
98
+ if __name__ == "__main__":
99
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
100
+ condition = torch.randn(8, 256)
101
+ decoder = Decoder(
102
+ input_channels=1024,
103
+ vocos_dim=384,
104
+ vocos_intermediate_dim=2048,
105
+ vocos_num_layers=12,
106
+ out_channels=256,
107
+ condition_dim=256,
108
+ sample_ratios=[2, 2],
109
+ )
110
+ output = decoder(test_input, condition)
111
+ print(output.shape) # torch.Size([8, 256, 200])
112
+ if output.shape == torch.Size([8, 256, 200]):
113
+ print("Decoder test passed")
114
+ else:
115
+ print("Decoder test failed")
trained_30_percents/sparktts/modules/encoder_decoder/feat_encoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Encoder(nn.Module):
27
+ """Encoder module with convnext and downsampling blocks"""
28
+
29
+ def __init__(
30
+ self,
31
+ input_channels: int,
32
+ vocos_dim: int,
33
+ vocos_intermediate_dim: int,
34
+ vocos_num_layers: int,
35
+ out_channels: int,
36
+ sample_ratios: List[int] = [1, 1],
37
+ ):
38
+ super().__init__()
39
+ """
40
+ Encoder module with VocosBackbone and sampling blocks.
41
+
42
+ Args:
43
+ sample_ratios (List[int]): sample ratios
44
+ example: [2, 2] means downsample by 2x and then upsample by 2x
45
+ """
46
+ self.encoder = VocosBackbone(
47
+ input_channels=input_channels,
48
+ dim=vocos_dim,
49
+ intermediate_dim=vocos_intermediate_dim,
50
+ num_layers=vocos_num_layers,
51
+ condition_dim=None,
52
+ )
53
+
54
+ modules = [
55
+ nn.Sequential(
56
+ SamplingBlock(
57
+ dim=vocos_dim,
58
+ groups=vocos_dim,
59
+ downsample_scale=ratio,
60
+ ),
61
+ VocosBackbone(
62
+ input_channels=vocos_dim,
63
+ dim=vocos_dim,
64
+ intermediate_dim=vocos_intermediate_dim,
65
+ num_layers=2,
66
+ condition_dim=None,
67
+ ),
68
+ )
69
+ for ratio in sample_ratios
70
+ ]
71
+
72
+ self.downsample = nn.Sequential(*modules)
73
+
74
+ self.project = nn.Linear(vocos_dim, out_channels)
75
+
76
+ def forward(self, x: torch.Tensor, *args):
77
+ """
78
+ Args:
79
+ x (torch.Tensor): (batch_size, input_channels, length)
80
+
81
+ Returns:
82
+ x (torch.Tensor): (batch_size, encode_channels, length)
83
+ """
84
+ x = self.encoder(x)
85
+ x = self.downsample(x)
86
+ x = self.project(x)
87
+ return x.transpose(1, 2)
88
+
89
+
90
+ # test
91
+ if __name__ == "__main__":
92
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
93
+ encoder = Encoder(
94
+ input_channels=1024,
95
+ vocos_dim=384,
96
+ vocos_intermediate_dim=2048,
97
+ vocos_num_layers=12,
98
+ out_channels=256,
99
+ sample_ratios=[2, 2],
100
+ )
101
+
102
+ output = encoder(test_input)
103
+ print(output.shape) # torch.Size([8, 256, 12])
104
+ if output.shape == torch.Size([8, 256, 12]):
105
+ print("test successful")
trained_30_percents/sparktts/modules/encoder_decoder/wave_generator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Xinsheng Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
16
+
17
+
18
+ import torch.nn as nn
19
+
20
+ from sparktts.modules.blocks.layers import (
21
+ Snake1d,
22
+ WNConv1d,
23
+ ResidualUnit,
24
+ WNConvTranspose1d,
25
+ init_weights,
26
+ )
27
+
28
+
29
+ class DecoderBlock(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_dim: int = 16,
33
+ output_dim: int = 8,
34
+ kernel_size: int = 2,
35
+ stride: int = 1,
36
+ ):
37
+ super().__init__()
38
+ self.block = nn.Sequential(
39
+ Snake1d(input_dim),
40
+ WNConvTranspose1d(
41
+ input_dim,
42
+ output_dim,
43
+ kernel_size=kernel_size,
44
+ stride=stride,
45
+ padding=(kernel_size - stride) // 2,
46
+ ),
47
+ ResidualUnit(output_dim, dilation=1),
48
+ ResidualUnit(output_dim, dilation=3),
49
+ ResidualUnit(output_dim, dilation=9),
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.block(x)
54
+
55
+
56
+ class WaveGenerator(nn.Module):
57
+ def __init__(
58
+ self,
59
+ input_channel,
60
+ channels,
61
+ rates,
62
+ kernel_sizes,
63
+ d_out: int = 1,
64
+ ):
65
+ super().__init__()
66
+
67
+ # Add first conv layer
68
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
69
+
70
+ # Add upsampling + MRF blocks
71
+ for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
72
+ input_dim = channels // 2**i
73
+ output_dim = channels // 2 ** (i + 1)
74
+ layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
75
+
76
+ # Add final conv layer
77
+ layers += [
78
+ Snake1d(output_dim),
79
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
80
+ nn.Tanh(),
81
+ ]
82
+
83
+ self.model = nn.Sequential(*layers)
84
+
85
+ self.apply(init_weights)
86
+
87
+ def forward(self, x):
88
+ return self.model(x)
trained_30_percents/sparktts/modules/fsq/__pycache__/finite_scalar_quantization.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
trained_30_percents/sparktts/modules/fsq/__pycache__/residual_fsq.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
trained_30_percents/sparktts/modules/fsq/finite_scalar_quantization.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+
26
+ def default(*args):
27
+ for arg in args:
28
+ if exists(arg):
29
+ return arg
30
+ return None
31
+
32
+
33
+ def maybe(fn):
34
+ @wraps(fn)
35
+ def inner(x, *args, **kwargs):
36
+ if not exists(x):
37
+ return x
38
+ return fn(x, *args, **kwargs)
39
+
40
+ return inner
41
+
42
+
43
+ def pack_one(t, pattern):
44
+ return pack([t], pattern)
45
+
46
+
47
+ def unpack_one(t, ps, pattern):
48
+ return unpack(t, ps, pattern)[0]
49
+
50
+
51
+ # tensor helpers
52
+
53
+
54
+ def round_ste(z: Tensor) -> Tensor:
55
+ """Round with straight through gradients."""
56
+ zhat = z.round()
57
+ return z + (zhat - z).detach()
58
+
59
+
60
+ # main class
61
+
62
+
63
+ class FSQ(Module):
64
+ def __init__(
65
+ self,
66
+ levels: List[int],
67
+ dim: int | None = None,
68
+ num_codebooks=1,
69
+ keep_num_codebooks_dim: bool | None = None,
70
+ scale: float | None = None,
71
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
72
+ channel_first: bool = False,
73
+ projection_has_bias: bool = True,
74
+ return_indices=True,
75
+ force_quantization_f32=True,
76
+ ):
77
+ super().__init__()
78
+ _levels = torch.tensor(levels, dtype=int32)
79
+ self.register_buffer("_levels", _levels, persistent=False)
80
+
81
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82
+ self.register_buffer("_basis", _basis, persistent=False)
83
+
84
+ self.scale = scale
85
+
86
+ codebook_dim = len(levels)
87
+ self.codebook_dim = codebook_dim
88
+
89
+ effective_codebook_dim = codebook_dim * num_codebooks
90
+ self.num_codebooks = num_codebooks
91
+ self.effective_codebook_dim = effective_codebook_dim
92
+
93
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
94
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
95
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
96
+
97
+ self.dim = default(dim, len(_levels) * num_codebooks)
98
+
99
+ self.channel_first = channel_first
100
+
101
+ has_projections = self.dim != effective_codebook_dim
102
+ self.project_in = (
103
+ nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
104
+ if has_projections
105
+ else nn.Identity()
106
+ )
107
+ self.project_out = (
108
+ nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
109
+ if has_projections
110
+ else nn.Identity()
111
+ )
112
+
113
+ self.has_projections = has_projections
114
+
115
+ self.return_indices = return_indices
116
+ if return_indices:
117
+ self.codebook_size = self._levels.prod().item()
118
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
119
+ self.register_buffer(
120
+ "implicit_codebook", implicit_codebook, persistent=False
121
+ )
122
+
123
+ self.allowed_dtypes = allowed_dtypes
124
+ self.force_quantization_f32 = force_quantization_f32
125
+
126
+ def bound(self, z, eps: float = 1e-3):
127
+ """Bound `z`, an array of shape (..., d)."""
128
+ half_l = (self._levels - 1) * (1 + eps) / 2
129
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
130
+ shift = (offset / half_l).atanh()
131
+ return (z + shift).tanh() * half_l - offset
132
+
133
+ def quantize(self, z):
134
+ """Quantizes z, returns quantized zhat, same shape as z."""
135
+ quantized = round_ste(self.bound(z))
136
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
137
+ return quantized / half_width
138
+
139
+ def _scale_and_shift(self, zhat_normalized):
140
+ half_width = self._levels // 2
141
+ return (zhat_normalized * half_width) + half_width
142
+
143
+ def _scale_and_shift_inverse(self, zhat):
144
+ half_width = self._levels // 2
145
+ return (zhat - half_width) / half_width
146
+
147
+ def _indices_to_codes(self, indices):
148
+ level_indices = self.indices_to_level_indices(indices)
149
+ codes = self._scale_and_shift_inverse(level_indices)
150
+ return codes
151
+
152
+ def codes_to_indices(self, zhat):
153
+ """Converts a `code` to an index in the codebook."""
154
+ assert zhat.shape[-1] == self.codebook_dim
155
+ zhat = self._scale_and_shift(zhat)
156
+ return (zhat * self._basis).sum(dim=-1).to(int32)
157
+
158
+ def indices_to_level_indices(self, indices):
159
+ """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
160
+ indices = rearrange(indices, "... -> ... 1")
161
+ codes_non_centered = (indices // self._basis) % self._levels
162
+ return codes_non_centered
163
+
164
+ def indices_to_codes(self, indices):
165
+ """Inverse of `codes_to_indices`."""
166
+ assert exists(indices)
167
+
168
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
169
+
170
+ codes = self._indices_to_codes(indices)
171
+
172
+ if self.keep_num_codebooks_dim:
173
+ codes = rearrange(codes, "... c d -> ... (c d)")
174
+
175
+ codes = self.project_out(codes)
176
+
177
+ if is_img_or_video or self.channel_first:
178
+ codes = rearrange(codes, "b ... d -> b d ...")
179
+
180
+ return codes
181
+
182
+ def forward(self, z):
183
+ """
184
+ einstein notation
185
+ b - batch
186
+ n - sequence (or flattened spatial dimensions)
187
+ d - feature dimension
188
+ c - number of codebook dim
189
+ """
190
+
191
+ is_img_or_video = z.ndim >= 4
192
+ need_move_channel_last = is_img_or_video or self.channel_first
193
+
194
+ # standardize image or video into (batch, seq, dimension)
195
+
196
+ if need_move_channel_last:
197
+ z = rearrange(z, "b d ... -> b ... d")
198
+ z, ps = pack_one(z, "b * d")
199
+
200
+ assert (
201
+ z.shape[-1] == self.dim
202
+ ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
203
+
204
+ z = self.project_in(z)
205
+
206
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
207
+
208
+ # whether to force quantization step to be full precision or not
209
+
210
+ force_f32 = self.force_quantization_f32
211
+ quantization_context = (
212
+ partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
213
+ )
214
+
215
+ with quantization_context():
216
+ orig_dtype = z.dtype
217
+
218
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
219
+ z = z.float()
220
+
221
+ codes = self.quantize(z)
222
+
223
+ # returning indices could be optional
224
+
225
+ indices = None
226
+
227
+ if self.return_indices:
228
+ indices = self.codes_to_indices(codes)
229
+
230
+ codes = rearrange(codes, "b n c d -> b n (c d)")
231
+
232
+ codes = codes.type(orig_dtype)
233
+
234
+ # project out
235
+
236
+ out = self.project_out(codes)
237
+
238
+ # reconstitute image or video dimensions
239
+
240
+ if need_move_channel_last:
241
+ out = unpack_one(out, ps, "b * d")
242
+ out = rearrange(out, "b ... d -> b d ...")
243
+
244
+ indices = maybe(unpack_one)(indices, ps, "b * c")
245
+
246
+ if not self.keep_num_codebooks_dim and self.return_indices:
247
+ indices = maybe(rearrange)(indices, "... 1 -> ...")
248
+
249
+ # return quantized output and indices
250
+
251
+ return out, indices
trained_30_percents/sparktts/modules/fsq/residual_fsq.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from typing import List
7
+ from torch import nn
8
+ from torch.nn import Module
9
+ from torch.amp import autocast
10
+ from einx import get_at
11
+ from einops import rearrange, reduce, pack, unpack
12
+
13
+ from sparktts.modules.fsq.finite_scalar_quantization import FSQ
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def first(l):
21
+ return l[0]
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ def round_up_multiple(num, mult):
29
+ return ceil(num / mult) * mult
30
+
31
+
32
+ # distributed helpers
33
+
34
+
35
+ def is_distributed():
36
+ return dist.is_initialized() and dist.get_world_size() > 1
37
+
38
+
39
+ def get_maybe_sync_seed(device, max_size=10_000):
40
+ rand_int = torch.randint(0, max_size, (), device=device)
41
+
42
+ if is_distributed():
43
+ dist.all_reduce(rand_int)
44
+
45
+ return rand_int.item()
46
+
47
+
48
+ class ResidualFSQ(Module):
49
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ levels: List[int],
55
+ num_quantizers,
56
+ dim=None,
57
+ is_channel_first=False,
58
+ quantize_dropout=False,
59
+ quantize_dropout_cutoff_index=0,
60
+ quantize_dropout_multiple_of=1,
61
+ **kwargs,
62
+ ):
63
+ super().__init__()
64
+ codebook_dim = len(levels)
65
+ dim = default(dim, codebook_dim)
66
+
67
+ requires_projection = codebook_dim != dim
68
+ self.project_in = (
69
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
70
+ )
71
+ self.project_out = (
72
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
73
+ )
74
+ self.has_projections = requires_projection
75
+
76
+ self.is_channel_first = is_channel_first
77
+ self.num_quantizers = num_quantizers
78
+
79
+ self.levels = levels
80
+ self.layers = nn.ModuleList([])
81
+
82
+ levels_tensor = torch.Tensor(levels)
83
+
84
+ scales = []
85
+
86
+ for ind in range(num_quantizers):
87
+ scales.append((levels_tensor - 1) ** -ind)
88
+
89
+ fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
90
+
91
+ self.layers.append(fsq)
92
+
93
+ assert all([not fsq.has_projections for fsq in self.layers])
94
+
95
+ self.codebook_size = self.layers[0].codebook_size
96
+
97
+ self.register_buffer("scales", torch.stack(scales), persistent=False)
98
+
99
+ self.quantize_dropout = quantize_dropout and num_quantizers > 1
100
+
101
+ assert quantize_dropout_cutoff_index >= 0
102
+
103
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
104
+ self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
105
+
106
+ @property
107
+ def codebooks(self):
108
+ codebooks = [layer.implicit_codebook for layer in self.layers]
109
+ codebooks = torch.stack(codebooks, dim=0)
110
+ return codebooks
111
+
112
+ def get_codes_from_indices(self, indices):
113
+
114
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
115
+
116
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
117
+
118
+ indices, ps = pack([indices], "b * q")
119
+
120
+ # because of quantize dropout, one can pass in indices that are coarse
121
+ # and the network should be able to reconstruct
122
+
123
+ if quantize_dim < self.num_quantizers:
124
+ assert (
125
+ self.quantize_dropout > 0.0
126
+ ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
127
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
128
+
129
+ # take care of quantizer dropout
130
+
131
+ mask = indices == -1
132
+ indices = indices.masked_fill(
133
+ mask, 0
134
+ ) # have it fetch a dummy code to be masked out later
135
+
136
+ all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
137
+
138
+ # mask out any codes that were dropout-ed
139
+
140
+ all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
141
+
142
+ # scale the codes
143
+
144
+ scales = rearrange(self.scales, "q d -> q 1 1 d")
145
+ all_codes = all_codes * scales
146
+
147
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
148
+
149
+ (all_codes,) = unpack(all_codes, ps, "q b * d")
150
+
151
+ return all_codes
152
+
153
+ def get_output_from_indices(self, indices):
154
+ codes = self.get_codes_from_indices(indices)
155
+ codes_summed = reduce(codes, "q ... -> ...", "sum")
156
+ return self.project_out(codes_summed)
157
+
158
+ def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
159
+ num_quant, quant_dropout_multiple_of, device = (
160
+ self.num_quantizers,
161
+ self.quantize_dropout_multiple_of,
162
+ x.device,
163
+ )
164
+
165
+ # handle channel first
166
+
167
+ if self.is_channel_first:
168
+ x = rearrange(x, "b d ... -> b ... d")
169
+ x, ps = pack([x], "b * d")
170
+
171
+ # maybe project in
172
+
173
+ x = self.project_in(x)
174
+
175
+ quantized_out = 0.0
176
+ residual = x
177
+
178
+ all_indices = []
179
+
180
+ should_quantize_dropout = self.training and self.quantize_dropout
181
+
182
+ # sample a layer index at which to dropout further residual quantization
183
+ # also prepare null indices
184
+
185
+ if should_quantize_dropout:
186
+
187
+ # check if seed is manually passed in
188
+
189
+ if not exists(rand_quantize_dropout_fixed_seed):
190
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
191
+
192
+ rand = random.Random(rand_quantize_dropout_fixed_seed)
193
+
194
+ rand_quantize_dropout_index = rand.randrange(
195
+ self.quantize_dropout_cutoff_index, num_quant
196
+ )
197
+
198
+ if quant_dropout_multiple_of != 1:
199
+ rand_quantize_dropout_index = (
200
+ round_up_multiple(
201
+ rand_quantize_dropout_index + 1, quant_dropout_multiple_of
202
+ )
203
+ - 1
204
+ )
205
+
206
+ null_indices = torch.full(
207
+ x.shape[:2], -1.0, device=device, dtype=torch.long
208
+ )
209
+
210
+ # go through the layers
211
+
212
+ with autocast("cuda", enabled=False):
213
+ for quantizer_index, (layer, scale) in enumerate(
214
+ zip(self.layers, self.scales)
215
+ ):
216
+
217
+ if (
218
+ should_quantize_dropout
219
+ and quantizer_index > rand_quantize_dropout_index
220
+ ):
221
+ all_indices.append(null_indices)
222
+ continue
223
+
224
+ quantized, indices = layer(residual / scale)
225
+
226
+ quantized = quantized * scale
227
+
228
+ residual = residual - quantized.detach()
229
+ quantized_out = quantized_out + quantized
230
+
231
+ all_indices.append(indices)
232
+
233
+ # project out, if needed
234
+
235
+ quantized_out = self.project_out(quantized_out)
236
+
237
+ # stack all indices
238
+
239
+ all_indices = torch.stack(all_indices, dim=-1)
240
+
241
+ # channel first out
242
+
243
+ if self.is_channel_first:
244
+ (quantized_out,) = unpack(quantized_out, ps, "b * d")
245
+ (all_indices,) = unpack(all_indices, ps, "b * d")
246
+
247
+ quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
248
+ all_indices = rearrange(all_indices, "b ... d -> b d ...")
249
+
250
+ # return
251
+
252
+ ret = (quantized_out, all_indices)
253
+
254
+ if not return_all_codes:
255
+ return ret
256
+
257
+ # whether to return all codes from all codebooks across layers
258
+
259
+ all_codes = self.get_codes_from_indices(all_indices)
260
+
261
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
262
+
263
+ return (*ret, all_codes)
264
+
265
+
266
+ # grouped residual fsq
267
+
268
+
269
+ class GroupedResidualFSQ(Module):
270
+ def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.groups = groups
274
+ assert (dim % groups) == 0
275
+ dim_per_group = dim // groups
276
+
277
+ self.accept_image_fmap = accept_image_fmap
278
+
279
+ self.rvqs = nn.ModuleList([])
280
+
281
+ for _ in range(groups):
282
+ self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
283
+
284
+ self.codebook_size = self.rvqs[0].codebook_size
285
+
286
+ @property
287
+ def codebooks(self):
288
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
289
+
290
+ @property
291
+ def split_dim(self):
292
+ return 1 if self.accept_image_fmap else -1
293
+
294
+ def get_codes_from_indices(self, indices):
295
+ codes = tuple(
296
+ rvq.get_codes_from_indices(chunk_indices)
297
+ for rvq, chunk_indices in zip(self.rvqs, indices)
298
+ )
299
+ return torch.stack(codes)
300
+
301
+ def get_output_from_indices(self, indices):
302
+ outputs = tuple(
303
+ rvq.get_output_from_indices(chunk_indices)
304
+ for rvq, chunk_indices in zip(self.rvqs, indices)
305
+ )
306
+ return torch.cat(outputs, dim=self.split_dim)
307
+
308
+ def forward(self, x, return_all_codes=False):
309
+ shape, split_dim, device = x.shape, self.split_dim, x.device
310
+ assert shape[split_dim] == self.dim
311
+
312
+ # split the feature dimension into groups
313
+
314
+ x = x.chunk(self.groups, dim=split_dim)
315
+
316
+ forward_kwargs = dict(
317
+ return_all_codes=return_all_codes,
318
+ rand_quantize_dropout_fixed_seed=(
319
+ get_maybe_sync_seed(device) if self.training else None
320
+ ),
321
+ )
322
+
323
+ # invoke residual vq on each group
324
+
325
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
326
+ out = tuple(zip(*out))
327
+
328
+ # otherwise, get all the zipped outputs and combine them
329
+
330
+ quantized, all_indices, *maybe_all_codes = out
331
+
332
+ quantized = torch.cat(quantized, dim=split_dim)
333
+ all_indices = torch.stack(all_indices)
334
+
335
+ ret = (quantized, all_indices, *maybe_all_codes)
336
+ return ret
337
+
338
+
339
+ if __name__ == "__main__":
340
+ model = ResidualFSQ(
341
+ levels=[4, 4, 4, 4, 4, 4],
342
+ num_quantizers=1,
343
+ dim=30,
344
+ is_channel_first=True,
345
+ quantize_dropout=False,
346
+ )
347
+ x = torch.randn(2, 30, 10)
348
+ quantize, embed_ind = model(x)
349
+
350
+ emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
351
+
352
+ print(quantize == emb_from_ind.transpose(1, 2))
353
+
354
+ print("quantize shape", quantize.shape)
355
+ print("embed_ind", embed_ind)
trained_30_percents/sparktts/modules/speaker/__pycache__/ecapa_tdnn.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
trained_30_percents/sparktts/modules/speaker/__pycache__/perceiver_encoder.cpython-311.pyc ADDED
Binary file (17.7 kB). View file
 
trained_30_percents/sparktts/modules/speaker/__pycache__/pooling_layers.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
trained_30_percents/sparktts/modules/speaker/__pycache__/speaker_encoder.cpython-311.pyc ADDED
Binary file (7.26 kB). View file
 
trained_30_percents/sparktts/modules/speaker/ecapa_tdnn.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Zhengyang Chen ([email protected])
2
+ # 2022 Hongji Wang ([email protected])
3
+ # 2023 Bing Han ([email protected])
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ This implementation is adapted from github repo:
18
+ https://github.com/lawlict/ECAPA-TDNN.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ import sparktts.modules.speaker.pooling_layers as pooling_layers
26
+
27
+
28
+ class Res2Conv1dReluBn(nn.Module):
29
+ """
30
+ in_channels == out_channels == channels
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels,
36
+ kernel_size=1,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ bias=True,
41
+ scale=4,
42
+ ):
43
+ super().__init__()
44
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
45
+ self.scale = scale
46
+ self.width = channels // scale
47
+ self.nums = scale if scale == 1 else scale - 1
48
+
49
+ self.convs = []
50
+ self.bns = []
51
+ for i in range(self.nums):
52
+ self.convs.append(
53
+ nn.Conv1d(
54
+ self.width,
55
+ self.width,
56
+ kernel_size,
57
+ stride,
58
+ padding,
59
+ dilation,
60
+ bias=bias,
61
+ )
62
+ )
63
+ self.bns.append(nn.BatchNorm1d(self.width))
64
+ self.convs = nn.ModuleList(self.convs)
65
+ self.bns = nn.ModuleList(self.bns)
66
+
67
+ def forward(self, x):
68
+ out = []
69
+ spx = torch.split(x, self.width, 1)
70
+ sp = spx[0]
71
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
72
+ # Order: conv -> relu -> bn
73
+ if i >= 1:
74
+ sp = sp + spx[i]
75
+ sp = conv(sp)
76
+ sp = bn(F.relu(sp))
77
+ out.append(sp)
78
+ if self.scale != 1:
79
+ out.append(spx[self.nums])
80
+ out = torch.cat(out, dim=1)
81
+
82
+ return out
83
+
84
+
85
+ """ Conv1d + BatchNorm1d + ReLU
86
+ """
87
+
88
+
89
+ class Conv1dReluBn(nn.Module):
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size=1,
96
+ stride=1,
97
+ padding=0,
98
+ dilation=1,
99
+ bias=True,
100
+ ):
101
+ super().__init__()
102
+ self.conv = nn.Conv1d(
103
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
104
+ )
105
+ self.bn = nn.BatchNorm1d(out_channels)
106
+
107
+ def forward(self, x):
108
+ return self.bn(F.relu(self.conv(x)))
109
+
110
+
111
+ """ The SE connection of 1D case.
112
+ """
113
+
114
+
115
+ class SE_Connect(nn.Module):
116
+
117
+ def __init__(self, channels, se_bottleneck_dim=128):
118
+ super().__init__()
119
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
120
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
121
+
122
+ def forward(self, x):
123
+ out = x.mean(dim=2)
124
+ out = F.relu(self.linear1(out))
125
+ out = torch.sigmoid(self.linear2(out))
126
+ out = x * out.unsqueeze(2)
127
+
128
+ return out
129
+
130
+
131
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
132
+ """
133
+
134
+
135
+ class SE_Res2Block(nn.Module):
136
+
137
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
138
+ super().__init__()
139
+ self.se_res2block = nn.Sequential(
140
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
141
+ Res2Conv1dReluBn(
142
+ channels, kernel_size, stride, padding, dilation, scale=scale
143
+ ),
144
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
145
+ SE_Connect(channels),
146
+ )
147
+
148
+ def forward(self, x):
149
+ return x + self.se_res2block(x)
150
+
151
+
152
+ class ECAPA_TDNN(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ channels=512,
157
+ feat_dim=80,
158
+ embed_dim=192,
159
+ pooling_func="ASTP",
160
+ global_context_att=False,
161
+ emb_bn=False,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
166
+ self.layer2 = SE_Res2Block(
167
+ channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
168
+ )
169
+ self.layer3 = SE_Res2Block(
170
+ channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
171
+ )
172
+ self.layer4 = SE_Res2Block(
173
+ channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
174
+ )
175
+
176
+ cat_channels = channels * 3
177
+ out_channels = 512 * 3
178
+ self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
179
+ self.pool = getattr(pooling_layers, pooling_func)(
180
+ in_dim=out_channels, global_context_att=global_context_att
181
+ )
182
+ self.pool_out_dim = self.pool.get_out_dim()
183
+ self.bn = nn.BatchNorm1d(self.pool_out_dim)
184
+ self.linear = nn.Linear(self.pool_out_dim, embed_dim)
185
+ self.emb_bn = emb_bn
186
+ if emb_bn: # better in SSL for SV
187
+ self.bn2 = nn.BatchNorm1d(embed_dim)
188
+ else:
189
+ self.bn2 = nn.Identity()
190
+
191
+ def forward(self, x, return_latent=False):
192
+ x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
193
+
194
+ out1 = self.layer1(x)
195
+ out2 = self.layer2(out1)
196
+ out3 = self.layer3(out2)
197
+ out4 = self.layer4(out3)
198
+
199
+ out = torch.cat([out2, out3, out4], dim=1)
200
+ latent = F.relu(self.conv(out))
201
+ out = self.bn(self.pool(latent))
202
+ out = self.linear(out)
203
+ if self.emb_bn:
204
+ out = self.bn2(out)
205
+
206
+ if return_latent:
207
+ return out, latent
208
+ return out
209
+
210
+
211
+ def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
212
+ return ECAPA_TDNN(
213
+ channels=1024,
214
+ feat_dim=feat_dim,
215
+ embed_dim=embed_dim,
216
+ pooling_func=pooling_func,
217
+ emb_bn=emb_bn,
218
+ )
219
+
220
+
221
+ def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
222
+ return ECAPA_TDNN(
223
+ channels=1024,
224
+ feat_dim=feat_dim,
225
+ embed_dim=embed_dim,
226
+ pooling_func=pooling_func,
227
+ global_context_att=True,
228
+ emb_bn=emb_bn,
229
+ )
230
+
231
+
232
+ def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
233
+ return ECAPA_TDNN(
234
+ channels=512,
235
+ feat_dim=feat_dim,
236
+ embed_dim=embed_dim,
237
+ pooling_func=pooling_func,
238
+ emb_bn=emb_bn,
239
+ )
240
+
241
+
242
+ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
243
+ return ECAPA_TDNN(
244
+ channels=512,
245
+ feat_dim=feat_dim,
246
+ embed_dim=embed_dim,
247
+ pooling_func=pooling_func,
248
+ global_context_att=True,
249
+ emb_bn=emb_bn,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ x = torch.zeros(1, 200, 100)
255
+ model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
256
+ model.eval()
257
+ out, latent = model(x, True)
258
+ print(out.shape)
259
+ print(latent.shape)
260
+
261
+ num_params = sum(param.numel() for param in model.parameters())
262
+ print("{} M".format(num_params / 1e6))
263
+
264
+ # from thop import profile
265
+ # x_np = torch.randn(1, 200, 80)
266
+ # flops, params = profile(model, inputs=(x_np, ))
267
+ # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
trained_30_percents/sparktts/modules/speaker/perceiver_encoder.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
+
18
+ from collections import namedtuple
19
+ from functools import wraps
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange, repeat
24
+ from einops.layers.torch import Rearrange
25
+ from packaging import version
26
+ from torch import einsum, nn
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def once(fn):
34
+ called = False
35
+
36
+ @wraps(fn)
37
+ def inner(x):
38
+ nonlocal called
39
+ if called:
40
+ return
41
+ called = True
42
+ return fn(x)
43
+
44
+ return inner
45
+
46
+
47
+ print_once = once(print)
48
+
49
+ # main class
50
+
51
+
52
+ class Attend(nn.Module):
53
+ def __init__(self, dropout=0.0, causal=False, use_flash=False):
54
+ super().__init__()
55
+ self.dropout = dropout
56
+ self.attn_dropout = nn.Dropout(dropout)
57
+
58
+ self.causal = causal
59
+ self.register_buffer("mask", None, persistent=False)
60
+
61
+ self.use_flash = use_flash
62
+ assert not (
63
+ use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
+
66
+ # determine efficient attention configs for cuda and cpu
67
+ self.config = namedtuple(
68
+ "EfficientAttentionConfig",
69
+ ["enable_flash", "enable_math", "enable_mem_efficient"],
70
+ )
71
+ self.cpu_config = self.config(True, True, True)
72
+ self.cuda_config = None
73
+
74
+ if not torch.cuda.is_available() or not use_flash:
75
+ return
76
+
77
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
78
+
79
+ if device_properties.major == 8 and device_properties.minor == 0:
80
+ print_once(
81
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
82
+ )
83
+ self.cuda_config = self.config(True, False, False)
84
+ else:
85
+ print_once(
86
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
+ )
88
+ self.cuda_config = self.config(False, True, True)
89
+
90
+ def get_mask(self, n, device):
91
+ if exists(self.mask) and self.mask.shape[-1] >= n:
92
+ return self.mask[:n, :n]
93
+
94
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
95
+ self.register_buffer("mask", mask, persistent=False)
96
+ return mask
97
+
98
+ def flash_attn(self, q, k, v, mask=None):
99
+ _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
100
+
101
+ # Recommended for multi-query single-key-value attention by Tri Dao
102
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
103
+
104
+ if k.ndim == 3:
105
+ k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
106
+
107
+ if v.ndim == 3:
108
+ v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
109
+
110
+ # Check if mask exists and expand to compatible shape
111
+ # The mask is B L, so it would have to be expanded to B H N L
112
+
113
+ if exists(mask):
114
+ mask = rearrange(mask, "b j -> b 1 1 j")
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=self.causal,
131
+ )
132
+
133
+ return out
134
+
135
+ def forward(self, q, k, v, mask=None):
136
+ """
137
+ einstein notation
138
+ b - batch
139
+ h - heads
140
+ n, i, j - sequence length (base sequence length, source, target)
141
+ d - feature dimension
142
+ """
143
+
144
+ n, device = q.shape[-2], q.device
145
+
146
+ scale = q.shape[-1] ** -0.5
147
+
148
+ if self.use_flash:
149
+ return self.flash_attn(q, k, v, mask=mask)
150
+
151
+ kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
152
+
153
+ # similarity
154
+
155
+ sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
156
+
157
+ # key padding mask
158
+
159
+ if exists(mask):
160
+ mask = rearrange(mask, "b j -> b 1 1 j")
161
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
162
+
163
+ # causal mask
164
+
165
+ if self.causal:
166
+ causal_mask = self.get_mask(n, device)
167
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
168
+
169
+ # attention
170
+
171
+ attn = sim.softmax(dim=-1)
172
+ attn = self.attn_dropout(attn)
173
+
174
+ # aggregate values
175
+
176
+ out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
177
+
178
+ return out
179
+
180
+
181
+ def Sequential(*mods):
182
+ return nn.Sequential(*filter(exists, mods))
183
+
184
+
185
+ def exists(x):
186
+ return x is not None
187
+
188
+
189
+ def default(val, d):
190
+ if exists(val):
191
+ return val
192
+ return d() if callable(d) else d
193
+
194
+
195
+ class RMSNorm(nn.Module):
196
+ def __init__(self, dim, scale=True, dim_cond=None):
197
+ super().__init__()
198
+ self.cond = exists(dim_cond)
199
+ self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
200
+
201
+ self.scale = dim**0.5
202
+ self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
203
+
204
+ def forward(self, x, cond=None):
205
+ gamma = default(self.gamma, 1)
206
+ out = F.normalize(x, dim=-1) * self.scale * gamma
207
+
208
+ if not self.cond:
209
+ return out
210
+
211
+ assert exists(cond)
212
+ gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
213
+ gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
214
+ return out * gamma + beta
215
+
216
+
217
+ class CausalConv1d(nn.Conv1d):
218
+ def __init__(self, *args, **kwargs):
219
+ super().__init__(*args, **kwargs)
220
+ (kernel_size,) = self.kernel_size
221
+ (dilation,) = self.dilation
222
+ (stride,) = self.stride
223
+
224
+ assert stride == 1
225
+ self.causal_padding = dilation * (kernel_size - 1)
226
+
227
+ def forward(self, x):
228
+ causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
229
+ return super().forward(causal_padded_x)
230
+
231
+
232
+ class GEGLU(nn.Module):
233
+ def forward(self, x):
234
+ x, gate = x.chunk(2, dim=-1)
235
+ return F.gelu(gate) * x
236
+
237
+
238
+ def FeedForward(dim, mult=4, causal_conv=False):
239
+ dim_inner = int(dim * mult * 2 / 3)
240
+
241
+ conv = None
242
+ if causal_conv:
243
+ conv = nn.Sequential(
244
+ Rearrange("b n d -> b d n"),
245
+ CausalConv1d(dim_inner, dim_inner, 3),
246
+ Rearrange("b d n -> b n d"),
247
+ )
248
+
249
+ return Sequential(
250
+ nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
251
+ )
252
+
253
+
254
+ class Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ *,
259
+ dim_context=None,
260
+ causal=False,
261
+ dim_head=64,
262
+ heads=8,
263
+ dropout=0.0,
264
+ use_flash=False,
265
+ cross_attn_include_queries=False,
266
+ ):
267
+ super().__init__()
268
+ self.scale = dim_head**-0.5
269
+ self.heads = heads
270
+ self.cross_attn_include_queries = cross_attn_include_queries
271
+
272
+ dim_inner = dim_head * heads
273
+ dim_context = default(dim_context, dim)
274
+
275
+ self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
276
+ self.to_q = nn.Linear(dim, dim_inner, bias=False)
277
+ self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
278
+ self.to_out = nn.Linear(dim_inner, dim, bias=False)
279
+
280
+ def forward(self, x, context=None, mask=None):
281
+ h, has_context = self.heads, exists(context)
282
+
283
+ context = default(context, x)
284
+
285
+ if has_context and self.cross_attn_include_queries:
286
+ context = torch.cat((x, context), dim=-2)
287
+
288
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
289
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
290
+
291
+ out = self.attend(q, k, v, mask=mask)
292
+
293
+ out = rearrange(out, "b h n d -> b n (h d)")
294
+ return self.to_out(out)
295
+
296
+
297
+ class PerceiverResampler(nn.Module):
298
+ def __init__(
299
+ self,
300
+ *,
301
+ dim,
302
+ depth=2,
303
+ dim_context=None,
304
+ num_latents=32,
305
+ dim_head=64,
306
+ heads=8,
307
+ ff_mult=4,
308
+ use_flash_attn=False,
309
+ ):
310
+ super().__init__()
311
+ dim_context = default(dim_context, dim)
312
+
313
+ self.proj_context = (
314
+ nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
315
+ )
316
+
317
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
318
+ nn.init.normal_(self.latents, std=0.02)
319
+
320
+ self.layers = nn.ModuleList([])
321
+ for _ in range(depth):
322
+ self.layers.append(
323
+ nn.ModuleList(
324
+ [
325
+ Attention(
326
+ dim=dim,
327
+ dim_head=dim_head,
328
+ heads=heads,
329
+ use_flash=use_flash_attn,
330
+ cross_attn_include_queries=True,
331
+ ),
332
+ FeedForward(dim=dim, mult=ff_mult),
333
+ ]
334
+ )
335
+ )
336
+
337
+ self.norm = RMSNorm(dim)
338
+
339
+ def forward(self, x, mask=None):
340
+ batch = x.shape[0]
341
+
342
+ x = self.proj_context(x)
343
+
344
+ latents = repeat(self.latents, "n d -> b n d", b=batch)
345
+
346
+ for attn, ff in self.layers:
347
+ latents = attn(latents, x, mask=mask) + latents
348
+ latents = ff(latents) + latents
349
+
350
+ return self.norm(latents)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ model = PerceiverResampler(dim=256, dim_context=80)
355
+ x = torch.randn(8, 200, 80)
356
+ out = model(x)
357
+ print(out.shape) # [8, 32, 80]
358
+
359
+ num_params = sum(param.numel() for param in model.parameters())
360
+ print("{} M".format(num_params / 1e6))
trained_30_percents/sparktts/modules/speaker/pooling_layers.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Shuai Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Pooling functions to aggregate frame-level deep features
16
+ into segment-level speaker embeddings
17
+
18
+ High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
+ even though we remove the mean statistic, on Voxceleb.
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class TAP(nn.Module):
28
+ """
29
+ Temporal average pooling, only first-order mean is considered
30
+ """
31
+
32
+ def __init__(self, in_dim=0, **kwargs):
33
+ super(TAP, self).__init__()
34
+ self.in_dim = in_dim
35
+
36
+ def forward(self, x):
37
+ pooling_mean = x.mean(dim=-1)
38
+ # To be compatable with 2D input
39
+ pooling_mean = pooling_mean.flatten(start_dim=1)
40
+ return pooling_mean
41
+
42
+ def get_out_dim(self):
43
+ self.out_dim = self.in_dim
44
+ return self.out_dim
45
+
46
+
47
+ class TSDP(nn.Module):
48
+ """
49
+ Temporal standard deviation pooling, only second-order std is considered
50
+ """
51
+
52
+ def __init__(self, in_dim=0, **kwargs):
53
+ super(TSDP, self).__init__()
54
+ self.in_dim = in_dim
55
+
56
+ def forward(self, x):
57
+ # The last dimension is the temporal axis
58
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
+ pooling_std = pooling_std.flatten(start_dim=1)
60
+ return pooling_std
61
+
62
+ def get_out_dim(self):
63
+ self.out_dim = self.in_dim
64
+ return self.out_dim
65
+
66
+
67
+ class TSTP(nn.Module):
68
+ """
69
+ Temporal statistics pooling, concatenate mean and std, which is used in
70
+ x-vector
71
+ Comment: simple concatenation can not make full use of both statistics
72
+ """
73
+
74
+ def __init__(self, in_dim=0, **kwargs):
75
+ super(TSTP, self).__init__()
76
+ self.in_dim = in_dim
77
+
78
+ def forward(self, x):
79
+ # The last dimension is the temporal axis
80
+ pooling_mean = x.mean(dim=-1)
81
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
+ pooling_mean = pooling_mean.flatten(start_dim=1)
83
+ pooling_std = pooling_std.flatten(start_dim=1)
84
+ stats = torch.cat((pooling_mean, pooling_std), 1)
85
+ return stats
86
+
87
+ def get_out_dim(self):
88
+ self.out_dim = self.in_dim * 2
89
+ return self.out_dim
90
+
91
+
92
+ class ASTP(nn.Module):
93
+ """ Attentive statistics pooling: Channel- and context-dependent
94
+ statistics pooling, first used in ECAPA_TDNN.
95
+ """
96
+
97
+ def __init__(self,
98
+ in_dim,
99
+ bottleneck_dim=128,
100
+ global_context_att=False,
101
+ **kwargs):
102
+ super(ASTP, self).__init__()
103
+ self.in_dim = in_dim
104
+ self.global_context_att = global_context_att
105
+
106
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
107
+ # need to transpose inputs.
108
+ if global_context_att:
109
+ self.linear1 = nn.Conv1d(
110
+ in_dim * 3, bottleneck_dim,
111
+ kernel_size=1) # equals W and b in the paper
112
+ else:
113
+ self.linear1 = nn.Conv1d(
114
+ in_dim, bottleneck_dim,
115
+ kernel_size=1) # equals W and b in the paper
116
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
117
+ kernel_size=1) # equals V and k in the paper
118
+
119
+ def forward(self, x):
120
+ """
121
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
122
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
123
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
124
+ """
125
+ if len(x.shape) == 4:
126
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
127
+ assert len(x.shape) == 3
128
+
129
+ if self.global_context_att:
130
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
131
+ context_std = torch.sqrt(
132
+ torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
133
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
134
+ else:
135
+ x_in = x
136
+
137
+ # DON'T use ReLU here! ReLU may be hard to converge.
138
+ alpha = torch.tanh(
139
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
140
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
141
+ mean = torch.sum(alpha * x, dim=2)
142
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
143
+ std = torch.sqrt(var.clamp(min=1e-7))
144
+ return torch.cat([mean, std], dim=1)
145
+
146
+ def get_out_dim(self):
147
+ self.out_dim = 2 * self.in_dim
148
+ return self.out_dim
149
+
150
+
151
+ class MHASTP(torch.nn.Module):
152
+ """ Multi head attentive statistics pooling
153
+ Reference:
154
+ Self Multi-Head Attention for Speaker Recognition
155
+ https://arxiv.org/pdf/1906.09890.pdf
156
+ """
157
+
158
+ def __init__(self,
159
+ in_dim,
160
+ layer_num=2,
161
+ head_num=2,
162
+ d_s=1,
163
+ bottleneck_dim=64,
164
+ **kwargs):
165
+ super(MHASTP, self).__init__()
166
+ assert (in_dim % head_num
167
+ ) == 0 # make sure that head num can be divided by input_dim
168
+ self.in_dim = in_dim
169
+ self.head_num = head_num
170
+ d_model = int(in_dim / head_num)
171
+ channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
172
+ if d_s > 1:
173
+ d_s = d_model
174
+ else:
175
+ d_s = 1
176
+ self.d_s = d_s
177
+ channel_dims[0], channel_dims[-1] = d_model, d_s
178
+ heads_att_trans = []
179
+ for i in range(self.head_num):
180
+ att_trans = nn.Sequential()
181
+ for i in range(layer_num - 1):
182
+ att_trans.add_module(
183
+ 'att_' + str(i),
184
+ nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
185
+ att_trans.add_module('tanh' + str(i), nn.Tanh())
186
+ att_trans.add_module(
187
+ 'att_' + str(layer_num - 1),
188
+ nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
189
+ 1, 1))
190
+ heads_att_trans.append(att_trans)
191
+ self.heads_att_trans = nn.ModuleList(heads_att_trans)
192
+
193
+ def forward(self, input):
194
+ """
195
+ input: a 3-dimensional tensor in xvector architecture
196
+ or a 4-dimensional tensor in resnet architecture
197
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
198
+ """
199
+ if len(input.shape) == 4: # B x F x T
200
+ input = input.reshape(input.shape[0],
201
+ input.shape[1] * input.shape[2],
202
+ input.shape[3])
203
+ assert len(input.shape) == 3
204
+ bs, f_dim, t_dim = input.shape
205
+ chunks = torch.chunk(input, self.head_num, 1)
206
+ # split
207
+ chunks_out = []
208
+ # for i in range(self.head_num):
209
+ # att_score = self.heads_att_trans[i](chunks[i])
210
+ for i, layer in enumerate(self.heads_att_trans):
211
+ att_score = layer(chunks[i])
212
+ alpha = F.softmax(att_score, dim=-1)
213
+ mean = torch.sum(alpha * chunks[i], dim=2)
214
+ var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
215
+ std = torch.sqrt(var.clamp(min=1e-7))
216
+ chunks_out.append(torch.cat((mean, std), dim=1))
217
+ out = torch.cat(chunks_out, dim=1)
218
+ return out
219
+
220
+ def get_out_dim(self):
221
+ self.out_dim = 2 * self.in_dim
222
+ return self.out_dim
223
+
224
+
225
+ class MQMHASTP(torch.nn.Module):
226
+ """ An attentive pooling
227
+ Reference:
228
+ multi query multi head attentive statistics pooling
229
+ https://arxiv.org/pdf/2110.05042.pdf
230
+ Args:
231
+ in_dim: the feature dimension of input
232
+ layer_num: the number of layer in the pooling layer
233
+ query_num: the number of querys
234
+ head_num: the number of heads
235
+ bottleneck_dim: the bottleneck dimension
236
+
237
+ SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
238
+ https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
239
+ MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
240
+ https://arxiv.org/pdf/1906.09890.pdf
241
+ AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
242
+ https://arxiv.org/pdf/1803.10963.pdf
243
+ VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
244
+ http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
245
+ """
246
+
247
+ def __init__(self,
248
+ in_dim,
249
+ layer_num=2,
250
+ query_num=2,
251
+ head_num=8,
252
+ d_s=2,
253
+ bottleneck_dim=64,
254
+ **kwargs):
255
+ super(MQMHASTP, self).__init__()
256
+ self.n_query = nn.ModuleList([
257
+ MHASTP(in_dim,
258
+ layer_num=layer_num,
259
+ head_num=head_num,
260
+ d_s=d_s,
261
+ bottleneck_dim=bottleneck_dim) for i in range(query_num)
262
+ ])
263
+ self.query_num = query_num
264
+ self.in_dim = in_dim
265
+
266
+ def forward(self, input):
267
+ """
268
+ input: a 3-dimensional tensor in xvector architecture
269
+ or a 4-dimensional tensor in resnet architecture
270
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
271
+ """
272
+ if len(input.shape) == 4: # B x F x T
273
+ input = input.reshape(input.shape[0],
274
+ input.shape[1] * input.shape[2],
275
+ input.shape[3])
276
+ assert len(input.shape) == 3
277
+ res = []
278
+ for i, layer in enumerate(self.n_query):
279
+ res.append(layer(input))
280
+ out = torch.cat(res, dim=-1)
281
+ return out
282
+
283
+ def get_out_dim(self):
284
+ self.out_dim = self.in_dim * 2 * self.query_num
285
+ return self.out_dim
286
+
287
+
288
+ if __name__ == '__main__':
289
+ data = torch.randn(16, 512, 10, 35)
290
+ # model = StatisticsPooling()
291
+ model = MQMHASTP(512 * 10)
292
+ model = MHASTP(512 * 10)
293
+ model = MQMHASTP(512 * 10, context=False)
294
+ print(model)
295
+
296
+ out = model(data)
297
+ print(out.shape)
298
+ print(model.get_out_dim())
trained_30_percents/sparktts/modules/speaker/speaker_encoder.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from typing import List, Tuple
20
+ from sparktts.modules.fsq.residual_fsq import ResidualFSQ
21
+ from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
22
+ from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
23
+
24
+ """
25
+ x-vector + d-vector
26
+ """
27
+
28
+
29
+ class SpeakerEncoder(nn.Module):
30
+ """
31
+
32
+ Args:
33
+ input_dim (int): acoustic feature dimension
34
+ out_dim (int): output dimension of x-vector and d-vector
35
+ latent_dim (int): latent dimension before quantization
36
+ token_num (int): sequence length of speaker tokens
37
+ fsq_levels (List[int]): number of levels for each quantizer
38
+ fsq_num_quantizers (int): number of quantizers
39
+
40
+ Return:
41
+ speaker_embs: (B, T2, out_dim)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_dim: int = 100,
47
+ out_dim: int = 512,
48
+ latent_dim: int = 128,
49
+ token_num: int = 32,
50
+ fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
51
+ fsq_num_quantizers: int = 1,
52
+ ):
53
+ super(SpeakerEncoder, self).__init__()
54
+
55
+ self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
56
+ feat_dim=input_dim, embed_dim=out_dim
57
+ )
58
+ self.perceiver_sampler = PerceiverResampler(
59
+ dim=latent_dim, dim_context=512 * 3, num_latents=token_num
60
+ )
61
+ self.quantizer = ResidualFSQ(
62
+ levels=fsq_levels,
63
+ num_quantizers=fsq_num_quantizers,
64
+ dim=latent_dim,
65
+ is_channel_first=True,
66
+ quantize_dropout=False,
67
+ )
68
+
69
+ self.project = nn.Linear(latent_dim * token_num, out_dim)
70
+
71
+ def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
72
+ zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
73
+ return zq.transpose(1, 2)
74
+
75
+ def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
76
+ mels = mels.transpose(1, 2)
77
+ x = self.perceiver_sampler(mels).transpose(1, 2)
78
+ zq, indices = self.quantizer(x)
79
+ return indices
80
+
81
+ def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ """
83
+ Args:
84
+ mels: (B, D_mel, T1)
85
+
86
+ Return:
87
+ x_vector: (B, out_dim)
88
+ d_vector: (B, out_dim)
89
+ """
90
+ # mels = mels.transpose(1,2)
91
+
92
+ x_vector, features = self.speaker_encoder(mels, True)
93
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
94
+ zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
95
+ x = zq.reshape(zq.shape[0], -1)
96
+ d_vector = self.project(x)
97
+
98
+ return x_vector, d_vector
99
+
100
+ def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
101
+ """tokenize the input mel spectrogram"""
102
+ _, features = self.speaker_encoder(mels, True)
103
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
104
+ zq, indices = self.quantizer(x)
105
+ return indices
106
+
107
+ def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
108
+ """detokenize the input indices to d-vector"""
109
+ zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
110
+ x = zq.reshape(zq.shape[0], -1)
111
+ d_vector = self.project(x)
112
+ return d_vector
113
+
114
+ if __name__ == "__main__":
115
+ model = SpeakerEncoder(
116
+ input_dim=100,
117
+ latent_dim=128,
118
+ token_num=32,
119
+ fsq_levels=[4, 4, 4, 4, 4, 4],
120
+ fsq_num_quantizers=1,
121
+ )
122
+ mel = torch.randn(8, 200, 100)
123
+ x_vector, d_vector = model(mel)
124
+ print("x-vector shape", x_vector.shape)
125
+ print("d-vector shape", d_vector.shape)
126
+
127
+ indices = model.tokenize(mel)
128
+ print("indices shape", indices.shape)
129
+ d_vector_post = model.detokenize(indices)
130
+ print("d-vector shape", d_vector_post.shape)
131
+ if d_vector_post.all() == d_vector.all():
132
+ print("d-vector post and d-vector are the same")
133
+ else:
134
+ print("d-vector post and d-vector are different")
135
+ num_params = sum(param.numel() for param in model.parameters())
136
+ print("{} M".format(num_params / 1e6))
trained_30_percents/sparktts/modules/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc ADDED
Binary file (9.15 kB). View file