diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/MMaDA/.cursor/rules/python-env.mdc b/MMaDA/.cursor/rules/python-env.mdc
new file mode 100644
index 0000000000000000000000000000000000000000..c1ff6bbc9d8b0c60d891753ae6b69b02dffa02de
--- /dev/null
+++ b/MMaDA/.cursor/rules/python-env.mdc
@@ -0,0 +1,4 @@
+---
+alwaysApply: true
+---
+When running python script, use conda env `mmada`.
\ No newline at end of file
diff --git a/MMaDA/.gitignore b/MMaDA/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4bf4b61a759ad4f6794799e6596ce371df3c8858
--- /dev/null
+++ b/MMaDA/.gitignore
@@ -0,0 +1,2 @@
+exp
+wandb
\ No newline at end of file
diff --git a/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py b/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d27947cc49b2a23d6ea5de59999022f729c6c9d8
--- /dev/null
+++ b/MMaDA/AIDAS-Omni-Modal-Diffusion/app.py
@@ -0,0 +1,16 @@
+import gradio as gr
+import spaces
+import torch
+
+zero = torch.Tensor([0]).cuda()
+print(zero.device) # should print 'cpu' until GPU context is enabled
+
+
+@spaces.GPU
+def greet(n):
+ print(zero.device) # now this should print 'cuda:0'
+ return f"Hello {zero + n} Tensor"
+
+
+demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
+demo.launch()
diff --git a/MMaDA/LICENSE b/MMaDA/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9ccb18421170127e4d4eab6963699a798d3b4ad3
--- /dev/null
+++ b/MMaDA/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Ling Yang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/MMaDA/README.md b/MMaDA/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d837c46be3c49c5e478d8b8c6e9fd37322c066d3
--- /dev/null
+++ b/MMaDA/README.md
@@ -0,0 +1,209 @@
+
+
+
+## š Introduction
+MMaDA is a new family of **multimodal diffusion foundation models** designed to achieve superior performance across diverse domains such as textual reasoning, multimodal understanding, and text-to-image generation. MMaDA is distinguished by three key innovations:
+1. MMaDA adopts a **unified diffusion architecture** with a shared probabilistic formulation and a modality-agnostic design, eliminating the need for modality-specific components.
+2. MMaDA introduces a **mixed long chain-of-thought (CoT) fine-tuning** strategy that curates a unified CoT format across modalities.
+3. MMaDA adopts a unified policy-gradient-based RL algorithm, which we call **UniGRPO**, tailored for diffusion foundation models. Utilizing diversified reward modeling, **UniGRPO** unifies post-training across both reasoning and generation tasks, ensuring consistent performance improvements.
+
+
+
+
+ MMaDA's decoding demo. This video showcases how a diffusion foundation model generates text and image.
+ The "Text Generation" part uses a semi-autoregressive sampling method, while the "Multimodal Generation" part adopts non-autoregressive diffusion denoising.
+
+
+
+
+
+
+
+
+
+
+
+## š° Latest Updates
+* **[2025-06-02]** We open source our **MMaDA-8B-MixCoT** at [Huggingface](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT).
+* **[2025-05-24]** We add support for MPS inference, tested on M4.
+* **[2025-05-22]** We release the inference and training code of MMaDA for text generation, multimodal generation and image generation.
+* **[2025-05-22]** We open source our **MMaDA-8B-Base** at [Huggingface](https://huggingface.co/Gen-Verse/MMaDA-8B-Base). **MMaDA-8B-MixCoT** and **MMaDA-8B-Max** will be released in the near future.
+* **[2025-05-22]** We release our [research paper](https://arxiv.org/abs/2505.15809) and [demo](https://huggingface.co/spaces/Gen-Verse/MMaDA) for the first unified multimodal diffusion model: MMaDA.
+
+
+## 𧬠MMaDA Series Overview
+
+MMaDA includes a series of checkpoints reflecting different training stages:
+1. **[MMaDA-8B-Base](https://huggingface.co/Gen-Verse/MMaDA-8B-Base)**: After pretraining and instruction tuning. Capable of basic text generation, image generation, image captioning and **thinking ablities**.
+2. **[MMaDA-8B-MixCoT](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT)**: After mixed long chain-of-thought (CoT) fine-tuning. Capable of **complex** textual, multimodal and image generation reasoning.
+3. **MMaDA-8B-Max (coming soon)**: After UniGRPO reinforment learning. Excels at complex reasoning and awesome visual generation. Will be released in the future.
+
+
+
Overview of MMaDA's capablities.
+
+
+
+
+
+## ā TODO
+- [x] Release [MMaDA-8B-MixCoT](https://huggingface.co/Gen-Verse/MMaDA-8B-MixCoT)
+- [ ] Release MMaDA-8B-Max and OpenRLHF-based UniGRPO training code.
+
+## āļø Quick Start
+First, set up the enviroment:
+```
+pip install -r requirements.txt
+```
+Launch local Gradio demo:
+```
+python app.py
+```
+Or try it online via our [Huggingface Demo](https://huggingface.co/spaces/Gen-Verse/MMaDA).
+
+## š Inference
+For batch-level inference, we provide our inference scripts here.
+### 1. Text Generation
+For text generation, we follow LLaDA's configuration and generation script. Simple run:
+```bash
+python generate.py
+```
+
+### 2. MultiModal Generation
+For multimodal generation and text-to-image generation, first login your wandb account:
+```
+wandb login
+```
+Inference demo for MultiModal Generation and you can view the results on wandb:
+```
+python3 inference_mmu.py config=configs/mmada_demo.yaml mmu_image_root=./mmu_validation question='Please describe this image in detail.'
+```
+
+### 3. Text-to-Image Genertion
+For multimodal generation and text-to-image generation, first login your wandb account:
+```
+wandb login
+```
+Inference demo for Text-to-Image Genertion and you can view the results on wandb:
+```
+python3 inference_t2i.py config=configs/mmada_demo.yaml batch_size=1 validation_prompts_file=validation_prompts/text2image_prompts.txt guidance_scale=3.5 generation_timesteps=15
+mode='t2i'
+```
+
+## š§ Training
+**Update your training data path in `configs/xx.yaml`.**
+
+### Stage 0. Prepare your accelerate configs
+Please first prepare your accelerate configs. You can simple run
+```
+accelerate config
+```
+
+Or use our provided configs in `accelerate_configs`:
+```
+āāā accelerate_configs/
+| āāā 1_gpu.yaml
+| āāā 8_node_8_gpus_deepspeed_zero2.yaml (for 8 * 8 gpus)
+```
+
+### Stage 1.1: Pre-training on ImageNet
+First we use LLaDA-8B-Instruct to initialize our model, and train on ImageNet for basic visual capbalities.
+```
+accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada.py config=configs/mmada_pretraining_stage1_llada_instruct.yaml
+```
+
+### Stage 1.2 Pre-training on Image-Text Dataset
+Then we replace the ImageNet dataset in Stage 1.1 with Image-Text Dataset. Please change the pretrained model path in `mmada_pretraining_stage2_llada_instruct.yaml` with your checkpoint in Stage 1.1
+```
+accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage2.py config=configs/mmada_pretraining_stage2_llada_instruct.yaml
+```
+
+### Stage 1.3 Pre-training on Text Instruction following
+In this stage, we begin training on text instruction following and include corresponding validations. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.2
+```
+accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage3.py config=configs/mmada_pretraining_stage3_llada_instruct.yaml
+```
+
+### Stage 2.1 Mix-CoT Training (Text Only)
+In this stage, we begin our Mix-CoT finetuning with text reasoning first, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.3 and prepare your CoT data.
+```
+accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage_cot_sft.py config=configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml
+```
+
+### Stage 2.2 Mix-CoT Training (with MultiModal Reasoning)
+In this stage, we include multimodal reasoning, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 2.1 and prepare your CoT data.
+```
+accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage4.py config=configs/mmada_pretraining_stage4_llada_instruct.yaml
+```
+
+### Stage 3 UniGRPO RL
+[Will be released once we finished our code transition to OpenRLHF]
+
+
+## š Citation
+```
+@article{yang2025mmada,
+ title={MMaDA: Multimodal Large Diffusion Language Models},
+ author={Yang, Ling and Tian, Ye and Li, Bowen and Zhang, Xinchen and Shen, Ke and Tong, Yunhai and Wang, Mengdi},
+ journal={arXiv preprint arXiv:2505.15809},
+ year={2025}
+}
+```
+
+## š¤ Acknowledgments
+This work is heavily based on [Show-o](https://github.com/showlab/Show-o), [LLaDA](https://github.com/ML-GSAI/LLaDA), [maskgit](https://github.com/google-research/maskgit), [transformers](https://github.com/huggingface/transformers), [accelerate](https://github.com/huggingface/accelerate) and [webdataset](https://github.com/webdataset/webdataset). Thanks to all the authors for their great work.
+
+## š¬ Discussion and Collaboration
+
+Welcome to discuss and collaborate with us for continuously improving MMaDA. If you have any bad cases, please kindly share them in the [Issue](https://github.com/Gen-Verse/MMaDA/issues/4#issue-3083196081).
+
+Also, you can reach us with this WeChat QR code!
+
+
+
+
diff --git a/MMaDA/accelerate_configs/1_gpu.yaml b/MMaDA/accelerate_configs/1_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f4525c78804745e0805a5df05fee00197163a09
--- /dev/null
+++ b/MMaDA/accelerate_configs/1_gpu.yaml
@@ -0,0 +1,15 @@
+compute_environment: LOCAL_MACHINE
+distributed_type: 'NO'
+downcast_bf16: 'no'
+gpu_ids: '0'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30b2d800f85d86f33ec7d3385723a88885bffa36
--- /dev/null
+++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0e16c6bd6051caadf07df618bccabf55fdfe425b
--- /dev/null
+++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero3.yaml
@@ -0,0 +1,24 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 2
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea5b9f219008290ac4c62ded876e7e5bb6729880
--- /dev/null
+++ b/MMaDA/accelerate_configs/1_node_8_gpus_deepspeed_zero4.yaml
@@ -0,0 +1,24 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9029f3a85ad24f5e51b9188f13af001335a42c03
--- /dev/null
+++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+main_process_ip: 172.51.80.134
+main_training_function: main
+num_machines: 2
+num_processes: 16
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..57c02ff4043d86350223992514199024f3a409db
--- /dev/null
+++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero2_aidas2.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+main_process_ip: 172.51.80.136
+main_training_function: main
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..48df01331374b048882fca4dfe23484303c85f43
--- /dev/null
+++ b/MMaDA/accelerate_configs/2_node_8_gpus_deepspeed_zero4.yaml
@@ -0,0 +1,26 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 4
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+main_process_ip: 172.51.64.134
+main_training_function: main
+num_machines: 2
+num_processes: 16
+machine_rank: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml b/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..707453ded5101d525d7810574c128f661c59227c
--- /dev/null
+++ b/MMaDA/accelerate_configs/3_node_8_gpus_deepspeed_zero1.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 4
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+main_process_ip: 172.51.64.130
+main_training_function: main
+num_machines: 3
+num_processes: 24
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0dd7a4cbbbd7902ce3abc1c2cfac1b4649d7fbd8
--- /dev/null
+++ b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 4
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+main_training_function: main
+mixed_precision: bf16
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13b4a8d0f00d68de8373deb98974287431292e7d
--- /dev/null
+++ b/MMaDA/accelerate_configs/4_node_8_gpus_deepspeed_zero2_aidas.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: none #cpu
+ offload_param_device: none #cpu
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 2
+ zero_optimization:
+ overlap_comm: false
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: true
+main_process_ip: 172.51.133.6
+main_training_function: main
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml b/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a9df6e71c045830d81d6174c9a5ae3efbb3df126
--- /dev/null
+++ b/MMaDA/accelerate_configs/8_node_8_gpus_deepspeed_zero2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: true
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+main_training_function: main
+mixed_precision: bf16
+num_machines: 8
+num_processes: 64
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/MMaDA/app.py b/MMaDA/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dbb3aa3c14863e1292379a4ea34ad93e2b92cb6
--- /dev/null
+++ b/MMaDA/app.py
@@ -0,0 +1,894 @@
+import gradio as gr
+import torch
+import numpy as np
+import torch.nn.functional as F
+from transformers import AutoTokenizer
+from torchvision import transforms
+from models import MAGVITv2, get_mask_schedule, MMadaModelLM
+from training.prompting_utils import UniversalPrompting
+from PIL import Image
+
+def image_transform(image, resolution=256, normalize=True):
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
+ image = transforms.CenterCrop((resolution, resolution))(image)
+ image = transforms.ToTensor()(image)
+ if normalize:
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
+ return image
+
+def add_gumbel_noise(logits, temperature):
+ """
+ Adds Gumbel noise to logits for stochastic sampling.
+ Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
+ This version is more numerically stable than a version involving exp() and division.
+ """
+ if abs(temperature) < 1e-9: # Effectively zero temperature
+ return logits
+ # Ensure logits are float64 for precision with noise, as suggested by user context
+ if DEVICE == "mps":
+ logits = logits.to(torch.float32)
+ else:
+ logits = logits.to(torch.float64)
+ # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
+ # Add small epsilon for numerical stability inside logs
+ if DEVICE == "mps":
+ noise = torch.rand_like(logits, dtype=torch.float32)
+ else:
+ noise = torch.rand_like(logits, dtype=torch.float64)
+ standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
+ return logits + temperature * standard_gumbel_noise
+
+def get_num_transfer_tokens(mask_index, steps):
+ mask_num = mask_index.sum(dim=1, keepdim=True)
+ # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
+ steps = max(1, int(steps)) # Ensure steps is a positive integer
+ base = mask_num // steps
+ remainder = mask_num % steps
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
+ for i in range(mask_num.size(0)): # Iterate over batch
+ if remainder[i] > 0 : # Ensure remainder is positive before indexing
+ num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
+ return num_transfer_tokens
+
+MODEL = None
+TOKENIZER = None
+DEVICE = (
+ "cuda"
+ if torch.cuda.is_available()
+ else "mps" if torch.backends.mps.is_available() else "cpu"
+)
+MASK_ID = None
+uni_prompting = None
+VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
+
+DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
+CURRENT_MODEL_PATH = None
+
+MODEL_CHOICES = [
+ "MMaDA-8B-Base",
+ "MMaDA-8B-MixCoT (coming soon)",
+ "MMaDA-8B-Max (coming soon)"
+]
+MODEL_ACTUAL_PATHS = {
+ "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
+}
+
+def clear_outputs_action():
+ return None, None
+
+def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
+
+ if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
+ return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
+
+ CURRENT_MODEL_PATH = model_path_to_load
+
+ status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
+ try:
+ TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
+ status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
+
+ MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
+ status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
+
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
+
+ if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
+ MASK_ID = TOKENIZER.mask_token_id
+ status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
+ else:
+ MASK_ID = 126336
+ status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
+
+ if TOKENIZER.pad_token_id is None:
+ if TOKENIZER.eos_token_id is not None:
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
+ TOKENIZER.pad_token = TOKENIZER.eos_token
+ status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
+ else:
+ status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
+
+ if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
+ status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
+
+ TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
+
+ return " ".join(status_msg_parts)
+ except Exception as e:
+ MODEL = None
+ TOKENIZER = None
+ MASK_ID = None
+ CURRENT_MODEL_PATH = None
+ return f"Error loading model '{model_display_name_for_status}': {str(e)}"
+
+def handle_model_selection_change(selected_model_name_ui):
+ if "coming soon" in selected_model_name_ui.lower():
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
+ MODEL = None
+ TOKENIZER = None
+ MASK_ID = None
+ CURRENT_MODEL_PATH = None
+ return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
+
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
+ if not actual_path:
+ return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
+
+ return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
+
+
+def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
+ if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
+ return [("Error in sequence data for visualization.", "ERROR")]
+ # only answer part
+ current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
+ seq_ids = current_x_ids_batch[0].tolist()
+ eos_token_id = tk.eos_token_id # Get EOS token ID
+
+ # Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
+ # This helps in identifying EOS tokens later without re-checking the type.
+ intermediate_tuples = []
+ for j, token_id_int in enumerate(seq_ids):
+ try:
+ token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
+ token_str = f"[ID:{token_id_int}]"
+
+ label = "ERROR"
+ if token_id_int == current_mask_id:
+ token_str = "[MASK]"
+ label = "MASK"
+ else:
+ label = "GEN"
+ intermediate_tuples.append((token_str, label, token_id_int))
+
+ return intermediate_tuples
+
+@torch.no_grad()
+def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
+ global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
+
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
+ return
+ steps = int(steps)
+ guidance_scale = float(guidance_scale)
+
+ image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
+ prompt_text = [prompt_text]
+ input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
+
+ if guidance_scale > 0:
+ uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
+ else:
+ uncond_input_ids, uncond_attention_mask = None, None
+
+ mask_schedule = get_mask_schedule(mask_schedule)
+ blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
+ yield blank_image, "Starting generation..."
+ for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
+ input_ids = input_ids,
+ uncond_input_ids = uncond_input_ids,
+ attention_mask = attention_mask,
+ uncond_attention_mask = uncond_attention_mask,
+ temperature=1.0,
+ timesteps = steps,
+ guidance_scale = guidance_scale,
+ noise_schedule = mask_schedule,
+ noise_type = "mask",
+ seq_len = 1024,
+ vq_model = VQ_MODEL,
+ uni_prompting=uni_prompting):
+ yield image_step, status_msg_step
+
+
+
+
+@torch.no_grad()
+def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
+ cfg_scale, remasking_strategy, thinking_mode_lm):
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
+ print(f"thinking_mode_lm: {thinking_mode_lm}")
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
+ return
+
+ steps = int(steps)
+ gen_length = int(gen_length)
+ block_length = int(block_length)
+
+ if thinking_mode_lm:
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + prompt_text
+
+ try:
+ m = [{"role": "user", "content": prompt_text}]
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
+ except Exception as e:
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
+ processed_prompt_text = prompt_text
+ try:
+ if TOKENIZER.pad_token_id is None:
+ if TOKENIZER.eos_token_id is not None:
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
+ else: # Should have been caught by load_model, but double check
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
+ return
+
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
+ raw_prompt_attention_mask = None
+
+ except Exception as e:
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
+ return
+
+
+
+ batch_size = input_ids.shape[0]
+ prompt_len = input_ids.shape[1]
+
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
+ x[:, :prompt_len] = input_ids.clone()
+
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
+
+ if gen_length == 0:
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
+ return
+
+ if block_length <= 0 or gen_length % block_length != 0 :
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
+ return
+ num_blocks = gen_length // block_length
+
+ if steps <=0 or steps % num_blocks != 0:
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
+ return
+ steps_per_block = steps // num_blocks
+
+ for num_block_iter in range(num_blocks):
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
+
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
+
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
+ steps_per_block
+ )
+
+ for i_step_in_block in range(steps_per_block):
+ mask_index_global = (x == MASK_ID)
+
+ if cfg_scale > 0.:
+ un_x = x.clone()
+ # For unconditional pass, mask out the original prompt tokens that are not padding
+ # raw_prompt_attention_mask is (B, prompt_len)
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
+
+ x_cfg_input = torch.cat([x, un_x], dim=0)
+ # Pass attention_mask for CFG if model expects it, covering both parts
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
+ model_output = MODEL(x_cfg_input)
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
+ else:
+ # Not passing explicit attention_mask here; relies on model's internal handling.
+ model_output = MODEL(x)
+ logits = model_output.logits
+
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
+
+ if remasking_strategy == 'low_confidence':
+ if DEVICE == "mps":
+ probs = F.softmax(logits.to(torch.float32), dim=-1)
+ else:
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
+ elif remasking_strategy == 'random':
+ if DEVICE == "mps":
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32)
+ else:
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
+ else:
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
+ return
+
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
+ confidence_for_selection = torch.where(
+ candidate_positions_for_unmasking,
+ x0_probs,
+ -torch.inf
+ )
+
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
+
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
+
+ for j_batch_idx in range(batch_size):
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
+
+ if k_val > 0:
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
+ conf_slice = confidence_for_selection[j_batch_idx]
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
+
+ # Check if there are enough valid (non -inf) confidences
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
+ actual_k = min(k_val, valid_conf_count)
+
+ if actual_k > 0:
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
+
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
+
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
+ total_overall_steps = num_blocks * steps_per_block
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
+
+ final_generated_ids = x[:, prompt_len:]
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
+
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
+
+@torch.no_grad()
+def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
+ cfg_scale, remasking_strategy, thinking_mode_mmu):
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
+
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
+ return
+
+ steps = int(steps)
+ gen_length = int(gen_length)
+ block_length = int(block_length)
+
+ if thinking_mode_mmu:
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within tags, i.e. reasoning process here answer here\n" + prompt_text
+
+ try:
+ m = [{"role": "user", "content": prompt_text}]
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
+ except Exception as e:
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
+ processed_prompt_text = prompt_text
+
+ image_vq_ids_tensor = None
+ if uploaded_image_pil is not None:
+ try:
+
+ image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
+ image = image.unsqueeze(0)
+ image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
+ except Exception as e:
+ yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
+ return
+
+
+ try:
+ if TOKENIZER.pad_token_id is None:
+ if TOKENIZER.eos_token_id is not None:
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
+ else:
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
+ return
+
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
+ raw_prompt_attention_mask = None
+ if image_vq_ids_tensor is not None:
+ if image_vq_ids_tensor.ndim == 1:
+ image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
+
+ input_ids = torch.cat([
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
+ image_vq_ids_tensor,
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
+ input_ids
+ ], dim=1).long()
+
+ else:
+ input_ids = input_ids
+
+
+ except Exception as e:
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
+ return
+
+
+
+ batch_size = input_ids.shape[0]
+ prompt_len = input_ids.shape[1]
+
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
+ x[:, :prompt_len] = input_ids.clone()
+
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
+
+ if gen_length == 0:
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
+ return
+
+ if block_length <= 0 or gen_length % block_length != 0 :
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
+ return
+ num_blocks = gen_length // block_length
+
+ if steps <=0 or steps % num_blocks != 0:
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
+ return
+ steps_per_block = steps // num_blocks
+
+ for num_block_iter in range(num_blocks):
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
+
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
+
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
+ steps_per_block
+ )
+
+ for i_step_in_block in range(steps_per_block):
+ mask_index_global = (x == MASK_ID)
+
+ if cfg_scale > 0.:
+ un_x = x.clone()
+ # For unconditional pass, mask out the original prompt tokens that are not padding
+ # raw_prompt_attention_mask is (B, prompt_len)
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
+
+ x_cfg_input = torch.cat([x, un_x], dim=0)
+ # Pass attention_mask for CFG if model expects it, covering both parts
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
+ model_output = MODEL(x_cfg_input)
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
+ else:
+ # Not passing explicit attention_mask here; relies on model's internal handling.
+ model_output = MODEL(x)
+ logits = model_output.logits
+
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
+
+ if remasking_strategy == 'low_confidence':
+ if DEVICE == "mps":
+ probs = F.softmax(logits.to(torch.float32), dim=-1)
+ else:
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
+ elif remasking_strategy == 'random':
+ if DEVICE == "mps":
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float32)
+ else:
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
+ else:
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
+ return
+
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
+ confidence_for_selection = torch.where(
+ candidate_positions_for_unmasking,
+ x0_probs,
+ -torch.inf
+ )
+
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
+
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
+
+ for j_batch_idx in range(batch_size):
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
+
+ if k_val > 0:
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
+ conf_slice = confidence_for_selection[j_batch_idx]
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
+
+ # Check if there are enough valid (non -inf) confidences
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
+ actual_k = min(k_val, valid_conf_count)
+
+ if actual_k > 0:
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
+
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
+
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
+ total_overall_steps = num_blocks * steps_per_block
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
+
+ final_generated_ids = x[:, prompt_len:]
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
+
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
+
+
+css_styles = """
+.gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
+.gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
+.gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
+
+.highlighted-text span{
+ padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
+}
+
+footer{display:none !important}
+
+#live-update-scrollable-box {
+ max-height: 800px; /* ęØåÆä»„ę ¹ę®éč¦č°ę“čæäøŖę大é«åŗ¦ļ¼ä¾å¦ '300px', '50vh' ē */
+ overflow-y: auto !important; /* å½å å®¹č¶ åŗ max-height ę¶ę¾ē¤ŗåē“ę»åØę” */
+ display: block; /* ē”®äæå ē“ ęÆåēŗ§å ē“ ļ¼ä»„便 max-height ēę */
+
+}
+#think_btn {
+ background-color: #f3f4f6 !important;
+ border: 1px solid #d0d0d0 !important;
+ color: #111827 !important;
+ font-size: 16px !important;
+ font-weight: bold !important;
+}
+#think_btn:hover {
+ background-color: #e0e0e0 !important;
+ border: 1px solid #c0c0c0 !important;
+ color: #222 !important;
+}
+#think_btn:active {
+ background-color: #2563eb !important;
+ border: 1px solid #b0b0b0 !important;
+ color: white !important;
+}
+"""
+
+
+# thinking_mode_t2i = gr.State(False)
+def toggle_thinking_mode_lm(current_thinking_mode):
+ # print(f"current_thinking_mode: {current_thinking_mode}")
+ new_state = not current_thinking_mode
+ new_label = "Thinking Mode ā " if new_state else "Thinking Mode ā"
+ return new_state, gr.update(value=new_label)
+
+def toggle_thinking_mode_mmu(current_thinking_mode):
+ new_state = not current_thinking_mode
+ new_label = "Thinking Mode ā " if new_state else "Thinking Mode ā"
+ return new_state, gr.update(value=new_label)
+
+
+color_map_config = {
+ "MASK": "lightgrey",
+ "GEN": "#DCABFA",
+}
+
+theme = gr.themes.Ocean(
+ primary_hue="fuchsia",
+)
+with gr.Blocks(css=css_styles, theme=theme) as demo:
+# with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
+# with gr.Blocks() as demo:
+ thinking_mode_lm = gr.State(False)
+ thinking_mode_mmu = gr.State(False)
+ gr.Markdown("
"
+ "Create speech, "
+ "transcribe audio, "
+ "describe video, "
+ "chat with text, and "
+ "generate or edit images ā all from a single model. "
+ "Use the advanced sections when you want tighter control."
+ "
"
+
+ out = []
+ for i in str_input:
+ # Skip OOV
+ if i not in r_map:
+ continue
+ out.append(r_map[i.item()])
+
+ return " ".join(out)
+
+
+class CharParser:
+ """Functor for parsing raw strings into list of int tokens.
+
+ Examples:
+ >>> parser = CharParser(['a', 'b', 'c'])
+ >>> parser('abc')
+ [0, 1, 2]
+ """
+
+ def __init__(
+ self,
+ labels: List[str],
+ *,
+ unk_id: int = -1,
+ blank_id: int = -1,
+ do_normalize: bool = True,
+ do_lowercase: bool = True,
+ add_end_space: bool = False
+ ):
+ """Creates simple mapping char parser.
+
+ Args:
+ labels: List of labels to allocate indexes for. Essentially,
+ this is a id to str mapping.
+ unk_id: Index to choose for OOV words (default: -1).
+ blank_id: Index to filter out from final list of tokens
+ (default: -1).
+ do_normalize: True if apply normalization step before tokenizing
+ (default: True).
+ do_lowercase: True if apply lowercasing at normalizing step
+ (default: True).
+ """
+
+ self._labels = labels
+ self._unk_id = unk_id
+ self._blank_id = blank_id
+ self._do_normalize = do_normalize
+ self._do_lowercase = do_lowercase
+
+ self._labels_map = {label: index for index, label in enumerate(labels)}
+ self._special_labels = set([label for label in labels if len(label) > 1])
+
+ print('INFO: CharParser add_end_space: {}'.format(add_end_space))
+ self.add_end_space = add_end_space
+
+ def __call__(self, text: str) -> Optional[List[int]]:
+ if self._do_normalize:
+ text = self._normalize(text)
+ if text is None:
+ return None
+
+ text_tokens = self._tokenize(text)
+
+ return text_tokens
+
+ def _normalize(self, text: str) -> Optional[str]:
+ text = text.strip()
+
+ if self._do_lowercase:
+ text = text.lower()
+
+ return text
+
+ def _tokenize(self, text: str) -> List[int]:
+ tokens = []
+ # Split by word for find special labels.
+ for word_id, word in enumerate(text.split(' ')):
+ if word_id != 0 and not self.add_end_space: # Not first word - so we insert space before.
+ tokens.append(self._labels_map.get(' ', self._unk_id))
+
+ if word in self._special_labels:
+ tokens.append(self._labels_map[word])
+ continue
+
+ for char in word:
+ tokens.append(self._labels_map.get(char, self._unk_id))
+
+ if self.add_end_space:
+ tokens.append(self._labels_map.get(' ', self._unk_id))
+
+ # If unk_id == blank_id, OOV tokens are removed.
+ tokens = [token for token in tokens if token != self._blank_id]
+
+ return tokens
+
+
+class ENCharParser(CharParser):
+ """Incorporates english-specific parsing logic."""
+
+ PUNCTUATION_TO_REPLACE = frozendict.frozendict({'+': 'plus', '&': 'and', '%': 'percent'})
+
+ def __init__(self, *args, **kwargs):
+ """Creates english-specific mapping char parser.
+
+ This class overrides normalizing implementation.
+
+ Args:
+ *args: Positional args to pass to `CharParser` constructor.
+ **kwargs: Key-value args to pass to `CharParser` constructor.
+ """
+
+ super().__init__(*args, **kwargs)
+
+ self._table = self.__make_trans_table()
+
+ def __make_trans_table(self):
+ punctuation = string.punctuation
+
+ for char in self.PUNCTUATION_TO_REPLACE:
+ punctuation = punctuation.replace(char, '')
+
+ for label in self._labels:
+ punctuation = punctuation.replace(label, '')
+
+ table = str.maketrans(punctuation, ' ' * len(punctuation))
+
+ return table
+
+ def _normalize(self, text: str) -> Optional[str]:
+ # noinspection PyBroadException
+ try:
+ text = cleaners.clean_text(
+ string=text, table=self._table, punctuation_to_replace=self.PUNCTUATION_TO_REPLACE,
+ )
+ except Exception:
+ return None
+
+ return text
+
+
+NAME_TO_PARSER = frozendict.frozendict({'base': CharParser, 'en': ENCharParser})
+
+
+def make_parser(labels: Optional[List[str]] = None, name: str = 'base', **kwargs, ) -> CharParser:
+ """Creates parser from labels, set of arguments and concise parser name.
+
+ Args:
+ labels: List of labels to allocate indexes for. If set to
+ None then labels would be ascii table list. Essentially, this is a
+ id to str mapping (default: None).
+ name: Concise name of parser to create (default: 'base').
+ (default: -1).
+ **kwargs: Other set of kwargs to pass to parser constructor.
+
+ Returns:
+ Instance of `CharParser`.
+
+ Raises:
+ ValueError: For invalid parser name.
+
+ Examples:
+ >>> type(make_parser(['a', 'b', 'c'], 'en'))
+ ENCharParser
+ """
+
+ if name not in NAME_TO_PARSER:
+ raise ValueError('Invalid parser name.')
+
+ if labels is None:
+ labels = list(string.printable)
+
+ parser_type = NAME_TO_PARSER[name]
+ parser = parser_type(labels=labels, **kwargs)
+
+ return parser
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py
new file mode 100644
index 0000000000000000000000000000000000000000..310b4a859e99fd85999554eee0cb2bf426090675
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/perturb.py
@@ -0,0 +1,1028 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Copyright (c) 2018 Ryan Leary
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# This file contains code artifacts adapted from https://github.com/ryanleary/patter
+import copy
+import io
+import os
+import random
+import subprocess
+from tempfile import NamedTemporaryFile
+from typing import List, Optional, Union
+
+import librosa
+import numpy as np
+import pandas
+import soundfile as sf
+import sox
+import webdataset as wd
+from omegaconf import DictConfig, OmegaConf
+from scipy import signal
+from torch.utils.data import IterableDataset
+
+from nemo.collections.asr.parts import collections, parsers
+from nemo.collections.asr.parts.segment import AudioSegment
+from nemo.utils import logging
+
+try:
+ from nemo.collections.asr.parts import numba_utils
+
+ HAVE_NUMBA = True
+except (ImportError, ModuleNotFoundError):
+ HAVE_NUMBA = False
+
+
+def read_one_audiosegment(manifest, target_sr, rng, tarred_audio=False, audio_dataset=None):
+
+ if tarred_audio:
+ if audio_dataset is None:
+ raise TypeError("Expected augmentation dataset but got None")
+ audio_file, file_id = next(audio_dataset)
+ manifest_idx = manifest.mapping[file_id]
+ manifest_entry = manifest[manifest_idx]
+
+ offset = 0 if manifest_entry.offset is None else manifest_entry.offset
+ duration = 0 if manifest_entry.duration is None else manifest_entry.duration
+
+ else:
+ audio_record = rng.sample(manifest.data, 1)[0]
+ audio_file = audio_record.audio_file
+ offset = 0 if audio_record.offset is None else audio_record.offset
+ duration = 0 if audio_record.duration is None else audio_record.duration
+
+ return AudioSegment.from_file(audio_file, target_sr=target_sr, offset=offset, duration=duration)
+
+
+class Perturbation(object):
+ def max_augmentation_length(self, length):
+ return length
+
+ def perturb(self, data):
+ raise NotImplementedError
+
+
+class SpeedPerturbation(Perturbation):
+ def __init__(self, sr, resample_type, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, rng=None):
+ """
+ Performs Speed Augmentation by re-sampling the data to a different sampling rate,
+ which does not preserve pitch.
+
+ Note: This is a very slow operation for online augmentation. If space allows,
+ it is preferable to pre-compute and save the files to augment the dataset.
+
+ Args:
+ sr: Original sampling rate.
+ resample_type: Type of resampling operation that will be performed.
+ For better speed using `resampy`'s fast resampling method, use `resample_type='kaiser_fast'`.
+ For high-quality resampling, set `resample_type='kaiser_best'`.
+ To use `scipy.signal.resample`, set `resample_type='fft'` or `resample_type='scipy'`
+ min_speed_rate: Minimum sampling rate modifier.
+ max_speed_rate: Maximum sampling rate modifier.
+ num_rates: Number of discrete rates to allow. Can be a positive or negative
+ integer.
+ If a positive integer greater than 0 is provided, the range of
+ speed rates will be discretized into `num_rates` values.
+ If a negative integer or 0 is provided, the full range of speed rates
+ will be sampled uniformly.
+ Note: If a positive integer is provided and the resultant discretized
+ range of rates contains the value '1.0', then those samples with rate=1.0,
+ will not be augmented at all and simply skipped. This is to unnecessary
+ augmentation and increase computation time. Effective augmentation chance
+ in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance
+ where `prob` is the global probability of a sample being augmented.
+ rng: Random seed number.
+ """
+ min_rate = min(min_speed_rate, max_speed_rate)
+ if min_rate < 0.0:
+ raise ValueError("Minimum sampling rate modifier must be > 0.")
+
+ if resample_type not in ('kaiser_best', 'kaiser_fast', 'fft', 'scipy'):
+ raise ValueError("Supported `resample_type` values are ('kaiser_best', 'kaiser_fast', 'fft', 'scipy')")
+
+ self._sr = sr
+ self._min_rate = min_speed_rate
+ self._max_rate = max_speed_rate
+ self._num_rates = num_rates
+ if num_rates > 0:
+ self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True)
+ self._res_type = resample_type
+ self._rng = random.Random() if rng is None else rng
+
+ def max_augmentation_length(self, length):
+ return length * self._max_rate
+
+ def perturb(self, data):
+ # Select speed rate either from choice or random sample
+ if self._num_rates < 0:
+ speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
+ else:
+ speed_rate = self._rng.choice(self._rates)
+
+ # Skip perturbation in case of identity speed rate
+ if speed_rate == 1.0:
+ return
+
+ new_sr = int(self._sr * speed_rate)
+ data._samples = librosa.core.resample(data._samples, self._sr, new_sr, res_type=self._res_type)
+
+
+class TimeStretchPerturbation(Perturbation):
+ def __init__(self, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, n_fft=512, rng=None):
+ """
+ Time-stretch an audio series by a fixed rate while preserving pitch, based on [1, 2].
+
+ Note:
+ This is a simplified implementation, intended primarily for reference and pedagogical purposes.
+ It makes no attempt to handle transients, and is likely to produce audible artifacts.
+
+ Reference
+ [1] [Ellis, D. P. W. āA phase vocoder in Matlab.ā Columbia University, 2002.]
+ (http://www.ee.columbia.edu/~dpwe/resources/matlab/pvoc/)
+ [2] [librosa.effects.time_stretch]
+ (https://librosa.github.io/librosa/generated/librosa.effects.time_stretch.html)
+
+ Args:
+ min_speed_rate: Minimum sampling rate modifier.
+ max_speed_rate: Maximum sampling rate modifier.
+ num_rates: Number of discrete rates to allow. Can be a positive or negative
+ integer.
+ If a positive integer greater than 0 is provided, the range of
+ speed rates will be discretized into `num_rates` values.
+ If a negative integer or 0 is provided, the full range of speed rates
+ will be sampled uniformly.
+ Note: If a positive integer is provided and the resultant discretized
+ range of rates contains the value '1.0', then those samples with rate=1.0,
+ will not be augmented at all and simply skipped. This is to avoid unnecessary
+ augmentation and increase computation time. Effective augmentation chance
+ in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance
+ where `prob` is the global probability of a sample being augmented.
+ n_fft: Number of fft filters to be computed.
+ rng: Random seed number.
+ """
+ min_rate = min(min_speed_rate, max_speed_rate)
+ if min_rate < 0.0:
+ raise ValueError("Minimum sampling rate modifier must be > 0.")
+
+ self._min_rate = min_speed_rate
+ self._max_rate = max_speed_rate
+ self._num_rates = num_rates
+ if num_rates > 0:
+ self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True)
+ self._rng = random.Random() if rng is None else rng
+
+ # Pre-compute constants
+ self._n_fft = int(n_fft)
+ self._hop_length = int(n_fft // 2)
+
+ # Pre-allocate buffers
+ self._phi_advance_fast = np.linspace(0, np.pi * self._hop_length, self._hop_length + 1)
+ self._scale_buffer_fast = np.empty(self._hop_length + 1, dtype=np.float32)
+
+ self._phi_advance_slow = np.linspace(0, np.pi * self._n_fft, self._n_fft + 1)
+ self._scale_buffer_slow = np.empty(self._n_fft + 1, dtype=np.float32)
+
+ def max_augmentation_length(self, length):
+ return length * self._max_rate
+
+ def perturb(self, data):
+ # Select speed rate either from choice or random sample
+ if self._num_rates < 0:
+ speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
+ else:
+ speed_rate = self._rng.choice(self._rates)
+
+ # Skip perturbation in case of identity speed rate
+ if speed_rate == 1.0:
+ return
+
+ # Increase `n_fft` based on task (speed up or slow down audio)
+ # This greatly reduces upper bound of maximum time taken
+ # to compute slowed down audio segments.
+ if speed_rate >= 1.0: # Speed up audio
+ fft_multiplier = 1
+ phi_advance = self._phi_advance_fast
+ scale_buffer = self._scale_buffer_fast
+
+ else: # Slow down audio
+ fft_multiplier = 2
+ phi_advance = self._phi_advance_slow
+ scale_buffer = self._scale_buffer_slow
+
+ n_fft = int(self._n_fft * fft_multiplier)
+ hop_length = int(self._hop_length * fft_multiplier)
+
+ # Perform short-term Fourier transform (STFT)
+ stft = librosa.core.stft(data._samples, n_fft=n_fft, hop_length=hop_length)
+
+ # Stretch by phase vocoding
+ if HAVE_NUMBA:
+ stft_stretch = numba_utils.phase_vocoder(stft, speed_rate, phi_advance, scale_buffer)
+
+ else:
+ stft_stretch = librosa.core.phase_vocoder(stft, speed_rate, hop_length)
+
+ # Predict the length of y_stretch
+ len_stretch = int(round(len(data._samples) / speed_rate))
+
+ # Invert the STFT
+ y_stretch = librosa.core.istft(
+ stft_stretch, dtype=data._samples.dtype, hop_length=hop_length, length=len_stretch
+ )
+
+ data._samples = y_stretch
+
+
+class TempoPerturbation(Perturbation):
+ def __init__(self, factors, rng=None):
+ assert len(factors) > 0
+ assert min(factors) > 0
+ self.factors = factors
+ self._max_factor = max(self.factors)
+ self._rng = random.Random() if rng is None else rng
+
+ def max_augmentation_length(self, length):
+ return length * self._max_factor
+
+ def perturb(self, data):
+ speed_rate = self._rng.choice(self.factors)
+ if speed_rate == 1.0:
+ return
+
+ tfm = sox.Transformer()
+ if abs(speed_rate - 1.0) <= 0.1:
+ tfm.stretch(speed_rate)
+ else:
+ tfm.tempo(speed_rate)
+ perturbed_data = tfm.build_array(input_array=data._samples, sample_rate_in=data._sample_rate)
+
+ data._samples = perturbed_data
+
+
+class GainPerturbation(Perturbation):
+ """
+ Applies random gain to the audio.
+
+ Args:
+ min_gain_dbfs (float): Min gain level in dB
+ max_gain_dbfs (float): Max gain level in dB
+ rng: Random number generator
+ """
+
+ def __init__(self, min_gain_dbfs=-10, max_gain_dbfs=10, rng=None):
+ self._min_gain_dbfs = min_gain_dbfs
+ self._max_gain_dbfs = max_gain_dbfs
+ self._rng = random.Random() if rng is None else rng
+
+ def perturb(self, data):
+ gain = self._rng.uniform(self._min_gain_dbfs, self._max_gain_dbfs)
+ # logging.debug("gain: %d", gain)
+ data._samples = data._samples * (10.0 ** (gain / 20.0))
+
+
+class ImpulsePerturbation(Perturbation):
+ """
+ Convolves audio with a Room Impulse Response.
+
+ Args:
+ manifest_path (list): Manifest file for RIRs
+ audio_tar_filepaths (list): Tar files, if RIR audio files are tarred
+ shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
+ shift_impulse (bool): Shift impulse response to adjust for delay at the beginning
+ """
+
+ def __init__(self, manifest_path=None, rng=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False):
+ self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
+ self._audiodataset = None
+ self._tarred_audio = False
+ self._shift_impulse = shift_impulse
+ self._data_iterator = None
+
+ if audio_tar_filepaths:
+ self._tarred_audio = True
+ self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n)
+ self._data_iterator = iter(self._audiodataset)
+
+ self._rng = random.Random() if rng is None else rng
+
+ def perturb(self, data):
+ impulse = read_one_audiosegment(
+ self._manifest,
+ data.sample_rate,
+ self._rng,
+ tarred_audio=self._tarred_audio,
+ audio_dataset=self._data_iterator,
+ )
+ if not self._shift_impulse:
+ impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
+ data._samples = signal.fftconvolve(data._samples, impulse_norm, "same")
+ else:
+ # Find peak and shift peak to left
+ impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
+ max_ind = np.argmax(np.abs(impulse_norm))
+
+ impulse_resp = impulse_norm[max_ind:]
+ delay_after = len(impulse_resp)
+ data._samples = signal.fftconvolve(data._samples, impulse_resp, "full")[:-delay_after]
+
+
+class ShiftPerturbation(Perturbation):
+ """
+ Perturbs audio by shifting the audio in time by a random amount between min_shift_ms and max_shift_ms.
+ The final length of the audio is kept unaltered by padding the audio with zeros.
+
+
+ Args:
+ min_shift_ms (float): Minimum time in milliseconds by which audio will be shifted
+ max_shift_ms (float): Maximum time in milliseconds by which audio will be shifted
+ rng: Random number generator
+ """
+
+ def __init__(self, min_shift_ms=-5.0, max_shift_ms=5.0, rng=None):
+ self._min_shift_ms = min_shift_ms
+ self._max_shift_ms = max_shift_ms
+ self._rng = random.Random() if rng is None else rng
+
+ def perturb(self, data):
+ shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
+ if abs(shift_ms) / 1000 > data.duration:
+ # TODO: do something smarter than just ignore this condition
+ return
+ shift_samples = int(shift_ms * data.sample_rate // 1000)
+ # logging.debug("shift: %s", shift_samples)
+ if shift_samples < 0:
+ data._samples[-shift_samples:] = data._samples[:shift_samples]
+ data._samples[:-shift_samples] = 0
+ elif shift_samples > 0:
+ data._samples[:-shift_samples] = data._samples[shift_samples:]
+ data._samples[-shift_samples:] = 0
+
+
+class NoisePerturbation(Perturbation):
+ """
+ Perturbation that adds noise to input audio.
+
+ Args:
+ manifest_path (str): Manifest file with paths to noise files
+ min_snr_db (float): Minimum SNR of audio after noise is added
+ max_snr_db (float): Maximum SNR of audio after noise is added
+ max_gain_db (float): Maximum gain that can be applied on the noise sample
+ audio_tar_filepaths (list) : Tar files, if noise audio files are tarred
+ shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
+ orig_sr (int): Original sampling rate of the noise files
+ rng: Random number generator
+ """
+
+ def __init__(
+ self,
+ manifest_path=None,
+ min_snr_db=10,
+ max_snr_db=50,
+ max_gain_db=300.0,
+ rng=None,
+ audio_tar_filepaths=None,
+ shuffle_n=100,
+ orig_sr=16000,
+ ):
+ self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
+ self._audiodataset = None
+ self._tarred_audio = False
+ self._orig_sr = orig_sr
+ self._data_iterator = None
+
+ if audio_tar_filepaths:
+ self._tarred_audio = True
+ self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n)
+ self._data_iterator = iter(self._audiodataset)
+
+ self._rng = random.Random() if rng is None else rng
+ self._min_snr_db = min_snr_db
+ self._max_snr_db = max_snr_db
+ self._max_gain_db = max_gain_db
+
+ @property
+ def orig_sr(self):
+ return self._orig_sr
+
+ def get_one_noise_sample(self, target_sr):
+ return read_one_audiosegment(
+ self._manifest, target_sr, self._rng, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator
+ )
+
+ def perturb(self, data):
+ noise = read_one_audiosegment(
+ self._manifest,
+ data.sample_rate,
+ self._rng,
+ tarred_audio=self._tarred_audio,
+ audio_dataset=self._data_iterator,
+ )
+ self.perturb_with_input_noise(data, noise)
+
+ def perturb_with_input_noise(self, data, noise, data_rms=None):
+ snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
+ if data_rms is None:
+ data_rms = data.rms_db
+ noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
+ # logging.debug("noise: %s %s %s", snr_db, noise_gain_db, noise_record.audio_file)
+
+ # calculate noise segment to use
+ start_time = self._rng.uniform(0.0, noise.duration - data.duration)
+ if noise.duration > (start_time + data.duration):
+ noise.subsegment(start_time=start_time, end_time=start_time + data.duration)
+
+ # adjust gain for snr purposes and superimpose
+ noise.gain_db(noise_gain_db)
+
+ if noise._samples.shape[0] < data._samples.shape[0]:
+ noise_idx = self._rng.randint(0, data._samples.shape[0] - noise._samples.shape[0])
+ data._samples[noise_idx : noise_idx + noise._samples.shape[0]] += noise._samples
+
+ else:
+ data._samples += noise._samples
+
+ def perturb_with_foreground_noise(
+ self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1,
+ ):
+ snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
+ if not data_rms:
+ data_rms = data.rms_db
+
+ noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
+ n_additions = self._rng.randint(1, max_additions)
+
+ for i in range(n_additions):
+ noise_dur = self._rng.uniform(0.0, max_noise_dur)
+ start_time = self._rng.uniform(0.0, noise.duration)
+ start_sample = int(round(start_time * noise.sample_rate))
+ end_sample = int(round(min(noise.duration, (start_time + noise_dur)) * noise.sample_rate))
+ noise_samples = np.copy(noise._samples[start_sample:end_sample])
+ # adjust gain for snr purposes and superimpose
+ noise_samples *= 10.0 ** (noise_gain_db / 20.0)
+
+ if noise_samples.shape[0] > data._samples.shape[0]:
+ noise_samples = noise_samples[0 : data._samples.shape[0]]
+
+ noise_idx = self._rng.randint(0, data._samples.shape[0] - noise_samples.shape[0])
+ data._samples[noise_idx : noise_idx + noise_samples.shape[0]] += noise_samples
+
+
+class RandomNoisePerturbation(Perturbation):
+ """
+ Perturbation that adds noise to input audio.
+
+ Args:
+ manifest_path (str): Manifest file with paths to noise files
+ min_snr_db (float): Minimum SNR of audio after noise is added
+ max_snr_db (float): Maximum SNR of audio after noise is added
+ max_gain_db (float): Maximum gain that can be applied on the noise sample
+ audio_tar_filepaths (list) : Tar files, if noise audio files are tarred
+ shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
+ orig_sr (int): Original sampling rate of the noise files
+ rng: Random number generator
+ """
+
+ def __init__(
+ self,
+ manifest_path=None,
+ min_snr_db=10,
+ max_snr_db=50,
+ max_gain_db=300.0,
+ ratio=1.0,
+ rng=None,
+ target_sr=16000,
+ data_dir='',
+ cache_noise=False
+ ):
+
+ self._rng = random.Random() if rng is None else rng
+ self._min_snr_db = min_snr_db
+ self._max_snr_db = max_snr_db
+ self._max_gain_db = max_gain_db
+ self.ratio = ratio
+
+ self.data_dir = data_dir
+ self.target_sr = target_sr
+ manifest, self._noise_weights = self.read_manifest(manifest_path)
+ self.noise_files = manifest['wav_filename'].tolist()
+ self._cache = {}
+ self.cache_noise = cache_noise
+
+ def read_manifest(self, manifest_fps):
+ manifest_files = []
+ for fp in manifest_fps:
+ manifest_files.append(pandas.read_csv(fp, encoding='utf-8'))
+ manifest = pandas.concat(manifest_files)
+
+ orig_noise_num = len(manifest)
+ wav_header_size = 44
+ # only use noise with duration longer than 1 second
+ manifest = manifest[manifest['wav_filesize'] > (1 * 16000 * 2 + wav_header_size)]
+ print('filter noise less than 1s: from {} to {} samples'.format(orig_noise_num, len(manifest)))
+ wav_data_size = manifest['wav_filesize'].values - wav_header_size
+ print('noise duration sum: {}h'.format(wav_data_size.sum() / (16000 * 2) / 60 / 60))
+ noise_weights = wav_data_size / wav_data_size.sum()
+ return manifest, noise_weights.tolist()
+
+ def perturb(self, data):
+ if self._rng.random() < self.ratio:
+ noises = self.get_noises(data.num_samples)
+ self.perturb_with_input_noise(data, noises)
+
+ def perturb_with_input_noise(self, data, noises):
+ snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
+ data_rms = rms_db(data._samples)
+
+ start_index = 0
+ for noise_i in noises:
+ noise_gain_db = min(data_rms - rms_db(noise_i) - snr_db, self._max_gain_db)
+ # logging.debug("noise: %s %s %s", snr_db, noise_gain_db, noise_record.audio_file)
+
+ # adjust gain for snr purposes and superimpose
+ noise_i = gain_db(noise_i, noise_gain_db)
+
+ end_index = start_index + noise_i.shape[0]
+ data._samples[start_index:end_index] += noise_i
+ start_index = end_index
+ assert end_index == data.num_samples
+
+ def get_noises(self, num_samples):
+ left_noise_samples = num_samples
+ noise_data = []
+ while left_noise_samples > 0:
+ noise = self.read_one_noise()
+ if noise.shape[0] > left_noise_samples:
+ start_pos = self._rng.randrange(0, noise.shape[0] - left_noise_samples + 1)
+ noise = noise[start_pos:start_pos + left_noise_samples]
+ left_noise_samples -= noise.shape[0]
+ noise_data.append(noise)
+ assert left_noise_samples == 0
+ return noise_data
+
+ def read_one_noise(self):
+ fp = self._rng.choices(self.noise_files, weights=self._noise_weights)[0]
+ fp = os.path.join(self.data_dir, fp)
+
+ cached_noise = self._cache.get(fp, None)
+ if cached_noise is None:
+ cached_noise = AudioSegment.from_file(fp, target_sr=self.target_sr)._samples
+ if self.cache_noise:
+ self._cache[fp] = cached_noise
+ return cached_noise.copy()
+
+
+def rms_db(samples):
+ mean_square = np.mean(samples ** 2)
+ if mean_square == 0:
+ return -np.inf
+ else:
+ return 10 * np.log10(mean_square)
+
+
+def gain_db(samples, gain):
+ return samples * (10.0 ** (gain / 20.0))
+
+
+class WhiteNoisePerturbation(Perturbation):
+ """
+ Perturbation that adds white noise to an audio file in the training dataset.
+
+ Args:
+ min_level (int): Minimum level in dB at which white noise should be added
+ max_level (int): Maximum level in dB at which white noise should be added
+ rng: Random number generator
+ """
+
+ def __init__(self, min_level=-90, max_level=-46, rng=None):
+ self.min_level = int(min_level)
+ self.max_level = int(max_level)
+ self._rng = np.random.RandomState() if rng is None else rng
+
+ def perturb(self, data):
+ noise_level_db = self._rng.randint(self.min_level, self.max_level, dtype='int32')
+ noise_signal = self._rng.randn(data._samples.shape[0]) * (10.0 ** (noise_level_db / 20.0))
+ data._samples += noise_signal
+
+
+class RirAndNoisePerturbation(Perturbation):
+ def __init__(
+ self,
+ rir_manifest_path=None,
+ rir_prob=0.5,
+ noise_manifest_paths=None,
+ min_snr_db=0,
+ max_snr_db=50,
+ rir_tar_filepaths=None,
+ rir_shuffle_n=100,
+ noise_tar_filepaths=None,
+ apply_noise_rir=False,
+ orig_sample_rate=None,
+ max_additions=5,
+ max_duration=2.0,
+ bg_noise_manifest_paths=None,
+ bg_min_snr_db=10,
+ bg_max_snr_db=50,
+ bg_noise_tar_filepaths=None,
+ bg_orig_sample_rate=None,
+ ):
+ """
+ RIR augmentation with additive foreground and background noise.
+ In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response
+ and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises
+ should either be supplied with a manifest file or as tarred audio files (faster).
+
+ Different sets of noise audio files based on the original sampling rate of the noise. This is useful while
+ training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a
+ target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise.
+
+ Args:
+ rir_manifest_path: manifest file for RIRs
+ rir_tar_filepaths: tar files, if RIR audio files are tarred
+ rir_prob: probability of applying a RIR
+ noise_manifest_paths: foreground noise manifest path
+ min_snr_db: min SNR for foreground noise
+ max_snr_db: max SNR for background noise,
+ noise_tar_filepaths: tar files, if noise files are tarred
+ apply_noise_rir: whether to convolve foreground noise with a a random RIR
+ orig_sample_rate: original sampling rate of foreground noise audio
+ max_additions: max number of times foreground noise is added to an utterance,
+ max_duration: max duration of foreground noise
+ bg_noise_manifest_paths: background noise manifest path
+ bg_min_snr_db: min SNR for background noise
+ bg_max_snr_db: max SNR for background noise
+ bg_noise_tar_filepaths: tar files, if noise files are tarred
+ bg_orig_sample_rate: original sampling rate of background noise audio
+
+ """
+ logging.info("Called Rir aug init")
+ self._rir_prob = rir_prob
+ self._rng = random.Random()
+ self._rir_perturber = ImpulsePerturbation(
+ manifest_path=rir_manifest_path,
+ audio_tar_filepaths=rir_tar_filepaths,
+ shuffle_n=rir_shuffle_n,
+ shift_impulse=True,
+ )
+ self._fg_noise_perturbers = {}
+ self._bg_noise_perturbers = {}
+ if noise_manifest_paths:
+ for i in range(len(noise_manifest_paths)):
+ if orig_sample_rate is None:
+ orig_sr = 16000
+ else:
+ orig_sr = orig_sample_rate[i]
+ self._fg_noise_perturbers[orig_sr] = NoisePerturbation(
+ manifest_path=noise_manifest_paths[i],
+ min_snr_db=min_snr_db[i],
+ max_snr_db=max_snr_db[i],
+ audio_tar_filepaths=noise_tar_filepaths[i],
+ orig_sr=orig_sr,
+ )
+ self._max_additions = max_additions
+ self._max_duration = max_duration
+ if bg_noise_manifest_paths:
+ for i in range(len(bg_noise_manifest_paths)):
+ if bg_orig_sample_rate is None:
+ orig_sr = 16000
+ else:
+ orig_sr = bg_orig_sample_rate[i]
+ self._bg_noise_perturbers[orig_sr] = NoisePerturbation(
+ manifest_path=bg_noise_manifest_paths[i],
+ min_snr_db=bg_min_snr_db[i],
+ max_snr_db=bg_max_snr_db[i],
+ audio_tar_filepaths=bg_noise_tar_filepaths[i],
+ orig_sr=orig_sr,
+ )
+
+ self._apply_noise_rir = apply_noise_rir
+
+ def perturb(self, data):
+ prob = self._rng.uniform(0.0, 1.0)
+
+ if prob < self._rir_prob:
+ self._rir_perturber.perturb(data)
+
+ orig_sr = data.orig_sr
+ if orig_sr not in self._fg_noise_perturbers:
+ orig_sr = max(self._fg_noise_perturbers.keys())
+ fg_perturber = self._fg_noise_perturbers[orig_sr]
+
+ orig_sr = data.orig_sr
+ if orig_sr not in self._bg_noise_perturbers:
+ orig_sr = max(self._bg_noise_perturbers.keys())
+ bg_perturber = self._bg_noise_perturbers[orig_sr]
+
+ data_rms = data.rms_db
+ noise = fg_perturber.get_one_noise_sample(data.sample_rate)
+ if self._apply_noise_rir:
+ self._rir_perturber.perturb(noise)
+ fg_perturber.perturb_with_foreground_noise(
+ data, noise, data_rms=data_rms, max_noise_dur=self._max_duration, max_additions=self._max_additions
+ )
+ noise = bg_perturber.get_one_noise_sample(data.sample_rate)
+ bg_perturber.perturb_with_input_noise(data, noise, data_rms=data_rms)
+
+
+class TranscodePerturbation(Perturbation):
+ def __init__(self, rng=None):
+ """
+ Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs,
+ so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb).
+
+ """
+ self._rng = np.random.RandomState() if rng is None else rng
+ self._codecs = ["g711", "amr-nb"]
+
+ def perturb(self, data):
+ att_factor = 0.8
+ max_level = np.max(np.abs(data._samples))
+ norm_factor = att_factor / max_level
+ norm_samples = norm_factor * data._samples
+ orig_f = NamedTemporaryFile(suffix=".wav")
+ sf.write(orig_f.name, norm_samples.transpose(), 16000)
+
+ codec_ind = random.randint(0, len(self._codecs) - 1)
+ if self._codecs[codec_ind] == "amr-nb":
+ transcoded_f = NamedTemporaryFile(suffix="_amr.wav")
+ rates = list(range(0, 8))
+ rate = rates[random.randint(0, len(rates) - 1)]
+ _ = subprocess.check_output(
+ f"sox {orig_f.name} -V0 -C {rate} -t amr-nb - | sox -t amr-nb - -V0 -b 16 -r 16000 {transcoded_f.name}",
+ shell=True,
+ )
+ elif self._codecs[codec_ind] == "g711":
+ transcoded_f = NamedTemporaryFile(suffix="_g711.wav")
+ _ = subprocess.check_output(
+ f"sox {orig_f.name} -V0 -r 8000 -c 1 -e a-law {transcoded_f.name}", shell=True
+ )
+
+ new_data = AudioSegment.from_file(transcoded_f.name, target_sr=16000)
+ data._samples = new_data._samples[0 : data._samples.shape[0]]
+ return
+
+
+perturbation_types = {
+ "speed": SpeedPerturbation,
+ "time_stretch": TimeStretchPerturbation,
+ "tempo": TempoPerturbation,
+ "gain": GainPerturbation,
+ "impulse": ImpulsePerturbation,
+ "shift": ShiftPerturbation,
+ "noise": NoisePerturbation,
+ "white_noise": WhiteNoisePerturbation,
+ "rir_noise_aug": RirAndNoisePerturbation,
+ "transcode_aug": TranscodePerturbation,
+}
+
+
+def register_perturbation(name: str, perturbation: Perturbation):
+ if name in perturbation_types.keys():
+ raise KeyError(
+ f"Perturbation with the name {name} exists. " f"Type of perturbation : {perturbation_types[name]}."
+ )
+
+ perturbation_types[name] = perturbation
+
+
+class AudioAugmentor(object):
+ def __init__(self, perturbations=None, rng=None):
+ self._rng = random.Random() if rng is None else rng
+ self._pipeline = perturbations if perturbations is not None else []
+ if perturbations:
+ print('audio perturbations:', perturbations)
+
+ def perturb(self, segment):
+ for (prob, p) in self._pipeline:
+ if self._rng.random() < prob:
+ p.perturb(segment)
+ return
+
+ def max_augmentation_length(self, length):
+ newlen = length
+ for (prob, p) in self._pipeline:
+ newlen = p.max_augmentation_length(newlen)
+ return newlen
+
+ @classmethod
+ def from_config(cls, config):
+ ptbs = []
+ for p in config:
+ if p['aug_type'] not in perturbation_types:
+ logging.warning("%s perturbation not known. Skipping.", p['aug_type'])
+ continue
+ perturbation = perturbation_types[p['aug_type']]
+ ptbs.append((p['prob'], perturbation(**p['cfg'])))
+ return cls(perturbations=ptbs)
+
+
+def process_augmentations(augmenter) -> Optional[AudioAugmentor]:
+ """Process list of online data augmentations.
+ Accepts either an AudioAugmentor object with pre-defined augmentations,
+ or a dictionary that points to augmentations that have been defined.
+ If a dictionary is passed, must follow the below structure:
+ Dict[str, Dict[str, Any]]: Which refers to a dictionary of string
+ names for augmentations, defined in `asr/parts/perturb.py`.
+ The inner dictionary may contain key-value arguments of the specific
+ augmentation, along with an essential key `prob`. `prob` declares the
+ probability of the augmentation being applied, and must be a float
+ value in the range [0, 1].
+ # Example in YAML config file
+ Augmentations are generally applied only during training, so we can add
+ these augmentations to our yaml config file, and modify the behaviour
+ for training and evaluation.
+ ```yaml
+ AudioToSpeechLabelDataLayer:
+ ... # Parameters shared between train and evaluation time
+ train:
+ augmentor:
+ shift:
+ prob: 0.5
+ min_shift_ms: -5.0
+ max_shift_ms: 5.0
+ white_noise:
+ prob: 1.0
+ min_level: -90
+ max_level: -46
+ ...
+ eval:
+ ...
+ ```
+ Then in the training script,
+ ```python
+ import copy
+ from ruamel.yaml import YAML
+ yaml = YAML(typ="safe")
+ with open(model_config) as f:
+ params = yaml.load(f)
+ # Train Config for Data Loader
+ train_dl_params = copy.deepcopy(params["AudioToTextDataLayer"])
+ train_dl_params.update(params["AudioToTextDataLayer"]["train"])
+ del train_dl_params["train"]
+ del train_dl_params["eval"]
+ data_layer_train = nemo_asr.AudioToTextDataLayer(
+ ...,
+ **train_dl_params,
+ )
+ # Evaluation Config for Data Loader
+ eval_dl_params = copy.deepcopy(params["AudioToTextDataLayer"])
+ eval_dl_params.update(params["AudioToTextDataLayer"]["eval"])
+ del eval_dl_params["train"]
+ del eval_dl_params["eval"]
+ data_layer_eval = nemo_asr.AudioToTextDataLayer(
+ ...,
+ **eval_dl_params,
+ )
+ ```
+ # Registering your own Augmentations
+ To register custom augmentations to obtain the above convenience of
+ the declaring the augmentations in YAML, you can put additional keys in
+ `perturbation_types` dictionary as follows.
+ ```python
+ from nemo.collections.asr.parts import perturb
+ # Define your own perturbation here
+ class CustomPerturbation(perturb.Perturbation):
+ ...
+ perturb.register_perturbation(name_of_perturbation, CustomPerturbation)
+ ```
+ Args:
+ augmenter: AudioAugmentor object or
+ dictionary of str -> kwargs (dict) which is parsed and used
+ to initialize an AudioAugmentor.
+ Note: It is crucial that each individual augmentation has
+ a keyword `prob`, that defines a float probability in the
+ the range [0, 1] of this augmentation being applied.
+ If this keyword is not present, then the augmentation is
+ disabled and a warning is logged.
+ Returns: AudioAugmentor object
+ """
+ if augmenter is None:
+ return None
+
+ if isinstance(augmenter, AudioAugmentor):
+ return augmenter
+
+ if not type(augmenter) in {dict, DictConfig}:
+ raise ValueError("Cannot parse augmenter. Must be a dict or an AudioAugmentor object ")
+
+ if isinstance(augmenter, DictConfig):
+ augmenter = OmegaConf.to_container(augmenter, resolve=True)
+
+ augmenter = copy.deepcopy(augmenter)
+
+ augmentations = []
+ for augment_name, augment_kwargs in augmenter.items():
+ prob = augment_kwargs.get('prob', None)
+
+ if prob is None:
+ raise KeyError(
+ f'Augmentation "{augment_name}" will not be applied as '
+ f'keyword argument "prob" was not defined for this augmentation.'
+ )
+
+ else:
+ _ = augment_kwargs.pop('prob')
+
+ if prob < 0.0 or prob > 1.0:
+ raise ValueError("`prob` must be a float value between 0 and 1.")
+
+ try:
+ augmentation = perturbation_types[augment_name](**augment_kwargs)
+ augmentations.append([prob, augmentation])
+ except KeyError:
+ raise KeyError(f"Invalid perturbation name. Allowed values : {perturbation_types.keys()}")
+
+ augmenter = AudioAugmentor(perturbations=augmentations)
+ return augmenter
+
+
+class AugmentationDataset(IterableDataset):
+ """
+ A class that loads tarred audio files and cycles over the files in the dataset.
+
+ Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset),
+ as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
+ contain the information for one audio file, including at least the transcript and name of the audio
+ file within the tarball.
+
+ Valid formats for the audio_tar_filepaths argument include:
+ (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
+ (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
+
+ Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
+ This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
+ Supported opening braces - { <=> (, [, < and the special tag _OP_.
+ Supported closing braces - } <=> ), ], > and the special tag _CL_.
+ For SLURM based tasks, we suggest the use of the special tags for ease of use.
+
+ See the WebDataset documentation for more information about accepted data and input formats.
+ """
+
+ def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128):
+ self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
+
+ if isinstance(tar_filepaths, str):
+ # Replace '(' and '[' with '{'
+ brace_keys_open = ['(', '[', '<', '_OP_']
+ for bkey in brace_keys_open:
+ if bkey in tar_filepaths:
+ tar_filepaths = tar_filepaths.replace(bkey, "{")
+
+ # Replace ')' and ']' with '}'
+ brace_keys_close = [')', ']', '>', '_CL_']
+ for bkey in brace_keys_close:
+ if bkey in tar_filepaths:
+ tar_filepaths = tar_filepaths.replace(bkey, "}")
+
+ self.audio_dataset = (
+ wd.Dataset(tar_filepaths).shuffle(shuffle_n).rename(audio='wav', key='__key__').to_tuple('audio', 'key')
+ )
+ self.audio_iter = iter(self.audio_dataset)
+
+ def __len__(self):
+ return len(self._manifest)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ while True:
+ try:
+ audio_bytes, audio_filename = next(self.audio_iter)
+
+ except StopIteration:
+ self.audio_iter = iter(self.audio_dataset)
+ audio_bytes, audio_filename = next(self.audio_iter)
+ file_id, _ = os.path.splitext(os.path.basename(audio_filename))
+
+ # Convert audio bytes to IO stream for processing (for SoundFile to read)
+ audio_file = io.BytesIO(audio_bytes)
+ return audio_file, file_id
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c9960ffe14fd99bff9e9cd061c5aa4b1c3511e
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_beam_decoding.py
@@ -0,0 +1,680 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from nemo.collections.asr.modules import rnnt_abstract
+from nemo.collections.asr.parts import rnnt_utils
+from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
+from nemo.core.classes import Typing, typecheck
+from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
+
+
+class BeamRNNTInfer(Typing):
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
+ "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"predictions": NeuralType(elements_type=HypothesisType())}
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ beam_size: int,
+ search_type: str = 'default',
+ score_norm: bool = True,
+ return_best_hypothesis: bool = True,
+ tsd_max_sym_exp_per_step: Optional[int] = 50,
+ alsd_max_target_len: Union[int, float] = 1.0,
+ nsc_max_timesteps_expansion: int = 1,
+ nsc_prefix_alpha: int = 1,
+ ):
+ """
+ Beam Search implementation ported from ESPNet implementation -
+ https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py
+
+ Sequence level beam decoding or batched-beam decoding, performed auto-repressively
+ depending on the search type chosen.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+
+ beam_size: number of beams for beam search. Must be a positive integer >= 1.
+ If beam size is 1, defaults to stateful greedy search.
+ This greedy search might result in slightly different results than
+ the greedy results obtained by GreedyRNNTInfer due to implementation differences.
+
+ For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer.
+
+ search_type: str representing the type of beam search to perform.
+ Must be one of ['beam', 'tsd', 'alsd']. 'nsc' is currently not supported.
+
+ Algoritm used:
+ `beam` - basic beam search strategy. Larger beams generally result in better decoding,
+ however the time required for the search also grows steadily.
+
+ `tsd` - time synchronous decoding. Please refer to the paper:
+ [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
+ for details on the algorithm implemented.
+
+ Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions.
+ For longer sequences, T is greater, and can therefore take a long time for beams to obtain
+ good results. This also requires greater memory to execute.
+
+ `alsd` - alignment-length synchronous decoding. Please refer to the paper:
+ [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
+ for details on the algorithm implemented.
+
+ Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth
+ factor of T + U_max, where U_max is the maximum target length expected during execution.
+
+ Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique,
+ therefore it is required to use larger beam sizes to achieve the same (or close to the same)
+ decoding accuracy as TSD.
+
+ For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.
+
+ score_norm: bool, whether to normalize the scores of the log probabilities.
+
+ return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N),
+ or return all N hypothesis (sorted with best score first). The container class changes based
+ this flag -
+ When set to True (default), returns a single Hypothesis.
+ When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.
+
+ # The following arguments are specific to the chosen `search_type`
+
+ tsd_max_sym_exp_per_step: Used for `search_type=tsd`. The maximum symmetric expansions allowed
+ per timestep during beam search. Larger values should be used to attempt decoding of longer
+ sequences, but this in turn increases execution time and memory usage.
+
+ alsd_max_target_len: Used for `search_type=alsd`. The maximum expected target sequence length
+ during beam search. Larger values allow decoding of longer sequences at the expense of
+ execution time and memory.
+
+ # The following two flags are placeholders and unused until `nsc` implementation is stabilized.
+
+ nsc_max_timesteps_expansion: Unused int.
+
+ nsc_prefix_alpha: Unused int.
+ """
+ self.decoder = decoder_model
+ self.joint = joint_model
+
+ self.blank = decoder_model.blank_idx
+ self.vocab_size = decoder_model.vocab_size
+ self.search_type = search_type
+ self.return_best_hypothesis = return_best_hypothesis
+
+ if beam_size < 1:
+ raise ValueError("Beam search size cannot be less than 1!")
+
+ self.beam_size = beam_size
+ self.score_norm = score_norm
+
+ if self.beam_size == 1:
+ self.search_algorithm = self.greedy_search
+ elif search_type == "default":
+ self.search_algorithm = self.default_beam_search
+ elif search_type == "tsd":
+ self.search_algorithm = self.time_sync_decoding
+ elif search_type == "alsd":
+ self.search_algorithm = self.align_length_sync_decoding
+ elif search_type == "nsc":
+ raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.")
+ # self.search_algorithm = self.nsc_beam_search
+ else:
+ raise NotImplementedError(
+ f"The search type ({search_type}) supplied is not supported!\n"
+ f"Please use one of : (default, tsd, alsd, nsc)"
+ )
+
+ if tsd_max_sym_exp_per_step is None:
+ tsd_max_sym_exp_per_step = -1
+
+ if search_type in ['tsd', 'alsd', 'nsc'] and not self.decoder.blank_as_pad:
+ raise ValueError(
+ f"Search type was chosen as '{search_type}', however the decoder module provided "
+ f"does not support the `blank` token as a pad value. {search_type} requires "
+ f"the blank token as pad value support in order to perform batched beam search."
+ f"Please chose one of the other beam search methods, or re-train your model "
+ f"with this support."
+ )
+
+ self.tsd_max_symmetric_expansion_per_step = tsd_max_sym_exp_per_step
+ self.alsd_max_target_length = alsd_max_target_len
+ self.nsc_max_timesteps_expansion = nsc_max_timesteps_expansion
+ self.nsc_prefix_alpha = nsc_prefix_alpha
+
+ @typecheck()
+ def __call__(
+ self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor
+ ) -> Union[Hypothesis, NBestHypotheses]:
+ """Perform general beam search.
+
+ Args:
+ encoder_output: Encoded speech features (B, T_max, D_enc)
+ encoded_lengths: Lengths of the encoder outputs
+
+ Returns:
+ Either a list containing a single Hypothesis (when `return_best_hypothesis=True`,
+ otherwise a list containing a single NBestHypotheses, which itself contains a list of
+ Hypothesis. This list is sorted such that the best hypothesis is the first element.
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+
+ self.decoder.eval()
+ self.joint.eval()
+
+ hypotheses = []
+ with tqdm(
+ range(encoder_output.size(0)),
+ desc='Beam search progress:',
+ total=encoder_output.size(0),
+ unit='sample',
+ ) as idx_gen:
+
+ # Freeze the decoder and joint to prevent recording of gradients
+ # during the beam loop.
+ with self.decoder.as_frozen(), self.joint.as_frozen():
+
+ # Decode every sample in the batch independently.
+ for batch_idx in idx_gen:
+ inseq = encoder_output[batch_idx : batch_idx + 1, :, :] # [1, T, D]
+ logitlen = encoded_lengths[batch_idx]
+
+ # Execute the specific search strategy
+ nbest_hyps = self.search_algorithm(inseq, logitlen) # sorted list of hypothesis
+
+ # Pack the result
+ if self.return_best_hypothesis:
+ best_hypothesis = nbest_hyps[0] # type: Hypothesis
+ else:
+ best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses
+ hypotheses.append(best_hypothesis)
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (hypotheses,)
+
+ def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Sort hypotheses by score or score given sequence length.
+
+ Args:
+ hyps: list of hypotheses
+
+ Return:
+ hyps: sorted list of hypotheses
+ """
+ if self.score_norm:
+ return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True)
+ else:
+ return sorted(hyps, key=lambda x: x.score, reverse=True)
+
+ def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Greedy search implementation for transducer.
+ Generic case when beam size = 1. Results might differ slightly due to implementation details
+ as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`.
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ hyp: 1-best decoding results
+ """
+ # Initialize zero state vectors
+ dec_state = self.decoder.initialize_state(h)
+
+ # Construct initial hypothesis
+ hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)
+ cache = {}
+
+ # Initialize state and first token
+ y, state, _ = self.decoder.score_hypothesis(hyp, cache)
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :] # [1, 1, D]
+
+ not_blank = True
+ symbols_added = 0
+
+ while not_blank:
+ ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1]
+ ytu = ytu[0, 0, 0, :] # [V + 1]
+
+ # max() requires float
+ if ytu.dtype != torch.float32:
+ ytu = ytu.float()
+
+ logp, pred = torch.max(ytu, dim=-1) # [1, 1]
+ pred = pred.item()
+
+ if pred == self.blank:
+ not_blank = False
+ else:
+ # Update state and current sequence
+ hyp.y_sequence.append(int(pred))
+ hyp.score += float(logp)
+ hyp.dec_state = state
+
+ # Compute next state and token
+ y, state, _ = self.decoder.score_hypothesis(hyp, cache)
+ symbols_added += 1
+
+ return [hyp]
+
+ def default_beam_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Beam search implementation.
+
+ Args:
+ x: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Initialize states
+ beam = min(self.beam_size, self.vocab_size)
+ beam_k = min(beam, (self.vocab_size - 1))
+ blank_tensor = torch.tensor([self.blank], device=h.device, dtype=torch.long)
+
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # Initialize zero vector states
+ dec_state = self.decoder.initialize_state(h)
+
+ # Initialize first hypothesis for the beam (blank)
+ kept_hyps = [Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)]
+ cache = {}
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :] # [1, 1, D]
+ hyps = kept_hyps
+ kept_hyps = []
+
+ while True:
+ max_hyp = max(hyps, key=lambda x: x.score)
+ hyps.remove(max_hyp)
+
+ # update decoder state and get next score
+ y, state, lm_tokens = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D]
+
+ # get next token
+ ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1]
+ ytu = ytu[0, 0, 0, :] # [V + 1]
+
+ # remove blank token before top k
+ top_k = ytu[ids].topk(beam_k, dim=-1)
+
+ # Two possible steps - blank token or non-blank token predicted
+ ytu = (
+ torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))),
+ torch.cat((top_k[1] + index_incr, blank_tensor)),
+ )
+
+ # for each possible step
+ for logp, k in zip(*ytu):
+ # construct hypothesis for step
+ new_hyp = Hypothesis(
+ score=(max_hyp.score + float(logp)),
+ y_sequence=max_hyp.y_sequence[:],
+ dec_state=max_hyp.dec_state,
+ lm_state=max_hyp.lm_state,
+ )
+
+ # if current token is blank, dont update sequence, just store the current hypothesis
+ if k == self.blank:
+ kept_hyps.append(new_hyp)
+ else:
+ # if non-blank token was predicted, update state and sequence and then search more hypothesis
+ new_hyp.dec_state = state
+ new_hyp.y_sequence.append(int(k))
+
+ hyps.append(new_hyp)
+
+ # keep those hypothesis that have scores greater than next search generation
+ hyps_max = float(max(hyps, key=lambda x: x.score).score)
+ kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,)
+
+ # If enough hypothesis have scores greater than next search generation,
+ # stop beam search.
+ if len(kept_most_prob) >= beam:
+ kept_hyps = kept_most_prob
+ break
+
+ return self.sort_nbest(kept_hyps)
+
+ def time_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Time synchronous beam search implementation.
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # prepare the batched beam states
+ beam = min(self.beam_size, self.vocab_size)
+ beam_state = self.decoder.initialize_state(
+ torch.zeros(beam, device=h.device, dtype=h.dtype)
+ ) # [L, B, H], [L, B, H] (for LSTMs)
+
+ # Initialize first hypothesis for the beam (blank)
+ B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))]
+ cache = {}
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :]
+
+ # Update caches
+ A = []
+ C = B
+
+ h_enc = hi
+
+ # For a limited number of symmetric expansions per timestep "i"
+ for v in range(self.tsd_max_symmetric_expansion_per_step):
+ D = []
+
+ # Decode a batch of beam states and scores
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state)
+
+ # Extract the log probabilities and the predicted tokens
+ beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B, 1, 1, V + 1]
+ beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1]
+ beam_topk = beam_logp[:, ids].topk(beam, dim=-1)
+
+ seq_A = [h.y_sequence for h in A]
+
+ for j, hyp in enumerate(C):
+ # create a new hypothesis in A
+ if hyp.y_sequence not in seq_A:
+ # If the sequence is not in seq_A, add it as the blank token
+ # In this step, we dont add a token but simply update score
+ A.append(
+ Hypothesis(
+ score=(hyp.score + float(beam_logp[j, self.blank])),
+ y_sequence=hyp.y_sequence[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+ )
+ else:
+ # merge the existing blank hypothesis score with current score.
+ dict_pos = seq_A.index(hyp.y_sequence)
+
+ A[dict_pos].score = np.logaddexp(
+ A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank]))
+ )
+
+ if v < self.tsd_max_symmetric_expansion_per_step:
+ for j, hyp in enumerate(C):
+ # for each current hypothesis j
+ # extract the top token score and top token id for the jth hypothesis
+ for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr):
+ # create new hypothesis and store in D
+ # Note: This loop does *not* include the blank token!
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ y_sequence=(hyp.y_sequence + [int(k)]),
+ dec_state=self.decoder.batch_select_state(beam_state, j),
+ lm_state=hyp.lm_state,
+ )
+
+ D.append(new_hyp)
+
+ # Prune beam
+ C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
+
+ # Prune beam
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
+
+ return self.sort_nbest(B)
+
+ def align_length_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Alignment-length synchronous beam search implementation.
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # prepare the batched beam states
+ beam = min(self.beam_size, self.vocab_size)
+
+ h = h[0] # [T, D]
+ h_length = int(encoded_lengths)
+ beam_state = self.decoder.initialize_state(
+ torch.zeros(beam, device=h.device, dtype=h.dtype)
+ ) # [L, B, H], [L, B, H] for LSTMS
+
+ # compute u_max as either a specific static limit,
+ # or a multiple of current `h_length` dynamically.
+ if type(self.alsd_max_target_length) == float:
+ u_max = int(self.alsd_max_target_length * h_length)
+ else:
+ u_max = int(self.alsd_max_target_length)
+
+ # Initialize first hypothesis for the beam (blank)
+ B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))]
+
+ final = []
+ cache = {}
+
+ # ALSD runs for T + U_max steps
+ for i in range(h_length + u_max):
+ # Update caches
+ A = []
+ B_ = []
+ h_states = []
+
+ # preserve the list of batch indices which are added into the list
+ # and those which are removed from the list
+ # This is necessary to perform state updates in the correct batch indices later
+ batch_ids = list(range(len(B))) # initialize as a list of all batch ids
+ batch_removal_ids = [] # update with sample ids which are removed
+
+ for bid, hyp in enumerate(B):
+ u = len(hyp.y_sequence) - 1
+ t = i - u + 1
+
+ if t > (h_length - 1):
+ batch_removal_ids.append(bid)
+ continue
+
+ B_.append(hyp)
+ h_states.append((t, h[t]))
+
+ if B_:
+ # Compute the subset of batch ids which were *not* removed from the list above
+ sub_batch_ids = None
+ if len(B_) != beam:
+ sub_batch_ids = batch_ids
+ for id in batch_removal_ids:
+ # sub_batch_ids contains list of ids *that were not removed*
+ sub_batch_ids.remove(id)
+
+ # extract the states of the sub batch only.
+ beam_state_ = [beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state))]
+ else:
+ # If entire batch was used (none were removed), simply take all the states
+ beam_state_ = beam_state
+
+ # Decode a batch/sub-batch of beam states and scores
+ beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_)
+
+ # If only a subset of batch ids were updated (some were removed)
+ if sub_batch_ids is not None:
+ # For each state in the RNN (2 for LSTM)
+ for state_id in range(len(beam_state)):
+ # Update the current batch states with the sub-batch states (in the correct indices)
+ # These indices are specified by sub_batch_ids, the ids of samples which were updated.
+ beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...]
+ else:
+ # If entire batch was updated, simply update all the states
+ beam_state = beam_state_
+
+ # h_states = list of [t, h[t]]
+ # so h[1] here is a h[t] of shape [D]
+ # Simply stack all of the h[t] within the sub_batch/batch (T <= beam)
+ h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D]
+ h_enc = h_enc.unsqueeze(1) # [B=beam, T=1, D]; batch over the beams
+
+ # Extract the log probabilities and the predicted tokens
+ beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B=beam, 1, 1, V + 1]
+ beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1]
+ beam_topk = beam_logp[:, ids].topk(beam, dim=-1)
+
+ for j, hyp in enumerate(B_):
+ # For all updated samples in the batch, add it as the blank token
+ # In this step, we dont add a token but simply update score
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(beam_logp[j, self.blank])),
+ y_sequence=hyp.y_sequence[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+
+ # Add blank prediction to A
+ A.append(new_hyp)
+
+ # If the prediction "timestep" t has reached the length of the input sequence
+ # we can add it to the "finished" hypothesis list.
+ if h_states[j][0] == (h_length - 1):
+ final.append(new_hyp)
+
+ # Here, we carefully select the indices of the states that we want to preserve
+ # for the next token (non-blank) update.
+ if sub_batch_ids is not None:
+ h_states_idx = sub_batch_ids[j]
+ else:
+ h_states_idx = j
+
+ # for each current hypothesis j
+ # extract the top token score and top token id for the jth hypothesis
+ for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr):
+ # create new hypothesis and store in A
+ # Note: This loop does *not* include the blank token!
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ y_sequence=(hyp.y_sequence[:] + [int(k)]),
+ dec_state=self.decoder.batch_select_state(beam_state, h_states_idx),
+ lm_state=hyp.lm_state,
+ )
+
+ A.append(new_hyp)
+
+ # Prune and recombine same hypothesis
+ # This may cause next beam to be smaller than max beam size
+ # Therefore larger beam sizes may be required for better decoding.
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
+ B = self.recombine_hypotheses(B)
+
+ # If B_ is empty list, then we may be able to early exit
+ elif len(batch_ids) == len(batch_removal_ids):
+ break
+
+ if final:
+ return self.sort_nbest(final)
+ else:
+ return B
+
+ def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis]:
+ """Recombine hypotheses with equivalent output sequence.
+
+ Args:
+ hypotheses (list): list of hypotheses
+
+ Returns:
+ final (list): list of recombined hypotheses
+ """
+ final = []
+
+ for hyp in hypotheses:
+ seq_final = [f.y_sequence for f in final if f.y_sequence]
+
+ if hyp.y_sequence in seq_final:
+ seq_pos = seq_final.index(hyp.y_sequence)
+
+ final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score)
+ else:
+ final.append(hyp)
+
+ return hypotheses
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec12457ea9ff78f8dd7d63aec40427b48365e3b
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_greedy_decoding.py
@@ -0,0 +1,565 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+
+from nemo.collections.asr.modules import rnnt_abstract
+from nemo.collections.asr.parts import rnnt_utils
+from nemo.collections.common.parts.rnn import label_collate
+from nemo.core.classes import Typing, typecheck
+from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
+
+
+class _GreedyRNNTInfer(Typing):
+ """A greedy transducer decoder.
+
+ Provides a common abstraction for sample level and batch level greedy decoding.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
+ "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"predictions": NeuralType(elements_type=HypothesisType())}
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__()
+ self.decoder = decoder_model
+ self.joint = joint_model
+
+ self._blank_index = blank_index
+ self._SOS = blank_index # Start of single index
+ self.max_symbols = max_symbols_per_step
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ @torch.no_grad()
+ def _pred_step(
+ self,
+ label: Union[torch.Tensor, int],
+ hidden: Optional[torch.Tensor],
+ add_sos: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> (torch.Tensor, torch.Tensor):
+ """
+ Common prediction step based on the AbstractRNNTDecoder implementation.
+
+ Args:
+ label: (int/torch.Tensor): Label or "Start-of-Signal" token.
+ hidden: (Optional torch.Tensor): RNN State vector
+ add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token.
+ batch_size: Batch size of the output tensor.
+
+ Returns:
+ g: (B, U, H) if add_sos is false, else (B, U + 1, H)
+ hid: (h, c) where h is the final sequence hidden state and c is
+ the final cell state:
+ h (tensor), shape (L, B, H)
+ c (tensor), shape (L, B, H)
+ """
+ if isinstance(label, torch.Tensor):
+ # label: [batch, 1]
+ if label.dtype != torch.long:
+ label = label.long()
+
+ else:
+ # Label is an integer
+ if label == self._SOS:
+ return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size)
+
+ label = label_collate([[label]])
+
+ # output: [B, 1, K]
+ return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size)
+
+ def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None):
+ """
+ Common joint step based on AbstractRNNTJoint implementation.
+
+ Args:
+ enc: Output of the Encoder model. A torch.Tensor of shape [B, 1, H1]
+ pred: Output of the Decoder model. A torch.Tensor of shape [B, 1, H2]
+ log_normalize: Whether to log normalize or not. None will log normalize only for CPU.
+
+ Returns:
+ logits of shape (B, T=1, U=1, V + 1)
+ """
+ with torch.no_grad():
+ logits = self.joint.joint(enc, pred)
+
+ if log_normalize is None:
+ if not logits.is_cuda: # Use log softmax only if on CPU
+ logits = logits.log_softmax(dim=len(logits.shape) - 1)
+ else:
+ if log_normalize:
+ logits = logits.log_softmax(dim=len(logits.shape) - 1)
+
+ return logits
+
+
+class GreedyRNNTInfer(_GreedyRNNTInfer):
+ """A greedy transducer decoder.
+
+ Sequence level greedy decoding, performed auto-repressively.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__(
+ decoder_model=decoder_model,
+ joint_model=joint_model,
+ blank_index=blank_index,
+ max_symbols_per_step=max_symbols_per_step,
+ )
+
+ @typecheck()
+ def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor):
+ """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
+ Output token is generated auto-repressively.
+
+ Args:
+ encoder_output: A tensor of size (batch, features, timesteps).
+ encoded_lengths: list of int representing the length of each sequence
+ output sequence.
+
+ Returns:
+ packed list containing batch number of sentences (Hypotheses).
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+
+ self.decoder.eval()
+ self.joint.eval()
+
+ hypotheses = []
+ # Process each sequence independently
+ with self.decoder.as_frozen(), self.joint.as_frozen():
+ for batch_idx in range(encoder_output.size(0)):
+ inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D]
+ logitlen = encoded_lengths[batch_idx]
+ sentence = self._greedy_decode(inseq, logitlen)
+ hypotheses.append(sentence)
+
+ # Pack results into Hypotheses
+ packed_result = [
+ rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0)
+ for sent in hypotheses
+ ]
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (packed_result,)
+
+ @torch.no_grad()
+ def _greedy_decode(self, x: torch.Tensor, out_len: torch.Tensor):
+ # x: [T, 1, D]
+ # out_len: [seq_len]
+
+ # Initialize blank state and empty label set
+ hidden = None
+ label = []
+
+ # For timestep t in X_t
+ for time_idx in range(out_len):
+ # Extract encoder embedding at timestep t
+ # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D]
+ f = x.narrow(dim=0, start=time_idx, length=1)
+
+ # Setup exit flags and counter
+ not_blank = True
+ symbols_added = 0
+
+ # While blank is not predicted, or we dont run out of max symbols per timestep
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+ # In the first timestep, we initialize the network with RNNT Blank
+ # In later timesteps, we provide previous predicted label as input.
+ last_label = self._SOS if label == [] else label[-1]
+
+ # Perform prediction network and joint network steps.
+ g, hidden_prime = self._pred_step(last_label, hidden)
+ logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :]
+
+ del g
+
+ # torch.max(0) op doesnt exist for FP 16.
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # get index k, of max prob
+ v, k = logp.max(0)
+ k = k.item() # K is the label at timestep t_s in inner loop, s >= 0.
+
+ del logp
+
+ # If blank token is predicted, exit inner loop, move onto next timestep t
+ if k == self._blank_index:
+ not_blank = False
+ else:
+ # Append token to label set, update RNN state.
+ label.append(k)
+ hidden = hidden_prime
+
+ # Increment token counter.
+ symbols_added += 1
+
+ return label
+
+
+class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
+ """A batch level greedy transducer decoder.
+
+ Batch level greedy decoding, performed auto-repressively.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__(
+ decoder_model=decoder_model,
+ joint_model=joint_model,
+ blank_index=blank_index,
+ max_symbols_per_step=max_symbols_per_step,
+ )
+
+ # Depending on availability of `blank_as_pad` support
+ # switch between more efficient batch decoding technique
+ if self.decoder.blank_as_pad:
+ self._greedy_decode = self._greedy_decode_blank_as_pad
+ else:
+ self._greedy_decode = self._greedy_decode_masked
+
+ @typecheck()
+ def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor):
+ """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
+ Output token is generated auto-repressively.
+
+ Args:
+ encoder_output: A tensor of size (batch, features, timesteps).
+ encoded_lengths: list of int representing the length of each sequence
+ output sequence.
+
+ Returns:
+ packed list containing batch number of sentences (Hypotheses).
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+ logitlen = encoded_lengths
+
+ self.decoder.eval()
+ self.joint.eval()
+
+ with self.decoder.as_frozen(), self.joint.as_frozen():
+ inseq = encoder_output # [B, T, D]
+ hypotheses = self._greedy_decode(inseq, logitlen, device=inseq.device)
+
+ # Pack the hypotheses results
+ packed_result = [
+ rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0)
+ for sent in hypotheses
+ ]
+
+ del hypotheses
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (packed_result,)
+
+ def _greedy_decode_blank_as_pad(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device):
+ with torch.no_grad():
+ # x: [B, T, D]
+ # out_len: [B]
+ # device: torch.device
+
+ # Initialize state
+ hidden = None
+ batchsize = x.shape[0]
+
+ # Output string buffer
+ label = [[] for _ in range(batchsize)]
+
+ # Last Label buffer + Last Label without blank buffer
+ # batch level equivalent of the last_label
+ last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)
+
+ # Mask buffers
+ blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
+
+ # Get max sequence length
+ max_out_len = out_len.max()
+
+ for time_idx in range(max_out_len):
+ f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]
+
+ # Prepare t timestamp batch variables
+ not_blank = True
+ symbols_added = 0
+
+ # Reset blank mask
+ blank_mask.mul_(False)
+
+ # Update blank mask with time mask
+ # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
+ # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
+ blank_mask = time_idx >= out_len
+
+ # Start inner loop
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+
+ # Batch prediction and joint network steps
+ # If very first prediction step, submit SOS tag (blank) to pred_step.
+ # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
+ if time_idx == 0 and symbols_added == 0:
+ g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
+ else:
+ # Perform batch step prediction of decoder, getting new states and scores ("g")
+ g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize)
+
+ # Batched joint step - Output = [B, V + 1]
+ logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :]
+
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # Get index k, of max prob for batch
+ v, k = logp.max(1)
+ del v, g, logp
+
+ # Update blank mask with current predicted blanks
+ # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
+ k_is_blank = k == self._blank_index
+ blank_mask |= k_is_blank
+
+ del k_is_blank
+
+ # If all samples predict / have predicted prior blanks, exit loop early
+ # This is equivalent to if single sample predicted k
+ if blank_mask.all():
+ not_blank = False
+ else:
+ # Collect batch indices where blanks occurred now/past
+ blank_indices = []
+ if hidden is not None:
+ blank_indices = (blank_mask == 1).nonzero(as_tuple=False)
+
+ # Recover prior state for all samples which predicted blank now/past
+ if hidden is not None:
+ # LSTM has 2 states
+ for state_id in range(len(hidden)):
+ hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :]
+
+ # Recover prior predicted label for all samples which predicted blank now/past
+ k[blank_indices] = last_label[blank_indices, 0]
+
+ # Update new label and hidden state for next iteration
+ last_label = k.clone().view(-1, 1)
+ hidden = hidden_prime
+
+ # Update predicted labels, accounting for time mask
+ # If blank was predicted even once, now or in the past,
+ # Force the current predicted label to also be blank
+ # This ensures that blanks propogate across all timesteps
+ # once they have occured (normally stopping condition of sample level loop).
+ for kidx, ki in enumerate(k):
+ if blank_mask[kidx] == 0:
+ label[kidx].append(ki)
+
+ symbols_added += 1
+
+ return label
+
+ @torch.no_grad()
+ def _greedy_decode_masked(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device):
+ # x: [B, T, D]
+ # out_len: [B]
+ # device: torch.device
+
+ # Initialize state
+ hidden = None
+ batchsize = x.shape[0]
+
+ # Output string buffer
+ label = [[] for _ in range(batchsize)]
+
+ # Last Label buffer + Last Label without blank buffer
+ # batch level equivalent of the last_label
+ last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)
+ last_label_without_blank = last_label.clone()
+
+ # Mask buffers
+ blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
+
+ # Get max sequence length
+ max_out_len = out_len.max()
+ for time_idx in range(max_out_len):
+ f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]
+
+ # Prepare t timestamp batch variables
+ not_blank = True
+ symbols_added = 0
+
+ # Reset blank mask
+ blank_mask.mul_(False)
+
+ # Update blank mask with time mask
+ # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
+ # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
+ blank_mask = time_idx >= out_len
+
+ # Start inner loop
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+ # Batch prediction and joint network steps
+ # If very first prediction step, submit SOS tag (blank) to pred_step.
+ # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
+ if time_idx == 0 and symbols_added == 0:
+ g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
+ else:
+ # Set a dummy label for the blank value
+ # This value will be overwritten by "blank" again the last label update below
+ # This is done as vocabulary of prediction network does not contain "blank" token of RNNT
+ last_label_without_blank_mask = last_label == self._blank_index
+ last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label
+ last_label_without_blank[~last_label_without_blank_mask] = last_label[
+ ~last_label_without_blank_mask
+ ]
+
+ # Perform batch step prediction of decoder, getting new states and scores ("g")
+ g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize)
+
+ # Batched joint step - Output = [B, V + 1]
+ logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :]
+
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # Get index k, of max prob for batch
+ v, k = logp.max(1)
+ del v, g, logp
+
+ # Update blank mask with current predicted blanks
+ # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
+ k_is_blank = k == self._blank_index
+ blank_mask.bitwise_or_(k_is_blank)
+
+ # If all samples predict / have predicted prior blanks, exit loop early
+ # This is equivalent to if single sample predicted k
+ if blank_mask.all():
+ not_blank = False
+ else:
+ # Collect batch indices where blanks occurred now/past
+ blank_indices = []
+ if hidden is not None:
+ blank_indices = (blank_mask == 1).nonzero(as_tuple=False)
+
+ # Recover prior state for all samples which predicted blank now/past
+ if hidden is not None:
+ # LSTM has 2 states
+ for state_id in range(len(hidden)):
+ hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :]
+
+ # Recover prior predicted label for all samples which predicted blank now/past
+ k[blank_indices] = last_label[blank_indices, 0]
+
+ # Update new label and hidden state for next iteration
+ last_label = k.view(-1, 1)
+ hidden = hidden_prime
+
+ # Update predicted labels, accounting for time mask
+ # If blank was predicted even once, now or in the past,
+ # Force the current predicted label to also be blank
+ # This ensures that blanks propogate across all timesteps
+ # once they have occured (normally stopping condition of sample level loop).
+ for kidx, ki in enumerate(k):
+ if blank_mask[kidx] == 0:
+ label[kidx].append(ki)
+
+ symbols_added += 1
+
+ return label
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cd53bd89ad8b1dfd89d3297880a3331655c349
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/rnnt_utils.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+
+@dataclass
+class Hypothesis:
+ """Hypothesis class for beam search algorithms.
+
+ score: A float score obtained from an AbstractRNNTDecoder module's score_hypothesis method.
+
+ y_sequence: Either a sequence of integer ids pointing to some vocabulary, or a packed torch.Tensor
+ behaving in the same manner. dtype must be torch.Long in the latter case.
+
+ dec_state: A list (or list of list) of LSTM-RNN decoder states. Can be None.
+
+ y: (Unused) A list of torch.Tensors representing the list of hypotheses.
+
+ lm_state: (Unused) A dictionary state cache used by an external Language Model.
+
+ lm_scores: (Unused) Score of the external Language Model.
+ """
+
+ score: float
+ y_sequence: Union[List[int], torch.Tensor]
+ dec_state: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor]]] = None
+ y: List[torch.tensor] = None
+ lm_state: Union[Dict[str, Any], List[Any]] = None
+ lm_scores: torch.Tensor = None
+
+
+@dataclass
+class NBestHypotheses:
+ """List of N best hypotheses"""
+
+ n_best_hypotheses: Optional[List[Hypothesis]]
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py
new file mode 100644
index 0000000000000000000000000000000000000000..85b4c38216ea5665982faf9c963a37e0c1fe7c57
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/segment.py
@@ -0,0 +1,223 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Copyright (c) 2018 Ryan Leary
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# This file contains code artifacts adapted from https://github.com/ryanleary/patter
+
+import random
+
+import librosa
+import numpy as np
+import soundfile as sf
+
+
+class AudioSegment(object):
+ """Monaural audio segment abstraction.
+ :param samples: Audio samples [num_samples x num_channels].
+ :type samples: ndarray.float32
+ :param sample_rate: Audio sample rate.
+ :type sample_rate: int
+ :raises TypeError: If the sample data type is not float or int.
+ """
+
+ def __init__(self, samples, sample_rate, target_sr=None, trim=False, trim_db=60, orig_sr=None):
+ """Create audio segment from samples.
+ Samples are convert float32 internally, with int scaled to [-1, 1].
+ """
+ samples = self._convert_samples_to_float32(samples)
+ if target_sr is not None and target_sr != sample_rate:
+ samples = librosa.core.resample(samples, sample_rate, target_sr)
+ sample_rate = target_sr
+ if trim:
+ samples, _ = librosa.effects.trim(samples, trim_db)
+ self._samples = samples
+ self._sample_rate = sample_rate
+ if self._samples.ndim >= 2:
+ self._samples = np.mean(self._samples, 1)
+
+ self._orig_sr = orig_sr if orig_sr is not None else sample_rate
+
+ def __eq__(self, other):
+ """Return whether two objects are equal."""
+ if type(other) is not type(self):
+ return False
+ if self._sample_rate != other._sample_rate:
+ return False
+ if self._samples.shape != other._samples.shape:
+ return False
+ if np.any(self.samples != other._samples):
+ return False
+ return True
+
+ def __ne__(self, other):
+ """Return whether two objects are unequal."""
+ return not self.__eq__(other)
+
+ def __str__(self):
+ """Return human-readable representation of segment."""
+ return "%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, rms=%.2fdB" % (
+ type(self),
+ self.num_samples,
+ self.sample_rate,
+ self.duration,
+ self.rms_db,
+ )
+
+ @staticmethod
+ def _convert_samples_to_float32(samples):
+ """Convert sample type to float32.
+ Audio sample type is usually integer or float-point.
+ Integers will be scaled to [-1, 1] in float32.
+ """
+ float32_samples = samples.astype('float32')
+ if samples.dtype in np.sctypes['int']:
+ bits = np.iinfo(samples.dtype).bits
+ float32_samples *= 1.0 / 2 ** (bits - 1)
+ elif samples.dtype in np.sctypes['float']:
+ pass
+ else:
+ raise TypeError("Unsupported sample type: %s." % samples.dtype)
+ return float32_samples
+
+ @classmethod
+ def from_file(
+ cls, audio_file, target_sr=None, int_values=False, offset=0, duration=0, trim=False, orig_sr=None,
+ ):
+ """
+ Load a file supported by librosa and return as an AudioSegment.
+ :param audio_file: path of file to load
+ :param target_sr: the desired sample rate
+ :param int_values: if true, load samples as 32-bit integers
+ :param offset: offset in seconds when loading audio
+ :param duration: duration in seconds when loading audio
+ :return: numpy array of samples
+ """
+ with sf.SoundFile(audio_file, 'r') as f:
+ dtype = 'int32' if int_values else 'float32'
+ sample_rate = f.samplerate
+ if offset > 0:
+ f.seek(int(offset * sample_rate))
+ if duration > 0:
+ samples = f.read(int(duration * sample_rate), dtype=dtype)
+ else:
+ samples = f.read(dtype=dtype)
+
+ samples = samples.transpose()
+ return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
+
+ @classmethod
+ def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None):
+ """Grabs n_segments number of samples from audio_file randomly from the
+ file as opposed to at a specified offset.
+
+ Note that audio_file can be either the file path, or a file-like object.
+ """
+ with sf.SoundFile(audio_file, 'r') as f:
+ sample_rate = f.samplerate
+ if n_segments > 0 and len(f) > n_segments:
+ max_audio_start = len(f) - n_segments
+ audio_start = random.randint(0, max_audio_start)
+ f.seek(audio_start)
+ samples = f.read(n_segments, dtype='float32')
+ else:
+ samples = f.read(dtype='float32')
+
+ samples = samples.transpose()
+ return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
+
+ @property
+ def samples(self):
+ return self._samples.copy()
+
+ @property
+ def sample_rate(self):
+ return self._sample_rate
+
+ @property
+ def num_samples(self):
+ return self._samples.shape[0]
+
+ @property
+ def duration(self):
+ return self._samples.shape[0] / float(self._sample_rate)
+
+ @property
+ def rms_db(self):
+ mean_square = np.mean(self._samples ** 2)
+ return 10 * np.log10(mean_square)
+
+ @property
+ def orig_sr(self):
+ return self._orig_sr
+
+ def gain_db(self, gain):
+ self._samples *= 10.0 ** (gain / 20.0)
+
+ def pad(self, pad_size, symmetric=False):
+ """Add zero padding to the sample. The pad size is given in number
+ of samples.
+ If symmetric=True, `pad_size` will be added to both sides. If false,
+ `pad_size`
+ zeros will be added only to the end.
+ """
+ self._samples = np.pad(self._samples, (pad_size if symmetric else 0, pad_size), mode='constant',)
+
+ def subsegment(self, start_time=None, end_time=None):
+ """Cut the AudioSegment between given boundaries.
+ Note that this is an in-place transformation.
+ :param start_time: Beginning of subsegment in seconds.
+ :type start_time: float
+ :param end_time: End of subsegment in seconds.
+ :type end_time: float
+ :raise ValueError: If start_time or end_time is incorrectly set,
+ e.g. out
+ of bounds in time.
+ """
+ start_time = 0.0 if start_time is None else start_time
+ end_time = self.duration if end_time is None else end_time
+ if start_time < 0.0:
+ start_time = self.duration + start_time
+ if end_time < 0.0:
+ end_time = self.duration + end_time
+ if start_time < 0.0:
+ raise ValueError("The slice start position (%f s) is out of bounds." % start_time)
+ if end_time < 0.0:
+ raise ValueError("The slice end position (%f s) is out of bounds." % end_time)
+ if start_time > end_time:
+ raise ValueError(
+ "The slice start position (%f s) is later than the end position (%f s)." % (start_time, end_time)
+ )
+ if end_time > self.duration:
+ raise ValueError("The slice end position (%f s) is out of bounds (> %f s)" % (end_time, self.duration))
+ start_sample = int(round(start_time * self._sample_rate))
+ end_sample = int(round(end_time * self._sample_rate))
+ self._samples = self._samples[start_sample:end_sample]
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..172d2ef71ac0759a041595cfdc1d3224ae3d177e
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/simple_wer_v2.py
@@ -0,0 +1,454 @@
+# Lint as: python2, python3
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The new version script to evalute the word error rate (WER) for ASR tasks.
+
+Tensorflow and Lingvo are not required to run this script.
+
+Example of Usage:
+
+a) `python simple_wer_v2.py file_hypothesis file_reference`
+b) `python simple_wer_v2.py file_hypothesis file_reference file_keyphrases`
+
+where `file_hypothesis` is the filename for hypothesis text,
+`file_reference` is the filename for reference text, and
+`file_keyphrases` is the optional filename for important phrases
+(one phrase per line).
+
+Note that the program will also generate a html to diagnose the errors,
+and the html filename is `{$file_hypothesis}_diagnois.html`.
+
+Another way is to use this file as a stand-alone library, by calling class
+SimpleWER with the following member functions:
+
+- AddHypRef(hyp, ref): Updates the evaluation for each (hyp,ref) pair.
+- GetWER(): Computes word error rate (WER) for all the added hyp-ref pairs.
+- GetSummaries(): Generates strings to summarize word and key phrase errors.
+- GetKeyPhraseStats(): Measures stats for key phrases.
+ Stats include:
+ (1) Jaccard similarity: https://en.wikipedia.org/wiki/Jaccard_index.
+ (2) F1 score: https://en.wikipedia.org/wiki/Precision_and_recall.
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import sys
+from six.moves import range
+
+
+def TxtPreprocess(txt):
+ """Preprocess text before WER caculation."""
+
+ # Lowercase, remove \t and new line.
+ txt = re.sub(r'[\t\n]', ' ', txt.lower())
+
+ # Remove punctuation before space.
+ txt = re.sub(r'[,.\?!]+ ', ' ', txt)
+
+ # Remove punctuation before end.
+ txt = re.sub(r'[,.\?!]+$', ' ', txt)
+
+ # Remove punctuation after space.
+ txt = re.sub(r' [,.\?!]+', ' ', txt)
+
+ # Remove quotes, [, ], ( and ).
+ txt = re.sub(r'["\(\)\[\]]', '', txt)
+
+ # Remove extra space.
+ txt = re.sub(' +', ' ', txt.strip())
+
+ return txt
+
+
+def RemoveCommentTxtPreprocess(txt):
+ """Preprocess text and remove comments in the brancket, such as [comments]."""
+
+ # Remove comments surrounded by box brackets:
+ txt = re.sub(r'\[\w+\]', '', txt)
+
+ return TxtPreprocess(txt)
+
+
+def HighlightAlignedHtml(hyp, ref, err_type):
+ """Generate a html element to highlight the difference between hyp and ref.
+
+ Args:
+ hyp: Hypothesis string.
+ ref: Reference string.
+ err_type: one of 'none', 'sub', 'del', 'ins'.
+
+ Returns:
+ a html string where disagreements are highlighted.
+ Note `hyp` is highlighted in green, and marked with
+ `ref` is highlighted in yellow. If you want html with nother styles,
+ consider to write your own function.
+
+ Raises:
+ ValueError: if err_type is not among ['none', 'sub', 'del', 'ins'].
+ or if when err_type == 'none', hyp != ref
+ """
+
+ highlighted_html = ''
+ if err_type == 'none':
+ if hyp != ref:
+ raise ValueError('hyp (%s) does not match ref (%s) for none error' %
+ (hyp, ref))
+ highlighted_html += '%s ' % hyp
+
+ elif err_type == 'sub':
+ highlighted_html += """
+ %s
+ %s """ % (hyp, ref)
+
+ elif err_type == 'del':
+ highlighted_html += """
+ %s """ % (
+ ref)
+
+ elif err_type == 'ins':
+ highlighted_html += """
+ %s """ % (
+ hyp)
+
+ else:
+ raise ValueError('unknown err_type ' + err_type)
+
+ return highlighted_html
+
+
+def ComputeEditDistanceMatrix(hyp_words, ref_words):
+ """Compute edit distance between two list of strings.
+
+ Args:
+ hyp_words: the list of words in the hypothesis sentence
+ ref_words: the list of words in the reference sentence
+
+ Returns:
+ Edit distance matrix (in the format of list of lists), where the first
+ index is the reference and the second index is the hypothesis.
+ """
+ reference_length_plus = len(ref_words) + 1
+ hypothesis_length_plus = len(hyp_words) + 1
+ edit_dist_mat = [[]] * reference_length_plus
+
+ # Initialization.
+ for i in range(reference_length_plus):
+ edit_dist_mat[i] = [0] * hypothesis_length_plus
+ for j in range(hypothesis_length_plus):
+ if i == 0:
+ edit_dist_mat[0][j] = j
+ elif j == 0:
+ edit_dist_mat[i][0] = i
+
+ # Do dynamic programming.
+ for i in range(1, reference_length_plus):
+ for j in range(1, hypothesis_length_plus):
+ if ref_words[i - 1] == hyp_words[j - 1]:
+ edit_dist_mat[i][j] = edit_dist_mat[i - 1][j - 1]
+ else:
+ tmp0 = edit_dist_mat[i - 1][j - 1] + 1
+ tmp1 = edit_dist_mat[i][j - 1] + 1
+ tmp2 = edit_dist_mat[i - 1][j] + 1
+ edit_dist_mat[i][j] = min(tmp0, tmp1, tmp2)
+
+ return edit_dist_mat
+
+
+class SimpleWER(object):
+ """Compute word error rates after the alignment.
+
+ Attributes:
+ key_phrases: list of important phrases.
+ aligned_htmls: list of diagnois htmls, each of which corresponding to a pair
+ of hypothesis and reference.
+ hyp_keyphrase_counts: dict. `hyp_keyphrase_counts[w]` counts how often a key
+ phrases `w` appear in the hypotheses.
+ ref_keyphrase_counts: dict. `ref_keyphrase_counts[w]` counts how often a key
+ phrases `w` appear in the references.
+ matched_keyphrase_counts: dict. `matched_keyphrase_counts[w]` counts how
+ often a key phrase `w` appear in the aligned transcripts when the
+ reference and hyp_keyphrase match.
+ wer_info: dict with four keys: 'sub' (substitution error), 'ins' (insersion
+ error), 'del' (deletion error), 'nw' (number of words). We can use
+ wer_info to compute word error rate (WER) as
+ (wer_info['sub']+wer_info['ins']+wer_info['del'])*100.0/wer_info['nw']
+ """
+
+ def __init__(self,
+ key_phrases=None,
+ html_handler=HighlightAlignedHtml,
+ preprocess_handler=RemoveCommentTxtPreprocess):
+ """Initialize SimpleWER object.
+
+ Args:
+ key_phrases: list of strings as important phrases. If key_phrases is
+ None, no key_phrases related metric will be computed.
+ html_handler: function to generate a string with html tags.
+ preprocess_handler: function to my_preprocess text before computing WER.
+ """
+ self._preprocess_handler = preprocess_handler
+ self._html_handler = html_handler
+ self.key_phrases = key_phrases
+ self.aligned_htmls = []
+ self.wer_info = {'sub': 0, 'ins': 0, 'del': 0, 'nw': 0}
+ if key_phrases:
+ # Pre-process key_phrase list
+ if self._preprocess_handler:
+ self.key_phrases = \
+ [self._preprocess_handler(k) for k in self.key_phrases]
+
+ # Init keyphrase_counts for every key phrase
+ self.ref_keyphrase_counts = {}
+ self.hyp_keyphrase_counts = {}
+ self.matched_keyphrase_counts = {}
+ for k in self.key_phrases:
+ self.ref_keyphrase_counts[k] = 0
+ self.hyp_keyphrase_counts[k] = 0
+ self.matched_keyphrase_counts[k] = 0
+ else:
+ self.ref_keyphrase_counts = None
+ self.hyp_keyphrase_counts = None
+ self.matched_keyphrase_counts = None
+
+ def AddHypRef(self, hypothesis, reference):
+ """Update WER when adding one pair of strings: (hypothesis, reference).
+
+ Args:
+ hypothesis: Hypothesis string.
+ reference: Reference string.
+
+ Raises:
+ ValueError: when the program fails to parse edit distance matrix.
+ """
+ if self._preprocess_handler:
+ hypothesis = self._preprocess_handler(hypothesis)
+ reference = self._preprocess_handler(reference)
+
+ # Compute edit distance.
+ hyp_words = hypothesis.split()
+ ref_words = reference.split()
+ distmat = ComputeEditDistanceMatrix(hyp_words, ref_words)
+
+ # Back trace, to distinguish different erroref_words: ins, del, sub.
+ pos_hyp, pos_ref = len(hyp_words), len(ref_words)
+ wer_info = {'sub': 0, 'ins': 0, 'del': 0, 'nw': len(ref_words)}
+ aligned_html = ''
+ matched_ref = ''
+ while pos_hyp > 0 or pos_ref > 0:
+ err_type = ''
+
+ # Distinguish error type by back tracking
+ if pos_ref == 0:
+ err_type = 'ins'
+ elif pos_hyp == 0:
+ err_type = 'del'
+ else:
+ if hyp_words[pos_hyp - 1] == ref_words[pos_ref - 1]:
+ err_type = 'none' # correct error
+ elif distmat[pos_ref][pos_hyp] == distmat[pos_ref - 1][pos_hyp - 1] + 1:
+ err_type = 'sub' # substitute error
+ elif distmat[pos_ref][pos_hyp] == distmat[pos_ref - 1][pos_hyp] + 1:
+ err_type = 'del' # deletion error
+ elif distmat[pos_ref][pos_hyp] == distmat[pos_ref][pos_hyp - 1] + 1:
+ err_type = 'ins' # insersion error
+ else:
+ raise ValueError('fail to parse edit distance matrix.')
+
+ # Generate aligned_html
+ if self._html_handler:
+ if pos_hyp == 0 or not hyp_words:
+ tmph = ' '
+ else:
+ tmph = hyp_words[pos_hyp - 1]
+ if pos_ref == 0 or not ref_words:
+ tmpr = ' '
+ else:
+ tmpr = ref_words[pos_ref - 1]
+ aligned_html = self._html_handler(tmph, tmpr, err_type) + aligned_html
+
+ # If no error, go to previous ref and hyp.
+ if err_type == 'none':
+ matched_ref = hyp_words[pos_hyp - 1] + ' ' + matched_ref
+ pos_hyp, pos_ref = pos_hyp - 1, pos_ref - 1
+ continue
+
+ # Update error.
+ wer_info[err_type] += 1
+
+ # Adjust position of ref and hyp.
+ if err_type == 'del':
+ pos_ref = pos_ref - 1
+ elif err_type == 'ins':
+ pos_hyp = pos_hyp - 1
+ else: # err_type == 'sub'
+ pos_hyp, pos_ref = pos_hyp - 1, pos_ref - 1
+
+ # Verify the computation of edit distance finishes
+ assert distmat[-1][-1] == wer_info['ins'] + \
+ wer_info['del'] + wer_info['sub']
+
+ # Accumulate err_info before the next (hyp, ref).
+ for k in wer_info:
+ self.wer_info[k] += wer_info[k]
+
+ # Collect aligned_htmls.
+ if self._html_handler:
+ self.aligned_htmls += [aligned_html]
+
+ # Update key phrase info.
+ if self.key_phrases:
+ for w in self.key_phrases:
+ self.ref_keyphrase_counts[w] += reference.count(w)
+ self.hyp_keyphrase_counts[w] += hypothesis.count(w)
+ self.matched_keyphrase_counts[w] += matched_ref.count(w)
+
+ def GetWER(self):
+ """Compute Word Error Rate (WER) to summarize word erroref_words.
+
+ Note WER can be larger than 100.0, esp when there are many insertion errors.
+
+ Returns:
+ WER as percentage number, usually between 0.0 to 100.0
+ """
+ nref = self.wer_info['nw']
+ nref = max(1, nref) # non_zero value for division
+ total_error = self.wer_info['ins'] \
+ + self.wer_info['del'] + self.wer_info['sub']
+ return total_error * 100.0 / nref
+
+ def GetKeyPhraseStats(self):
+ """Measure the Jaccard similarity of key phrases between hyps and refs.
+
+ Returns:
+ jaccard_similarity: jaccard similarity, between 0.0 and 1.0
+ F1_keyphrase: F1 score (=2/(1/prec + 1/recall)), between 0.0 and 1.0
+ matched_keyphrases: num of matched key phrases.
+ ref_keyphrases: num of key phrases in the reference strings.
+ hyp_keyphrases: num of key phrases in the hypothesis strings.
+ """
+
+ matched_k = sum(self.matched_keyphrase_counts.values())
+ ref_k = sum(self.ref_keyphrase_counts.values())
+ hyp_k = sum(self.hyp_keyphrase_counts.values())
+ joined_k = ref_k + hyp_k - matched_k
+ joined_k = max(1, joined_k) # non_zero value for division
+ jaccard_similarity = matched_k * 1.0 / joined_k
+
+ f1_k = 2.0 * matched_k / max(ref_k + hyp_k, 1.0)
+ return (jaccard_similarity, f1_k, matched_k, ref_k, hyp_k)
+
+ def GetSummaries(self):
+ """Generate strings to summarize word errors and key phrase errors.
+
+ Returns:
+ str_sum: string summarizing total error, total word and WER.
+ str_details: string breaking down three error types: del, ins, sub.
+ str_str_keyphrases_info: string summarizing kerphrase information.
+ """
+ wer = self.GetWER()
+ nref = self.wer_info['nw']
+ total_error = self.wer_info['ins'] \
+ + self.wer_info['del'] + self.wer_info['sub']
+ str_sum = 'WER = %.2f%% (%.4f%%), total error = %d, total word = %d' % (
+ wer, wer, total_error, nref)
+
+ str_details = 'Error breakdown: del = %.2f%%, ins=%.2f%%, sub=%.2f%%' % (
+ self.wer_info['del'] * 100.0 / nref, self.wer_info['ins'] * 100.0 /
+ nref, self.wer_info['sub'] * 100.0 / nref)
+
+ str_keyphrases_info = ''
+ if self.key_phrases:
+ jaccard_p, f1_p, matched_p, ref_p, hyp_p = self.GetKeyPhraseStats()
+ str_keyphrases_info = ('matched %d key phrases (%d in ref, %d in hyp), '
+ 'jaccard similarity=%.2f, F1=%.2f') % \
+ (matched_p, ref_p, hyp_p, jaccard_p, f1_p)
+
+ return str_sum, str_details, str_keyphrases_info, (wer, total_error, nref)
+
+ def write_html(self, fn_output):
+ try:
+ aligned_html = '\n'.join('
{}
'.format(html) for html in self.aligned_htmls)
+ aligned_html = '{}'.format(aligned_html)
+ with open(fn_output, 'wt') as fp:
+ fp.write('')
+ fp.write('')
+ fp.write('
' % aligned_html)
+ fp.write('')
+ except IOError:
+ print('failed to write diagnosis html')
+
+
+if __name__ == '__main__':
+ if len(sys.argv) < 3 or len(sys.argv) > 4:
+ print("""
+Example of Usage:
+
+ python simple_wer_v2.py file_hypothesis file_reference
+or
+ python simple_wer_v2.py file_hypothesis file_reference file_keyphrases
+
+ where file_hypothesis is the file name for hypothesis text
+ file_reference is the file name for reference text.
+ file_keyphrases (optional) is the filename of key phrases over which
+ you want to measure accuracy.
+
+Or you can use this file as a library, and call class SimpleWER
+ .AddHypRef(hyp, ref): add one pair of hypothesis/reference. You can call this
+ function multiple times.
+ .GetWER(): get the Word Error Rate (WER).
+ .GetKeyPhraseStats(): get stats for key phrases. The first value is Jaccard
+ Similarity of key phrases.
+ .GetSummaries(): generate strings to summarize word error and
+ key phrase errors.
+""")
+ sys.exit(1)
+
+ main(sys.argv)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py
new file mode 100644
index 0000000000000000000000000000000000000000..272580476d9b544e70e7d6c708b5edb35dc27887
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spec2vec.py
@@ -0,0 +1,210 @@
+from typing import List
+
+import torch
+from torch import nn
+
+from nemo.collections.asr.models.configs import common_config as common_cfg
+from nemo.collections.asr.models.spec2vec.spec2vec_config import ConvTransformerBlock
+from nemo.collections.asr.parts.convolution_layers import ConvNormAct, create_pad_mask, Conv
+from nemo.collections.asr.parts.wav2vec import TransformerEncoder
+
+
+class FeatureEncoder(nn.Module):
+ def __init__(self, feat_in, use_conv_mask, conv2d_block: common_cfg.Conv2dBlock,
+ conv_transformer_blocks: List[ConvTransformerBlock],
+ use_tf_pad: bool, ln_eps: float = 1e-5):
+ super().__init__()
+
+ self.use_conv_mask = use_conv_mask
+
+ self.bn_moudles = []
+
+ if conv2d_block:
+ prev_out_channels = 1
+ self.conv2d_block = nn.ModuleList()
+ for conv2d_cfg_i in conv2d_block.layers:
+ layer = ConvNormAct(in_channels=prev_out_channels,
+ conv_type='2d',
+ use_tf_pad=use_tf_pad,
+ ln_eps=ln_eps,
+ **conv2d_cfg_i)
+
+ if isinstance(layer.norm, (nn.BatchNorm2d, nn.BatchNorm1d)):
+ self.bn_moudles.append(layer.norm)
+
+ prev_out_channels = conv2d_cfg_i.filters
+ self.conv2d_block.append(layer)
+ prev_out_channels = conv2d_block.output_dim
+ self.conv2d_block.apply(kaiming_init_conv_weights)
+ else:
+ self.conv2d_block = None
+ prev_out_channels = feat_in
+
+ self.block_modules = nn.ModuleList()
+ for block_cfg in conv_transformer_blocks:
+ for conv_cfg_i in block_cfg.conv_layers:
+ layer = ConvNormAct(in_channels=prev_out_channels,
+ conv_type='1d',
+ use_tf_pad=use_tf_pad,
+ ln_eps=ln_eps,
+ **conv_cfg_i)
+
+ if isinstance(layer.norm, (nn.BatchNorm2d, nn.BatchNorm1d)):
+ self.bn_moudles.append(layer.norm)
+
+ prev_out_channels = conv_cfg_i.filters
+ layer.apply(kaiming_init_conv_weights)
+ self.block_modules.append(layer)
+
+ if block_cfg.transformer_block is not None:
+ block = TransformerEncoder(block_cfg.transformer_block)
+ self.block_modules.append(block)
+ prev_out_channels = block_cfg.transformer_block.encoder.embedding_dim
+
+ self.output_dim = prev_out_channels
+
+ def forward(self, audio_signal, length):
+ # [B, F/D, T]
+ output = audio_signal
+
+ if self.use_conv_mask:
+ pad_mask = create_pad_mask(length, max_len=output.size(2))
+ else:
+ pad_mask = None
+
+ if self.conv2d_block is not None:
+ # [B, F, T] => [B, T, F] =>[B, C, T, F]
+ output = torch.transpose(output, 1, 2).unsqueeze(1)
+ for module in self.conv2d_block:
+ output, length, pad_mask = module(output, length, pad_mask=pad_mask)
+ b, c, t, f = output.size()
+ # [B, C, T, F] => [B, F, C, T] => [B, FxC/D, T]
+ output = output.permute(0, 3, 1, 2).reshape(b, f * c, t)
+
+ for module in self.block_modules:
+ if isinstance(module, ConvNormAct):
+ output, length, pad_mask = module(output, length, pad_mask=pad_mask)
+ else:
+ assert isinstance(module, TransformerEncoder)
+ # [B, D, T] => [B, T, D]
+ output = output.transpose(1, 2)
+ output = module(output, padding_mask=pad_mask)
+ # [B, T, D] => [B, D, T]
+ output = output.transpose(1, 2)
+
+ return output, length, None
+
+ def bn_eval(self):
+ for m in self.bn_moudles:
+ m.eval()
+
+ def get_subsampled_lens(self, lens):
+ if self.conv2d_block is not None:
+ for module in self.conv2d_block:
+ lens = module.update_out_seq_lens(lens)
+
+ for module in self.block_modules:
+ if isinstance(module, ConvNormAct):
+ lens = module.update_out_seq_lens(lens)
+
+ return lens
+
+
+class Projector(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.use_conv_mask = cfg.use_conv_mask
+
+ prev_out_channels = cfg.input_dim
+
+ if cfg.conv_layers is not None:
+ self.conv_layers = nn.ModuleList()
+ for conv_cfg_i in cfg.conv_layers:
+ assert conv_cfg_i.stride == (1,)
+ layer = ConvNormAct(in_channels=prev_out_channels,
+ conv_type='1d',
+ use_tf_pad=cfg.use_tf_pad,
+ ln_eps=cfg.ln_eps,
+ **conv_cfg_i)
+ prev_out_channels = conv_cfg_i.filters
+ self.conv_layers.append(layer)
+ self.conv_layers.apply(kaiming_init_conv_weights)
+ else:
+ self.conv_layers = None
+
+ self.transformer = None if cfg.transformer is None else TransformerEncoder(cfg.transformer)
+
+ if cfg.output_dim is not None:
+ self.output_proj = nn.Linear(prev_out_channels, cfg.output_dim)
+ self.output_dim = cfg.output_dim
+ else:
+ self.output_proj = None
+ self.output_dim = prev_out_channels
+
+ def forward(self, inputs, length):
+ # [B, T, D]
+ assert inputs.shape[0] == length.shape[0]
+ output = inputs
+
+ if (self.conv_layers is not None and self.use_conv_mask) or self.transformer is not None:
+ pad_mask = create_pad_mask(length, max_len=output.size(1))
+ else:
+ pad_mask = None
+
+ if self.conv_layers is not None:
+ # [B, T, D] => [B, D, T]
+ output = output.transpose(1, 2)
+ for conv_i in self.conv_layers:
+ output, length, pad_mask = conv_i(output, length, pad_mask=pad_mask)
+ # [B, D, T] => [B, T, D]
+ output = output.transpose(1, 2)
+
+ if self.transformer is not None:
+ assert pad_mask is not None
+ output = self.transformer(output, padding_mask=pad_mask)
+
+ if self.output_proj is not None:
+ output = self.output_proj(output)
+
+ return output
+
+
+def kaiming_init_conv_weights(m):
+ if isinstance(m, (nn.Conv1d, nn.Conv2d)):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
+ pass # use default init
+ elif isinstance(m, (nn.Dropout, nn.ReLU, nn.Sequential, nn.ModuleList)):
+ pass # ignore modules do not need init
+ elif isinstance(m, (FeatureEncoder, ConvNormAct, Conv)):
+ pass # ignore wrapper modules
+ else:
+ raise ValueError('initializing unknown module type {}'.format(type(m)))
+
+
+class RandomMask(nn.Module):
+ def __init__(self, prob, mask_value=None, mask_dim=None):
+ super().__init__()
+ assert 0 <= prob < 1
+ self.prob = prob
+ if mask_value is not None:
+ assert mask_dim is None
+ self.mask_value = mask_value
+ self.embedding_mask = False
+ else:
+ assert isinstance(mask_dim, int)
+ self.mask_value = nn.Parameter(torch.FloatTensor(mask_dim).uniform_())
+ self.embedding_mask = True
+
+ def forward(self, inputs: torch.Tensor):
+ if not self.training:
+ return inputs
+
+ if self.embedding_mask:
+ mask_shape = inputs.size()[:-1]
+ else:
+ mask_shape = inputs.size()
+ mask_indices = torch.bernoulli(torch.full(mask_shape, self.prob, device=inputs.device)).type(torch.bool)
+ inputs[mask_indices] = self.mask_value
+ return inputs
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e2b1cc0cea2291af4a1a0f5761d131a68ae5580
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/spectr_augment.py
@@ -0,0 +1,144 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+
+import torch
+import torch.nn as nn
+
+
+GAUSSIAN_MASK = [-0.008577285334467888, -0.008312131278216839, -0.003433679463341832, -0.007991528138518333, 0.022046754136681557, -0.014587175101041794, 0.010684879496693611, 0.0032962223049253225, 0.0006217826739884913, -0.0109024653211236, 0.003951661288738251, 0.006253963336348534, 0.005004990380257368, -0.004416681360453367, -0.014572846703231335, -0.006137363612651825, 0.015708647668361664, 0.013687868602573872, -0.0034236200153827667, -0.017063571140170097, -0.005857351701706648, 0.00830098520964384, 0.001996119972318411, 0.0001801295584300533, 0.007846095599234104, 0.006822641938924789, 0.0053718010894954205, -0.006921472027897835, -0.008087077178061008, 0.0019816632848232985, 0.017332371324300766, 0.0013035658048465848, -0.005086318589746952, 0.0038084506522864103, -0.0028235530480742455, 0.004277005326002836, -0.008790107443928719, -0.01645108126103878, -0.00870309118181467, 0.0030529664363712072, 0.0044243973679840565, -0.001729264622554183, 0.0018796413205564022, 0.001113063539378345, 0.001685396651737392, 0.007856708019971848, -0.008503658697009087, 0.004928226582705975, 0.003501072758808732, 0.013370650820434093, 0.015068281441926956, -0.008089980110526085, -0.00913938321173191, 0.010570412501692772, -0.0008485731086693704, -0.0017308632377535105, 0.009554422460496426, 0.0008375696488656104, 0.01573362946510315, -0.00033936771797016263, -0.00738796591758728, 0.004305792506784201, 0.01625652238726616, 0.004601571708917618, 0.0033054756931960583, -0.006255044601857662, -0.004542464856058359, 0.013747476041316986, -0.010799456387758255, 0.0024993526749312878, 0.005865572020411491, -0.013002256862819195, 0.005194664001464844, -0.021207045763731003, 0.007146795745939016, 0.001368773402646184, -0.014667760580778122, 0.010809154249727726, -0.005281668156385422, 0.009397738613188267, -0.0117409098893404, 0.010593295097351074, 0.00379750388674438, 0.005227712448686361, 0.005195990204811096, 0.007941234856843948, -0.008073929697275162, 0.0005659427843056619, -0.002949729096144438, 0.004707518499344587, 0.00013310337089933455, -0.0021645540837198496, 0.005021876189857721, 0.006173995789140463, 0.007487660273909569, 0.00892797950655222, 0.008880391716957092, 0.017241517081856728, 0.007244352716952562, -0.0020703296177089214, -0.003922614734619856, -0.017371725291013718, 0.0036656244192272425, 0.0008915121434256434, -0.010720319114625454, -0.0029386195819824934, -0.004162450321018696, -0.003555792849510908, -0.009222142398357391, 0.008136272430419922, -0.016355125233530998, -0.0019096146570518613, 0.009561257436871529, -0.0009358802926726639, -0.009368732571601868, -0.0006803714786656201, -0.0032679459545761347, -0.010729965753853321, -0.006920733489096165, -0.007359025534242392, -0.0036679436452686787, -0.01085688453167677, 0.0032564126886427402, -0.00396075751632452, 0.003947695717215538, 0.006906456314027309, -0.00874226726591587, -0.010531643405556679, 0.002544721122831106, 0.003992108628153801, 0.002931957133114338, -0.008593415841460228, -0.00834380928426981, -0.0004008698451798409, -0.007474125362932682, -0.0018838047981262207, -0.010623954236507416, 0.0004698234552051872, -0.0064691524021327496, -0.001641935552470386, 0.030333172529935837, 0.01937473751604557, 0.008528976701200008, -0.015279320999979973, -0.009957612492144108, 0.004345834255218506, 0.0052221533842384815, 0.028346601873636246, 0.00015258256462402642, 0.017097460106015205, -0.011207249946892262, -0.0013319909339770675, -0.01574462465941906, -0.001186217530630529, 0.002741827629506588, -0.0027266093529760838, -0.0019425811478868127, -0.016291635110974312, 0.0014635918196290731, -0.0068401917815208435, -0.0023575483355671167, -0.023059917613863945, -0.004130368120968342, 0.018156377598643303, -0.010979206301271915, 0.0030661290511488914, 0.001635069027543068, -0.013064907863736153, -0.0013261985732242465, 0.008750525303184986, 0.01025752630084753, 0.006366524379700422, -0.019202180206775665, 0.012242366559803486, 0.005239878781139851, 0.006112856324762106, -0.0009505008929409087, 0.007498623337596655, 0.004345917142927647, 0.0019737775437533855, 0.01724175363779068, -0.011013401672244072, 0.01421738788485527, -0.00010386518988525495, 0.0005218135192990303, -7.810740498825908e-05, -0.0013396443100646138, 0.00891269650310278, 0.006826599128544331, -0.005983854178339243, -0.008014648221433163, -0.004054467659443617, 0.003911960404366255, 0.012865566648542881, 0.006723911035805941, 0.00459368946030736, 0.008541394025087357, 0.014683294110000134, 0.017786560580134392, -0.006093125324696302, 0.0011138939298689365, 0.005661547649651766, -0.001207416644319892, -0.008716529235243797, 0.004718406591564417, -0.004078548867255449, 0.004175822716206312, -0.007703948300331831, 0.01632404327392578, -0.00023797148605808616, 0.0023386836983263493, 0.0049704634584486485, 0.01812780275940895, 0.008168922737240791, -0.01889769174158573, -0.007803930900990963, 0.004977164324373007, 0.005988257005810738, -0.0024359619710594416, 0.015691304579377174, -0.004170612432062626, 0.011284936219453812, -0.012687956914305687, -0.013468815013766289, -0.0056364708580076694, 0.003777650184929371, -0.010477692820131779, -3.670304431580007e-05, 0.002994526643306017, -0.02478531375527382, -0.0186957735568285, -0.000980377197265625, -0.006090754177421331, -0.0008337362669408321, 0.0042216661386191845, -0.0017890859162434936, -0.017191831022500992, 0.007558343932032585, -0.0004839670436922461, 0.0006820259150117636, -0.0036172219552099705, 0.012755093164741993, 0.007131294813007116, 0.01682445965707302, -0.0033714391756802797, 0.011848727241158485, -0.008541670627892017, 0.005448013544082642, 0.005193810444325209, 0.0003855297400150448, 0.009456363506615162, -0.004641769919544458, 0.025975538417696953, -0.008149398490786552, -0.014135519042611122, 0.006781088653951883, 0.004117966629564762, -0.006494637578725815, -0.007128587923943996, 0.006647801958024502, -0.012087219394743443, 0.001748546026647091, 0.01748277060687542, -0.004554575774818659, 0.0036450892221182585, -0.010104892775416374, -0.0035864445380866528, 0.0058875922113657, -0.0011381434742361307, 0.0015288969734683633, 0.0013207353185862303, -0.003274752525612712, 0.0023052070755511522, -0.011285059154033661, 0.011404736898839474, -0.001985494513064623, -0.008716653101146221, 0.005484268069267273, -0.0021095615811645985, 0.015584509819746017, -0.008399765007197857, 0.007663742173463106, -0.0046447026543319225, 0.01980607770383358, -0.021430866792798042, 0.0019235932268202305, 0.01850181445479393, 0.014331997372210026, 0.004704942926764488, -0.004087083041667938, -0.0052873240783810616, -0.004040342289954424, -0.0009587344247847795, -0.001312632579356432, 0.006471287924796343, -0.004858402069658041, 0.006215596571564674, -0.01835842989385128, -0.013215125538408756, -0.012908847071230412, -0.005191713571548462, 0.005219723097980022, -0.0015103589976206422, -0.006740736309438944, -0.005966144613921642, -0.0008368652779608965, 0.003591155633330345, -0.0025086551904678345, 0.018543150275945663, 0.01635683700442314, 0.003782284678891301, -0.012239447794854641, 0.0011241791071370244, -8.145845640683547e-05, 0.01495729386806488, 0.0033017871901392937, 0.00019223176059313118, 0.007682753726840019, 0.011071891523897648, 0.011146822944283485, -0.006912447512149811, 0.0034845354966819286, 0.005456835497170687, 0.002649294678121805, -0.005829342640936375, -0.002240463625639677, 0.0029810150153934956, 0.0026485121343284845, 0.012740016914904118, 0.006942421197891235, 0.008712748065590858, -0.002921474166214466, -0.003271800698712468, -0.0006888209609314799, 0.000602235842961818, -0.009740591049194336, -0.005469066556543112, 0.0027335130143910646, -0.007446215022355318, -0.010410390794277191, -0.020527798682451248, -0.0007794187986291945, -0.010217029601335526, -0.008854899555444717, 0.0018076488049700856, 0.003745997091755271, -0.008995960466563702, 0.006499497219920158, -0.009520279243588448, -0.0013822759501636028, 0.005432894919067621, 0.0048836348578333855, -0.006969304755330086, 0.009624844416975975, 0.002221961971372366, 0.012857971712946892, 0.00029701043968088925, 0.01846439018845558, -0.0010167461587116122, 0.003362444229424, 0.018802843987941742, 0.007899937219917774, 0.007362710777670145, -0.01259714737534523, -0.004353848285973072, 0.009529988281428814, -0.010434559546411037, -0.00114968151319772, -0.015083436854183674, -0.000513288367073983, -0.012708902359008789, 0.009817283600568771, -0.01616625115275383, 0.02099389024078846, 0.0011255480349063873, 0.002087818691506982, 0.004141946788877249, -0.004236323293298483, -0.008043688721954823, -0.01005634292960167, -0.01211725827306509, 0.008502005599439144, -0.003608573926612735, 0.0013564183609560132, 0.0025668770540505648, 0.006136550568044186, 0.01201977115124464, 0.025350334122776985, 0.012129077687859535, -0.0032070353627204895, 0.0003363523574080318, -0.00036432372871786356, 6.168012623675168e-05, -0.010630798526108265, -0.001909107668325305, 0.0031751287169754505, 0.023505495861172676, 0.006265965756028891, -0.000127899824292399, 0.027016283944249153, -0.003139286069199443, 0.007756243925541639, 0.012871683575212955, -0.005639276001602411, 0.006201721262186766, -0.007926560938358307, -0.007784618530422449, 0.0018705694237723947, 0.011534891091287136, 0.003263025777414441, -0.008872064761817455, 0.011395181529223919, -0.0061043002642691135, -0.0020404469687491655, -0.002608739770948887, 0.0014213520335033536, 7.451167039107531e-05, -0.0006841712747700512, 0.00638044998049736, 0.02477930672466755, -0.0001436929014744237, 0.009057887829840183, -0.0005273017450235784, -0.016813546419143677, -0.011913691647350788, 0.008906804956495762, 0.0009794242214411497, -0.002302555600181222, 0.012043703347444534, -0.00899094995111227, 0.0017951279878616333, 0.010412560775876045, -0.0018607425736263394, 0.003462078981101513, 0.010766361840069294, 0.007733880076557398, 0.0037096040323376656, 0.013803163543343544, -0.010908039286732674, 0.012310138903558254, 0.011781050823628902, -0.022002901881933212, 0.01713118888437748, 0.004179463256150484, 0.00031042384216561913, 0.002601897343993187, 0.002513417275622487, -0.0009618049371056259, 0.0004392332921270281, -0.01457865722477436, -0.00014599964197259396, -0.016580622643232346, -0.014996372163295746, 0.0037527403328567743, -0.008046748116612434, 0.0020090779289603233, -0.0033000593539327383, 0.0052381521090865135, 0.001489385380409658, -0.010585324838757515, -0.015648989006876945, -0.0017674925038591027, 0.004076640121638775, 0.019764112308621407, -0.010859061032533646, -0.0004517346096690744, 0.019380798563361168, -0.0069849626161158085, 0.006805767305195332, 0.0010639704996719956, -0.0008604522445239127, 0.014004685916006565, -0.009115546941757202, 0.020015710964798927, 0.0020416032057255507, -0.005136480089277029, -0.014490129426121712, -0.013324854895472527, 0.02655731327831745, -0.006656433921307325, -0.003929227590560913, 0.0036637713201344013, -0.002843528985977173, 0.016735395416617393, -0.0026499521918594837, -9.097589645534754e-05, 0.0009586272644810379, -0.008444109931588173, -0.01267432514578104, -0.00987020693719387, 0.024359915405511856, 0.006355916149914265, -0.010813186876475811, -0.008563755080103874, 0.001825400278903544, 0.015858624130487442, 0.008124123327434063, 0.007248807232826948, 0.0014658113941550255, 0.00889967568218708, -0.018838992342352867, 0.012073985300958157, 0.007954470813274384, 0.0040962728671729565, 0.003913100343197584, -0.009089702740311623, -0.01628614217042923, -0.005144990514963865, 0.006413666065782309, -0.014760496094822884, -0.010384202003479004, -0.013040153309702873, 0.018687430769205093, -0.0015114899724721909, -0.007259591482579708, -0.002887800568714738, 0.01662031002342701, 0.01137789711356163, 0.0007028636755421758, 0.004064011387526989, 0.007000538054853678, -0.005053787026554346, 0.012506005354225636, -0.009795127436518669, 0.0013481989735737443, 0.01153571903705597, 0.0005770407733507454, 0.00843889731913805, -0.004305741284042597, -0.00523682776838541, -0.003950543235987425, -0.030804451555013657, 0.0019242960261180997, -0.001121998648159206, -0.00024091487284749746, -0.011150459758937359, -0.0006252250750549138, -0.008173838257789612, 0.0025749732740223408, 0.012828282080590725, 0.0037352831568568945, 0.017000781372189522, 0.019764423370361328, -0.006141415797173977, 0.009571244940161705, 0.0060485838912427425, -0.003701487323269248, 0.012796318158507347, -0.005578612443059683, -0.002949218498542905, -0.0077390591613948345, 0.013134066946804523, 0.01348984893411398, -0.017807014286518097, -0.007176446728408337, -0.0025080926716327667, 0.0026706154458224773, 0.006270487792789936, -0.001722719520330429, 0.014065185561776161, -0.012098010629415512, -0.006948764901608229, 0.007107093930244446, 0.011564299464225769, 0.015365575440227985, -0.001405196264386177, 0.013922395184636116, -0.011645046062767506, 0.019185440614819527, 0.008669205941259861, -0.006402950268238783, -0.02930634282529354, 0.012674490921199322, 0.0020986138842999935, 0.0073446049354970455, -0.008211275562644005, -0.010451032780110836, 0.0039263502694666386, 0.010214317589998245, 0.005137004889547825, 0.003445018082857132, 0.011561079882085323, 0.014101037755608559, -0.0009882557205855846, -0.020502908155322075, -0.01304234005510807, -0.0010743135353550315, 0.006024991162121296, 0.005487101152539253, 0.015423699282109737, 0.010340548120439053, 0.007406091317534447, -0.002593359677121043, -0.014484986662864685, -0.008927910588681698, 0.01809288188815117, -8.335654797519965e-07, 0.016086341813206673, -0.007467319723218679, 0.006248668301850557, 0.010234993882477283, 0.01105313841253519, 0.017129629850387573, 0.01001526229083538, 0.008971969597041607, 0.0019335917895659804, 0.014964735135436058, 0.0032235209364444017, -0.01558700855821371, -0.0002231745602330193, 1.0325457878934685e-05, 0.005610847380012274, -0.0014725240180268884, -0.007073611952364445, 0.006045978516340256, 0.00661000469699502, 0.016458045691251755, 0.004776420537382364, -0.0024809655733406544, -0.001141632441431284, 0.020380591973662376, 0.001985315466299653, 0.0016066418029367924, -0.0002630813978612423, -0.0023715763818472624, -0.005043766926974058, -0.0006530124810524285, -0.0011994787491858006, -0.005200596526265144, -0.011772070080041885, 0.020897328853607178, -0.0030809161253273487, 0.0151188550516963, -0.002648375928401947, -0.0020464551635086536, -0.011824622750282288, 0.006495086010545492, -0.006923593580722809, -0.010816085152328014, -0.0005119291599839926, 0.019806137308478355, 0.01142488420009613, 0.006749970838427544, -0.015500311739742756, -0.00845906138420105, -0.0021400016266852617, -0.0008869305020198226, 0.016512639820575714, 0.0022560760844498873, -0.003913175780326128, 0.007876270450651646, 0.0010393362026661634, 0.009496960788965225, 0.008087781257927418, 0.0075989654287695885, 0.007650574669241905, 0.011847944930195808, 0.006889567244797945, 0.005724477581679821, 0.003269805572926998, 0.017861217260360718, 0.014128664508461952, -0.018311243504285812, -0.00034174948814325035, 0.010595879517495632, -0.007073213811963797, -0.0037547526881098747, 0.007721452973783016, -0.013909009285271168, 0.009047291241586208, -0.005391568876802921, 0.010471170768141747, -0.005925311706960201, -0.010769817978143692, -0.01213113870471716, 0.013672600500285625, -0.025970296934247017, 0.004848531447350979, -3.2736243156250566e-05, -0.0014015173073858023, -0.006384227890521288, 0.00352313625626266, 0.007666777819395065, 0.011333705857396126, -0.02562958188354969, 0.007899546064436436, -0.0003998324682470411, 0.00041209004120901227, 0.003894905559718609, 0.011146043427288532, 0.012012973427772522, 0.005776166915893555, 0.01267810445278883, 0.015597822144627571, -0.006004557479172945, -0.005119791720062494, 0.004542575217783451, -0.006028160918504, 0.0002760722709354013, 0.008363571017980576, 0.004950536414980888, 0.006884160451591015, 0.002470653271302581, 0.008553436025977135, 0.016060978174209595, -0.004487414378672838, 0.0008720596670173109, -0.0011375041212886572, 0.006429357919842005, -0.014677043072879314, 0.011124462820589542, -0.0023991605266928673, -0.003062682691961527, -0.0004514633328653872, 0.007666073273867369, 0.008637756109237671, -0.009318140335381031, 0.017036888748407364, -0.001699244137853384, -0.001919509842991829, -0.010616443119943142, -0.014744545333087444, 0.0023425209801644087, 0.015598481521010399, 0.004587096627801657, -0.006646877154707909, -0.0003353591891936958, 0.018362227827310562, -0.0008180644363164902, -0.009953021071851254, 0.009997352957725525, 0.006935094483196735, -0.0041261701844632626, 0.008880453184247017, 0.014008709229528904, -0.002823092043399811, 0.00992603413760662, -0.012414450757205486, 0.00843184906989336, 0.0020578710827976465, 0.016727814450860023, 0.019950158894062042, -0.004741964861750603, -0.0005425591953098774, -0.0058597358874976635, -0.0017816543113440275, 0.0027736839838325977, -0.012360996566712856, 0.003590812673792243, 0.005126148462295532, 0.004654150456190109, -0.006369463168084621, -0.003772721393033862, 0.02091693878173828, 0.01528538204729557, -0.003835881594568491, -0.003409450873732567, 0.010528912767767906, -0.0017828766722232103, 0.02516132965683937, 0.00814846158027649, -0.004559976980090141, -0.011295965872704983, 0.010371438227593899, 0.0006707622087560594, 0.005064355209469795, -0.010145595297217369, 0.004426772706210613, -0.005737415049225092, 0.0012003653682768345, 0.0036150291562080383, 0.011215592734515667, -0.0040545896627008915, 0.0034540158230811357, -0.008764381520450115, 0.014239713549613953, -0.0008067164453677833, -0.0077876923605799675, -0.018948623910546303, 0.003514879383146763, 0.0030799901578575373, 0.017965681850910187, -2.2737660401617177e-05, 0.007884345017373562, 0.003565009217709303, -0.012896131724119186, 0.0026204099413007498, -0.0020277798175811768, -0.005030298605561256, 0.0028232778422534466, -0.010799753479659557, 0.0004869748081546277, 0.00079371128231287, 0.0038255134131759405, 0.01625426858663559, -0.007327786646783352, -0.0006097709410823882, 0.021006541326642036, -0.004630157724022865, -0.015508498065173626, -0.011421660892665386, 0.004675465635955334, -0.0020752830896526575, 0.00651256600394845, -0.001388593576848507, -0.018634997308254242, -0.013168826699256897, 0.00600035535171628, -0.0051271733827888966, -0.0009046715567819774, 0.008201603777706623, 0.004723501857370138, -0.003945730160921812, 0.002696358598768711, -0.00534584978595376, 0.014200250618159771, 0.0032823574729263783, -0.009708631783723831, -0.002968631684780121, -0.0013192173792049289, -0.02307787910103798, 0.0048147342167794704, 0.009474620223045349, 0.0016143537359312177, 0.0026406359393149614, -0.004522481467574835, -0.00021635316079482436, 0.003252497175708413, -0.012805595062673092, 0.0034197745844721794, -0.016244668513536453, 0.001594359870068729, 0.01458613108843565, 0.010125475935637951, -0.009164906106889248, 0.006422918755561113, 0.002428187755867839, -0.005929546430706978, -0.012289043515920639, -0.00899637583643198, 6.348137685563415e-05, 0.01821526326239109, -0.0014832826564088464, -0.010244476608932018, -0.0010153147159144282, -0.005283149890601635, -0.012473183684051037, 0.007026445120573044, -0.011461296118795872, 0.0033159609884023666, -0.01289116870611906, 0.01181294210255146, 0.012677633203566074, -0.0013208106392994523, 0.003216641489416361, -0.012547919526696205, 0.004124876111745834, 0.0009470342774875462, -0.002027716487646103, 0.02221308834850788, -0.0008846950950101018, -0.0014873967738822103, 0.0159499179571867, 0.006823782809078693, -0.014584256336092949, -0.008290774188935757, -0.0120676439255476, -0.005099371075630188, -0.0080600306391716, 0.014079704880714417, -0.0020367265678942204, -0.008362890221178532, -0.001621987670660019, 0.010619360022246838, 0.013944074511528015, 0.0009812230709940195, 0.01836475357413292, 0.00273580988869071, -0.008425148203969002, -0.009124254807829857, 0.00011686794459819794, 0.0009999654721468687, 0.015176774933934212, -0.0009844052838161588, 0.013951435685157776, 0.004029604606330395, 0.006666754838079214, 0.003529589157551527, 0.003615013789385557, -0.02800256945192814, 0.006249892991036177, -0.008050324395298958, 0.005056394264101982, -0.006130233872681856, 0.006157447583973408, -0.010642053559422493, 0.0012126392684876919, -0.006806216202676296, 0.004176209215074778, -0.006571719888597727, -0.009268700145184994, -0.0075151631608605385, -0.0015292258467525244, -0.010599321685731411, 0.003558061085641384, 0.007374519016593695, 0.003586599137634039, 0.017905788496136665, 0.008015276864171028, -0.009034614078700542, 0.0018911826191470027, -0.005265430081635714, 0.004869093652814627, -0.0001151503311120905, -0.003885870799422264, 0.004713456612080336, 0.005261662881821394, -0.006401803344488144, 0.004026818089187145, 0.000506703567225486, 0.01196456328034401, -0.007442399859428406, 0.007578311488032341, 0.011165248230099678, -0.0009072477114386857, 0.010882778093218803, -0.0029303899500519037, 0.0024887293111532927, 0.004268537741154432, -0.003925780300050974, 0.01125251967459917, 0.0056076375767588615, 0.0013249896001070738, -0.0064854929223656654, 0.008991091512143612, 0.006860945839434862, 0.01103312149643898, 0.009761952795088291, -0.010707248002290726, -0.013999280519783497, 0.00729756522923708, -0.011525528505444527, 0.007183250039815903, 0.020575670525431633, -0.0008294707513414323, 0.002093465067446232, 0.005759619176387787, -0.0024755089543759823, 0.006597691681236029, 0.01883961632847786, -0.0017893409822136164, 0.004597699735313654, 0.0020806791726499796, -0.016394654288887978, 0.00436061155050993, -0.0021841854322701693, -0.01360281091183424, 0.0053757247515022755, 0.0006083177286200225, 0.0009018208365887403, 0.023406485095620155, -0.0009405831224285066, 0.018612656742334366, 0.0033334933687001467, -0.003511708928272128, -0.0026272705290466547, -0.0007344239274971187, 0.00568321393802762, -0.0053346729837358, -0.0017605028115212917, -0.010383752174675465, 0.018871944397687912, 0.0006560813635587692, -0.010139352641999722, 0.0026594139635562897, -0.0033114543184638023, 0.017703739926218987, -0.00698117958381772, -0.0020041917450726032, -0.004826934076845646, 0.0005630375235341489, 0.0032594038639217615, -0.0009713696199469268, -0.008052065037190914, -0.0027339174412190914, -0.010677228681743145, 0.004025970585644245, -0.012248994782567024, 0.015234006568789482, -0.005934533197432756, 0.01589520275592804, -0.012386002577841282, 0.009487216360867023, -0.0009444615570828319, 0.0125525938346982, -0.0013922285288572311, 0.013387584127485752, 0.0037463558837771416, -0.007817583158612251, 0.01384445745497942, -0.005437555257230997, -0.0015223644440993667, -0.015931112691760063, -0.010777517221868038, 0.02394471876323223, 0.011097057722508907, 0.010094624944031239, 0.012511318549513817, 0.0038457082118839025, 0.009449491277337074, 0.005493646953254938, 0.011773099191486835, 0.0019221196416765451, -0.013554790988564491, 0.01794261485338211, -0.0231052003800869, 0.008189410902559757, 0.006517274770885706, -0.004495919682085514, 0.001762248226441443, 0.008912311866879463, 0.009711232967674732, -0.00761059345677495, 0.02114887163043022]
+
+
+class SpecAugment(nn.Module):
+ """
+ Zeroes out(cuts) random continuous horisontal or
+ vertical segments of the spectrogram as described in
+ SpecAugment (https://arxiv.org/abs/1904.08779).
+
+ params:
+ freq_masks - how many frequency segments should be cut
+ time_masks - how many time segments should be cut
+ freq_width - maximum number of frequencies to be cut in one segment
+ time_width - maximum number of time steps to be cut in one segment.
+ Can be a positive integer or a float value in the range [0, 1].
+ If positive integer value, defines maximum number of time steps
+ to be cut in one segment.
+ If a float value, defines maximum percentage of timesteps that
+ are cut adaptively.
+ """
+
+ def __init__(
+ self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, max_time_masks=20, gauss_mask_std=0.0, rng=None,
+ ):
+ super(SpecAugment, self).__init__()
+
+ self._rng = random.Random() if rng is None else rng
+
+ self.freq_masks = freq_masks
+ self.time_masks = time_masks
+ self.max_time_masks = max_time_masks
+
+ self.freq_width = freq_width
+ self.time_width = time_width
+
+ self.gauss_mask_std = gauss_mask_std
+
+ if isinstance(time_width, int):
+ self.adaptive_temporal_width = False
+ else:
+ if time_width > 1.0 or time_width < 0.0:
+ raise ValueError('If `time_width` is a float value, must be in range [0, 1]')
+
+ self.adaptive_temporal_width = True
+
+ if isinstance(time_masks, int):
+ self.adaptive_time_mask = False
+ else:
+ if time_masks >= 1.0 or time_masks < 0.0:
+ raise ValueError('If `time_width` is a float value, must be in range [0, 1]')
+
+ self.adaptive_time_mask = True
+
+ @torch.no_grad()
+ def forward(self, x, length):
+ B, D, T = x.shape
+
+ for idx in range(B):
+ for _ in range(self.freq_masks):
+ x_left = self._rng.randint(0, D - self.freq_width)
+
+ w = self._rng.randint(0, self.freq_width)
+
+ x[idx, x_left : x_left + w, :] = 0.0
+
+ if self.adaptive_temporal_width:
+ time_width = max(1, int(length[idx] * self.time_width))
+ else:
+ time_width = self.time_width
+
+ if self.adaptive_time_mask:
+ time_masks = int(length[idx] * self.time_masks)
+ time_masks = min(time_masks, self.max_time_masks)
+ else:
+ time_masks = self.time_masks
+
+ for _ in range(time_masks):
+ y_left = self._rng.randint(0, length[idx] - time_width)
+
+ w = self._rng.randint(0, time_width)
+
+ if self.gauss_mask_std == 0:
+ x[idx, :, y_left:y_left + w] = 0.0
+ else:
+ x[idx, :, y_left:y_left + w] = torch.normal(mean=0, std=self.gauss_mask_std, size=(D, w)).to(x.device)
+
+ return x
+
+
+class SpecCutout(nn.Module):
+ """
+ Zeroes out(cuts) random rectangles in the spectrogram
+ as described in (https://arxiv.org/abs/1708.04552).
+
+ params:
+ rect_masks - how many rectangular masks should be cut
+ rect_freq - maximum size of cut rectangles along the frequency dimension
+ rect_time - maximum size of cut rectangles along the time dimension
+ """
+
+ def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
+ super(SpecCutout, self).__init__()
+
+ self._rng = random.Random() if rng is None else rng
+
+ self.rect_masks = rect_masks
+ self.rect_time = rect_time
+ self.rect_freq = rect_freq
+
+ @torch.no_grad()
+ def forward(self, x):
+ sh = x.shape
+
+ for idx in range(sh[0]):
+ for i in range(self.rect_masks):
+ rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
+ rect_y = self._rng.randint(0, sh[2] - self.rect_time)
+
+ w_x = self._rng.randint(0, self.rect_time)
+ w_y = self._rng.randint(0, self.rect_freq)
+
+ x[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
+
+ return x
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..5477982e3f00feaaa1b7b7fa589ededdea7931dd
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/subsampling.py
@@ -0,0 +1,138 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+import torch.nn as nn
+
+
+class ConvSubsampling(torch.nn.Module):
+ """Convolutional subsampling which supports VGGNet and striding approach introduced in:
+ VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf
+ Striding Subsampling:
+ "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al.
+ Args:
+ subsampling (str): The subsampling technique from {"vggnet", "striding"}
+ subsampling_factor (int): The subsampling factor which should be a power of 2
+ feat_in (int): size of the input features
+ feat_out (int): size of the output features
+ conv_channels (int): Number of channels for the convolution layers.
+ activation (Module): activation function, default is nn.ReLU()
+ """
+
+ def __init__(self, subsampling, subsampling_factor, feat_in, feat_out, conv_channels, activation=nn.ReLU()):
+ super(ConvSubsampling, self).__init__()
+ self._subsampling = subsampling
+
+ if subsampling_factor % 2 != 0:
+ raise ValueError("Sampling factor should be a multiply of 2!")
+ self._sampling_num = int(math.log(subsampling_factor, 2))
+
+ in_channels = 1
+ layers = []
+ if subsampling == 'vggnet':
+ self._padding = 0
+ self._stride = 2
+ self._kernel_size = 2
+ self._ceil_mode = True
+
+ for i in range(self._sampling_num):
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
+ )
+ )
+ layers.append(activation)
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=conv_channels, out_channels=conv_channels, kernel_size=3, stride=1, padding=1
+ )
+ )
+ layers.append(activation)
+ layers.append(
+ torch.nn.MaxPool2d(
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._padding,
+ ceil_mode=self._ceil_mode,
+ )
+ )
+ in_channels = conv_channels
+ elif subsampling == 'striding':
+ self._padding = 0
+ self._stride = 2
+ self._kernel_size = 3
+ self._ceil_mode = False
+
+ for i in range(self._sampling_num):
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=conv_channels,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ padding=self._padding,
+ )
+ )
+ layers.append(activation)
+ in_channels = conv_channels
+ else:
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
+
+ in_length = feat_in
+ for i in range(self._sampling_num):
+ out_length = calc_length(
+ length=int(in_length),
+ padding=self._padding,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ ceil_mode=self._ceil_mode,
+ )
+ in_length = out_length
+
+ self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
+ self.conv = torch.nn.Sequential(*layers)
+
+ def forward(self, x, lengths):
+ x = x.unsqueeze(1)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+
+ # TODO: improve the performance of length calculation
+ new_lengths = lengths
+ for i in range(self._sampling_num):
+ new_lengths = [
+ calc_length(
+ length=int(length),
+ padding=self._padding,
+ kernel_size=self._kernel_size,
+ stride=self._stride,
+ ceil_mode=self._ceil_mode,
+ )
+ for length in new_lengths
+ ]
+
+ new_lengths = torch.IntTensor(new_lengths).to(lengths.device)
+ return x, new_lengths
+
+
+def calc_length(length, padding, kernel_size, stride, ceil_mode):
+ """ Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
+ if ceil_mode:
+ length = math.ceil((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1)
+ else:
+ length = math.floor((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1)
+ return length
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f8cb084b36262d45753dc7f34327c5431f0c142
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_beam_decoding.py
@@ -0,0 +1,869 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from nemo.collections.asr.modules import rnnt_abstract, TransformerTDecoder
+from nemo.collections.asr.parts import rnnt_utils
+from nemo.collections.asr.parts.rnnt_utils import Hypothesis, NBestHypotheses
+from nemo.core.classes import Typing, typecheck
+from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
+
+
+class BeamTransformerTInfer(Typing):
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
+ "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"predictions": NeuralType(elements_type=HypothesisType())}
+
+ def __init__(
+ self,
+ decoder_model: TransformerTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ beam_size: int,
+ beam_temperature: float,
+ beam_combine_path: bool,
+ beam_max_exp_step: int,
+ beam_prune_exp: bool,
+ beam_prune_exp_full: bool,
+ beam_word_reward_ratio: float,
+ search_type: str = 'convtt',
+ score_norm: bool = True,
+ return_best_hypothesis: bool = True,
+ tsd_max_sym_exp_per_step: Optional[int] = 50,
+ alsd_max_target_len: Union[int, float] = 1.0,
+ nsc_max_timesteps_expansion: int = 1,
+ nsc_prefix_alpha: int = 1,
+ ):
+ """
+ Beam Search implementation ported from ESPNet implementation -
+ https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py
+
+ Sequence level beam decoding or batched-beam decoding, performed auto-repressively
+ depending on the search type chosen.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+
+ beam_size: number of beams for beam search. Must be a positive integer >= 1.
+ If beam size is 1, defaults to stateful greedy search.
+ This greedy search might result in slightly different results than
+ the greedy results obtained by GreedyRNNTInfer due to implementation differences.
+
+ For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer.
+
+ search_type: str representing the type of beam search to perform.
+ Must be one of ['beam', 'tsd', 'alsd']. 'nsc' is currently not supported.
+
+ Algoritm used:
+ `beam` - basic beam search strategy. Larger beams generally result in better decoding,
+ however the time required for the search also grows steadily.
+
+ `tsd` - time synchronous decoding. Please refer to the paper:
+ [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
+ for details on the algorithm implemented.
+
+ Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions.
+ For longer sequences, T is greater, and can therefore take a long time for beams to obtain
+ good results. This also requires greater memory to execute.
+
+ `alsd` - alignment-length synchronous decoding. Please refer to the paper:
+ [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
+ for details on the algorithm implemented.
+
+ Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth
+ factor of T + U_max, where U_max is the maximum target length expected during execution.
+
+ Generally, T + U_max < T * max_symmetric_expansions. However, ALSD beams are non-unique,
+ therefore it is required to use larger beam sizes to achieve the same (or close to the same)
+ decoding accuracy as TSD.
+
+ For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.
+
+ score_norm: bool, whether to normalize the scores of the log probabilities.
+
+ return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N),
+ or return all N hypothesis (sorted with best score first). The container class changes based
+ this flag -
+ When set to True (default), returns a single Hypothesis.
+ When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis.
+
+ # The following arguments are specific to the chosen `search_type`
+
+ tsd_max_sym_exp_per_step: Used for `search_type=tsd`. The maximum symmetric expansions allowed
+ per timestep during beam search. Larger values should be used to attempt decoding of longer
+ sequences, but this in turn increases execution time and memory usage.
+
+ alsd_max_target_len: Used for `search_type=alsd`. The maximum expected target sequence length
+ during beam search. Larger values allow decoding of longer sequences at the expense of
+ execution time and memory.
+
+ # The following two flags are placeholders and unused until `nsc` implementation is stabilized.
+
+ nsc_max_timesteps_expansion: Unused int.
+
+ nsc_prefix_alpha: Unused int.
+ """
+ self.decoder = decoder_model
+ self.joint = joint_model
+
+ self.blank = blank_index
+ assert self.blank == self.decoder.blank_idx
+ self.vocab_size = decoder_model.vocab_size
+ self.search_type = search_type
+ self.return_best_hypothesis = return_best_hypothesis
+
+ if beam_size < 1:
+ raise ValueError("Beam search size cannot be less than 1!")
+
+ self.beam_size = beam_size
+ self.score_norm = score_norm
+
+ self.beam_stepwise_ln_alpha = 0.
+ self.beam_word_reward_ratio = beam_word_reward_ratio
+ if self.beam_stepwise_ln_alpha > 0:
+ assert self.score_norm
+ assert self.beam_word_reward_ratio == 0
+ if self.beam_word_reward_ratio > 0:
+ assert self.score_norm
+ assert self.beam_stepwise_ln_alpha == 0
+ self.beam_combine_path = beam_combine_path
+ self.beam_max_exp_step = beam_max_exp_step
+ self.beam_temperature = beam_temperature
+ self.beam_prune_exp = beam_prune_exp
+ self.beam_prune_exp_full = beam_prune_exp_full
+
+ assert self.beam_size > 1 and self.search_type == 'convtt'
+
+ if self.beam_size == 1:
+ self.search_algorithm = self.greedy_search
+ elif search_type == "default":
+ self.search_algorithm = self.default_beam_search
+ elif search_type == "tsd":
+ self.search_algorithm = self.time_sync_decoding
+ elif search_type == "alsd":
+ self.search_algorithm = self.align_length_sync_decoding
+ elif search_type == "convtt":
+ self.search_algorithm = self.convtt_beam_search
+ elif search_type == "nsc":
+ raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.")
+ # self.search_algorithm = self.nsc_beam_search
+ else:
+ raise NotImplementedError(
+ f"The search type ({search_type}) supplied is not supported!\n"
+ f"Please use one of : (default, tsd, alsd, nsc)"
+ )
+
+ if tsd_max_sym_exp_per_step is None:
+ tsd_max_sym_exp_per_step = -1
+
+ if search_type in ['tsd', 'alsd', 'nsc'] and not self.decoder.blank_as_pad:
+ raise ValueError(
+ f"Search type was chosen as '{search_type}', however the decoder module provided "
+ f"does not support the `blank` token as a pad value. {search_type} requires "
+ f"the blank token as pad value support in order to perform batched beam search."
+ f"Please chose one of the other beam search methods, or re-train your model "
+ f"with this support."
+ )
+
+ self.tsd_max_symmetric_expansion_per_step = tsd_max_sym_exp_per_step
+ self.alsd_max_target_length = alsd_max_target_len
+ self.nsc_max_timesteps_expansion = nsc_max_timesteps_expansion
+ self.nsc_prefix_alpha = nsc_prefix_alpha
+
+ @typecheck()
+ def __call__(
+ self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor
+ ) -> Union[Hypothesis, NBestHypotheses]:
+ """Perform general beam search.
+
+ Args:
+ encoder_output: Encoded speech features (B, T_max, D_enc)
+ encoded_lengths: Lengths of the encoder outputs
+
+ Returns:
+ Either a list containing a single Hypothesis (when `return_best_hypothesis=True`,
+ otherwise a list containing a single NBestHypotheses, which itself contains a list of
+ Hypothesis. This list is sorted such that the best hypothesis is the first element.
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+ self.decoder.eval()
+ self.joint.eval()
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+
+ hypotheses = []
+ with tqdm(
+ range(encoder_output.size(0)),
+ desc='Beam search progress:',
+ total=encoder_output.size(0),
+ unit='sample',
+ ) as idx_gen:
+ # Decode every sample in the batch independently.
+ for batch_idx in idx_gen:
+ inseq = encoder_output[batch_idx : batch_idx + 1, :, :] # [1, T, D]
+ logitlen = encoded_lengths[batch_idx]
+
+ # Execute the specific search strategy
+ nbest_hyps = self.search_algorithm(inseq, logitlen) # sorted list of hypothesis
+
+ # Pack the result
+ if self.return_best_hypothesis:
+ best_hypothesis = nbest_hyps[0] # type: Hypothesis
+ else:
+ best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses
+ hypotheses.append(best_hypothesis)
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (hypotheses,)
+
+ def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Sort hypotheses by score or score given sequence length.
+
+ Args:
+ hyps: list of hypotheses
+
+ Return:
+ hyps: sorted list of hypotheses
+ """
+ if self.score_norm:
+ return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True)
+ else:
+ return sorted(hyps, key=lambda x: x.score, reverse=True)
+
+ def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Greedy search implementation for transducer.
+ Generic case when beam size = 1. Results might differ slightly due to implementation details
+ as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`.
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ hyp: 1-best decoding results
+ """
+ # Initialize zero state vectors
+ dec_state = self.decoder.initialize_state(h)
+
+ # Construct initial hypothesis
+ hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)
+ cache = {}
+
+ # Initialize state and first token
+ y, state, _ = self.decoder.score_hypothesis(hyp, cache)
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :] # [1, 1, D]
+
+ not_blank = True
+ symbols_added = 0
+
+ while not_blank:
+ ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1]
+ ytu = ytu[0, 0, 0, :] # [V + 1]
+
+ # max() requires float
+ if ytu.dtype != torch.float32:
+ ytu = ytu.float()
+
+ logp, pred = torch.max(ytu, dim=-1) # [1, 1]
+ pred = pred.item()
+
+ if pred == self.blank:
+ not_blank = False
+ else:
+ # Update state and current sequence
+ hyp.y_sequence.append(int(pred))
+ hyp.score += float(logp)
+ hyp.dec_state = state
+
+ # Compute next state and token
+ y, state, _ = self.decoder.score_hypothesis(hyp, cache)
+ symbols_added += 1
+
+ return [hyp]
+
+ def default_beam_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Beam search implementation.
+
+ Args:
+ x: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Initialize states
+ beam = min(self.beam_size, self.vocab_size)
+ beam_k = min(beam, (self.vocab_size - 1))
+ blank_tensor = torch.tensor([self.blank], device=h.device, dtype=torch.long)
+
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # Initialize zero vector states
+ dec_state = self.decoder.initialize_state(h)
+
+ # Initialize first hypothesis for the beam (blank)
+ kept_hyps = [Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state)]
+ cache = {}
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :] # [1, 1, D]
+ hyps = kept_hyps
+ kept_hyps = []
+
+ while True:
+ max_hyp = max(hyps, key=lambda x: x.score)
+ hyps.remove(max_hyp)
+
+ # update decoder state and get next score
+ y, state, lm_tokens = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D]
+
+ # get next token
+ ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1]
+ ytu = ytu[0, 0, 0, :] # [V + 1]
+
+ # remove blank token before top k
+ top_k = ytu[ids].topk(beam_k, dim=-1)
+
+ # Two possible steps - blank token or non-blank token predicted
+ ytu = (
+ torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))),
+ torch.cat((top_k[1] + index_incr, blank_tensor)),
+ )
+
+ # for each possible step
+ for logp, k in zip(*ytu):
+ # construct hypothesis for step
+ new_hyp = Hypothesis(
+ score=(max_hyp.score + float(logp)),
+ y_sequence=max_hyp.y_sequence[:],
+ dec_state=max_hyp.dec_state,
+ lm_state=max_hyp.lm_state,
+ )
+
+ # if current token is blank, dont update sequence, just store the current hypothesis
+ if k == self.blank:
+ kept_hyps.append(new_hyp)
+ else:
+ # if non-blank token was predicted, update state and sequence and then search more hypothesis
+ new_hyp.dec_state = state
+ new_hyp.y_sequence.append(int(k))
+
+ hyps.append(new_hyp)
+
+ # keep those hypothesis that have scores greater than next search generation
+ hyps_max = float(max(hyps, key=lambda x: x.score).score)
+ kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,)
+
+ # If enough hypothesis have scores greater than next search generation,
+ # stop beam search.
+ if len(kept_most_prob) >= beam:
+ kept_hyps = kept_most_prob
+ break
+
+ return self.sort_nbest(kept_hyps)
+
+ def time_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Time synchronous beam search implementation.
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # prepare the batched beam states
+ beam = min(self.beam_size, self.vocab_size)
+ beam_state = self.decoder.initialize_state(
+ torch.zeros(beam, device=h.device, dtype=h.dtype)
+ ) # [L, B, H], [L, B, H] (for LSTMs)
+
+ # Initialize first hypothesis for the beam (blank)
+ B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))]
+ cache = {}
+
+ for i in range(int(encoded_lengths)):
+ hi = h[:, i : i + 1, :]
+
+ # Update caches
+ A = []
+ C = B
+
+ h_enc = hi
+
+ # For a limited number of symmetric expansions per timestep "i"
+ for v in range(self.tsd_max_symmetric_expansion_per_step):
+ D = []
+
+ # Decode a batch of beam states and scores
+ beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state)
+
+ # Extract the log probabilities and the predicted tokens
+ beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B, 1, 1, V + 1]
+ beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1]
+ beam_topk = beam_logp[:, ids].topk(beam, dim=-1)
+
+ seq_A = [h.y_sequence for h in A]
+
+ for j, hyp in enumerate(C):
+ # create a new hypothesis in A
+ if hyp.y_sequence not in seq_A:
+ # If the sequence is not in seq_A, add it as the blank token
+ # In this step, we dont add a token but simply update score
+ A.append(
+ Hypothesis(
+ score=(hyp.score + float(beam_logp[j, self.blank])),
+ y_sequence=hyp.y_sequence[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+ )
+ else:
+ # merge the existing blank hypothesis score with current score.
+ dict_pos = seq_A.index(hyp.y_sequence)
+
+ A[dict_pos].score = np.logaddexp(
+ A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank]))
+ )
+
+ if v < self.tsd_max_symmetric_expansion_per_step:
+ for j, hyp in enumerate(C):
+ # for each current hypothesis j
+ # extract the top token score and top token id for the jth hypothesis
+ for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr):
+ # create new hypothesis and store in D
+ # Note: This loop does *not* include the blank token!
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ y_sequence=(hyp.y_sequence + [int(k)]),
+ dec_state=self.decoder.batch_select_state(beam_state, j),
+ lm_state=hyp.lm_state,
+ )
+
+ D.append(new_hyp)
+
+ # Prune beam
+ C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
+
+ # Prune beam
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
+
+ return self.sort_nbest(B)
+
+ def align_length_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ """Alignment-length synchronous beam search implementation.
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoded speech features (1, T_max, D_enc)
+
+ Returns:
+ nbest_hyps: N-best decoding results
+ """
+ # Precompute some constants for blank position
+ ids = list(range(self.vocab_size + 1))
+ ids.remove(self.blank)
+
+ # Used when blank token is first vs last token
+ if self.blank == 0:
+ index_incr = 1
+ else:
+ index_incr = 0
+
+ # prepare the batched beam states
+ beam = min(self.beam_size, self.vocab_size)
+
+ h = h[0] # [T, D]
+ h_length = int(encoded_lengths)
+ beam_state = self.decoder.initialize_state(
+ torch.zeros(beam, device=h.device, dtype=h.dtype)
+ ) # [L, B, H], [L, B, H] for LSTMS
+
+ # compute u_max as either a specific static limit,
+ # or a multiple of current `h_length` dynamically.
+ if type(self.alsd_max_target_length) == float:
+ u_max = int(self.alsd_max_target_length * h_length)
+ else:
+ u_max = int(self.alsd_max_target_length)
+
+ # Initialize first hypothesis for the beam (blank)
+ B = [Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state(beam_state, 0))]
+
+ final = []
+ cache = {}
+
+ # ALSD runs for T + U_max steps
+ for i in range(h_length + u_max):
+ # Update caches
+ A = []
+ B_ = []
+ h_states = []
+
+ # preserve the list of batch indices which are added into the list
+ # and those which are removed from the list
+ # This is necessary to perform state updates in the correct batch indices later
+ batch_ids = list(range(len(B))) # initialize as a list of all batch ids
+ batch_removal_ids = [] # update with sample ids which are removed
+
+ for bid, hyp in enumerate(B):
+ u = len(hyp.y_sequence) - 1
+ t = i - u + 1
+
+ if t > (h_length - 1):
+ batch_removal_ids.append(bid)
+ continue
+
+ B_.append(hyp)
+ h_states.append((t, h[t]))
+
+ if B_:
+ # Compute the subset of batch ids which were *not* removed from the list above
+ sub_batch_ids = None
+ if len(B_) != beam:
+ sub_batch_ids = batch_ids
+ for id in batch_removal_ids:
+ # sub_batch_ids contains list of ids *that were not removed*
+ sub_batch_ids.remove(id)
+
+ # extract the states of the sub batch only.
+ beam_state_ = [beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state))]
+ else:
+ # If entire batch was used (none were removed), simply take all the states
+ beam_state_ = beam_state
+
+ # Decode a batch/sub-batch of beam states and scores
+ beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_)
+
+ # If only a subset of batch ids were updated (some were removed)
+ if sub_batch_ids is not None:
+ # For each state in the RNN (2 for LSTM)
+ for state_id in range(len(beam_state)):
+ # Update the current batch states with the sub-batch states (in the correct indices)
+ # These indices are specified by sub_batch_ids, the ids of samples which were updated.
+ beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...]
+ else:
+ # If entire batch was updated, simply update all the states
+ beam_state = beam_state_
+
+ # h_states = list of [t, h[t]]
+ # so h[1] here is a h[t] of shape [D]
+ # Simply stack all of the h[t] within the sub_batch/batch (T <= beam)
+ h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D]
+ h_enc = h_enc.unsqueeze(1) # [B=beam, T=1, D]; batch over the beams
+
+ # Extract the log probabilities and the predicted tokens
+ beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B=beam, 1, 1, V + 1]
+ beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1]
+ beam_topk = beam_logp[:, ids].topk(beam, dim=-1)
+
+ for j, hyp in enumerate(B_):
+ # For all updated samples in the batch, add it as the blank token
+ # In this step, we dont add a token but simply update score
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(beam_logp[j, self.blank])),
+ y_sequence=hyp.y_sequence[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+
+ # Add blank prediction to A
+ A.append(new_hyp)
+
+ # If the prediction "timestep" t has reached the length of the input sequence
+ # we can add it to the "finished" hypothesis list.
+ if h_states[j][0] == (h_length - 1):
+ final.append(new_hyp)
+
+ # Here, we carefully select the indices of the states that we want to preserve
+ # for the next token (non-blank) update.
+ if sub_batch_ids is not None:
+ h_states_idx = sub_batch_ids[j]
+ else:
+ h_states_idx = j
+
+ # for each current hypothesis j
+ # extract the top token score and top token id for the jth hypothesis
+ for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr):
+ # create new hypothesis and store in A
+ # Note: This loop does *not* include the blank token!
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ y_sequence=(hyp.y_sequence[:] + [int(k)]),
+ dec_state=self.decoder.batch_select_state(beam_state, h_states_idx),
+ lm_state=hyp.lm_state,
+ )
+
+ A.append(new_hyp)
+
+ # Prune and recombine same hypothesis
+ # This may cause next beam to be smaller than max beam size
+ # Therefore larger beam sizes may be required for better decoding.
+ B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
+ B = self.recombine_hypotheses(B)
+
+ # If B_ is empty list, then we may be able to early exit
+ elif len(batch_ids) == len(batch_removal_ids):
+ break
+
+ if final:
+ return self.sort_nbest(final)
+ else:
+ return B
+
+ def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis]:
+ """Recombine hypotheses with equivalent output sequence.
+
+ Args:
+ hypotheses (list): list of hypotheses
+
+ Returns:
+ final (list): list of recombined hypotheses
+ """
+ final = []
+
+ for hyp in hypotheses:
+ seq_final = [f.y_sequence for f in final if f.y_sequence]
+
+ if hyp.y_sequence in seq_final:
+ seq_pos = seq_final.index(hyp.y_sequence)
+
+ final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score)
+ else:
+ final.append(hyp)
+
+ return hypotheses
+
+ def convtt_beam_search(self, encoder_outputs: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]:
+ assert self.decoder.prepend_sos_label
+ start_token_id = self.decoder.sos_idx
+
+ init_pred_net_state = self.decoder.initialize_state(1, encoder_outputs.dtype, encoder_outputs.device)
+
+ kept_hyps = [Hyp(score=1.0, labels=tuple(),
+ pred_net_state=init_pred_net_state,
+ last_label=start_token_id)]
+ blank_label_id = self.decoder.blank_idx
+ assert start_token_id != blank_label_id
+
+ pred_net_cache = {}
+ pred_net_cache_steps = 0
+ pred_net_cache_missing_steps = 0
+
+ for t in range(int(encoded_lengths)):
+ blank_hyps = {}
+ exp_i = 0
+ exp_hyps = kept_hyps
+
+ while exp_i < self.beam_max_exp_step and len(exp_hyps) > 0:
+ new_hyps = []
+ # expand hyps
+ assert len(exp_hyps) <= self.beam_size
+ for hyp_i in exp_hyps:
+ if hyp_i.last_label == blank_label_id:
+ assert hyp_i.pred_net_output is not None
+ pred_net_output = hyp_i.pred_net_output
+ pred_net_state = hyp_i.pred_net_state
+ else:
+ assert hyp_i.last_label == start_token_id or hyp_i.last_label == hyp_i.labels[-1]
+ if hyp_i.labels in pred_net_cache:
+ pred_net_cache_steps += 1
+ pred_net_output, pred_net_state = pred_net_cache[hyp_i.labels]
+ else:
+ pred_net_cache_missing_steps += 1
+ last_label_t = torch.tensor([[hyp_i.last_label]], dtype=torch.long, device=encoder_outputs.device)
+ pred_net_output, pred_net_state = self.decoder.predict(last_label_t,
+ hyp_i.pred_net_state,
+ add_sos=False)
+ pred_net_cache[hyp_i.labels] = (pred_net_output, pred_net_state)
+ # [B, T, U, V + 1]
+ logits = self.joint.joint(encoder_outputs[:, t:t+1, :], pred_net_output, apply_softmax=False)
+ if self.beam_temperature != 1.0:
+ logits = logits / self.beam_temperature
+ new_hyp_logp = torch.log_softmax(logits, dim=-1)
+ new_hyp_logp = new_hyp_logp.cpu().numpy()
+ # new_hyp_logp = new_hyp_logp.astype(np.float64)
+ new_hyp_logp += hyp_i.score
+ top_new_hyp_idx = np.argsort(new_hyp_logp)
+ assert top_new_hyp_idx.shape[:3] == (1, 1, 1)
+ top_new_hyp_idx = top_new_hyp_idx[0, 0, 0, -(self.beam_size + 1):]
+ for label_i in top_new_hyp_idx:
+ if label_i == blank_label_id:
+ continue
+ new_hyp = Hyp(score=new_hyp_logp[0, 0, 0, label_i],
+ labels=hyp_i.labels + (label_i,),
+ pred_net_state=pred_net_state,
+ pred_net_output=pred_net_output,
+ last_label=label_i)
+ new_hyps.append(new_hyp)
+ new_blank_hyp = Hyp(score=new_hyp_logp[0, 0, 0, blank_label_id],
+ labels=hyp_i.labels,
+ pred_net_state=pred_net_state,
+ pred_net_output=pred_net_output,
+ last_label=blank_label_id)
+ if self.beam_prune_exp:
+ new_hyps.append(new_blank_hyp)
+ if new_blank_hyp.labels in blank_hyps:
+ if self.beam_combine_path:
+ existing_blank_hyp = blank_hyps[new_blank_hyp.labels]
+ existing_blank_hyp.score = np.logaddexp(existing_blank_hyp.score, new_blank_hyp.score)
+ else:
+ if new_blank_hyp.score > blank_hyps[new_blank_hyp.labels].score:
+ blank_hyps[new_blank_hyp.labels] = new_blank_hyp
+ else:
+ blank_hyps[new_blank_hyp.labels] = new_blank_hyp
+
+ if self.beam_prune_exp_full and exp_i > 0:
+ exp_beam_size = len(exp_hyps)
+ else:
+ exp_beam_size = self.beam_size
+
+ # one expansion step for all beam finished, select top K hyp for further search
+ exp_hyps = get_top_k_hyps(new_hyps, exp_beam_size, ln_alpha=self.beam_stepwise_ln_alpha,
+ word_reward_ratio=self.beam_word_reward_ratio)
+ if self.beam_prune_exp:
+ # orig_exp_hyp_num = len(exp_hyps)
+ exp_hyps = [exp_hyp_i for exp_hyp_i in exp_hyps if exp_hyp_i.last_label != blank_label_id]
+ # print('pruned exp_hyps from {} to {}'.format(orig_exp_hyp_num, len(exp_hyps)))
+ exp_i += 1
+ # expand finished
+ blank_hyps = list(blank_hyps.values())
+ kept_hyps = get_top_k_hyps(blank_hyps, self.beam_size, always_sort=True,
+ ln_alpha=self.beam_stepwise_ln_alpha,
+ word_reward_ratio=self.beam_word_reward_ratio)
+
+ if self.score_norm:
+ if self.beam_stepwise_ln_alpha > 0:
+ scores_len_norm = [hyp_i.length_norm_score(self.beam_stepwise_ln_alpha) for hyp_i in kept_hyps]
+ elif self.beam_word_reward_ratio > 0:
+ scores_len_norm = [hyp_i.word_reward_score(self.beam_word_reward_ratio) for hyp_i in kept_hyps]
+ else:
+ len_normed_best_hyp, scores_len_norm = get_length_normed_best(kept_hyps)
+ else:
+ scores_len_norm = [0] * len(kept_hyps)
+
+ ret_hyps = []
+ for i in range(len(kept_hyps)):
+ ret_hyps.append(Hypothesis(score=scores_len_norm[i] if self.score_norm else kept_hyps[i].score,
+ y_sequence=[int(label_i) for label_i in kept_hyps[i].labels]))
+ return sorted(ret_hyps, key=lambda x: x.score, reverse=True)
+
+
+class Hyp:
+ __slots__ = ('score', 'labels', 'pred_net_state', 'pred_net_output', 'last_label')
+
+ def __init__(self, score, labels=None, pred_net_state=None, pred_net_output=None, last_label=None):
+ self.score = score
+ self.labels = labels
+ self.pred_net_state = pred_net_state
+ self.pred_net_output = pred_net_output
+ self.last_label = last_label
+
+ def length_norm_score(self, alpha):
+ return length_norm(self.score, len(self.labels), alpha)
+
+ def word_reward_score(self, reward_ratio):
+ label_len = len(self.labels)
+ # if not self.blank:
+ # label_len -= 1
+ return self.score + label_len * reward_ratio
+
+
+def length_norm(log_prob, label_len, alpha):
+ len_norm = (label_len + 5.) / 6.
+ if alpha != 1:
+ len_norm = len_norm ** alpha
+ return log_prob / len_norm
+
+
+def get_top_k_hyps(hyps, k, always_sort=False, ln_alpha=0., word_reward_ratio=0.):
+ if len(hyps) <= k and not always_sort:
+ return hyps
+ else:
+ if ln_alpha > 0:
+ hyp_scores = np.array([hyp_i.length_norm_score(ln_alpha) for hyp_i in hyps])
+ elif word_reward_ratio > 0:
+ hyp_scores = np.array([hyp_i.word_reward_score(word_reward_ratio) for hyp_i in hyps])
+ else:
+ hyp_scores = np.array([hyp_i.score for hyp_i in hyps])
+ top_k_idx = np.argsort(hyp_scores)[-k:]
+ return [hyps[i] for i in top_k_idx]
+
+
+def get_length_normed_best(hyps):
+ best_hyp = None
+ best_hyp_score = None
+ score_len_normed = []
+ for hyp_i in hyps:
+ length_normed_score = hyp_i.score / (len(hyp_i.labels) + 1e-16)
+ score_len_normed.append(length_normed_score)
+ if best_hyp_score is None or length_normed_score > best_hyp_score:
+ best_hyp_score = length_normed_score
+ best_hyp = hyp_i
+ assert best_hyp is not None
+ return best_hyp, score_len_normed
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b30d4c7e87907cbfe6c5e240ef8f68d7f87407
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/transformert_greedy_decoding.py
@@ -0,0 +1,540 @@
+from typing import Optional, Union
+
+import torch
+
+from nemo.collections.asr.modules import rnnt_abstract, TransformerTDecoder
+from nemo.collections.asr.parts import rnnt_utils
+from nemo.collections.common.parts.rnn import label_collate
+from nemo.core.classes import Typing, typecheck
+from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
+
+
+class _GreedyTransformerTInfer(Typing):
+ """A greedy transducer decoder.
+
+ Provides a common abstraction for sample level and batch level greedy decoding.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
+ "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"predictions": NeuralType(elements_type=HypothesisType())}
+
+ def __init__(
+ self,
+ decoder_model: TransformerTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__()
+ self.decoder = decoder_model
+ self.joint = joint_model
+
+ self._blank_index = blank_index
+ self._SOS = blank_index # Start of single index
+ self.max_symbols = max_symbols_per_step
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ @torch.no_grad()
+ def _pred_step(
+ self,
+ label: Union[torch.Tensor, int],
+ hidden: Optional[torch.Tensor],
+ add_sos: bool = False,
+ batch_size: Optional[int] = None,
+ ) -> (torch.Tensor, torch.Tensor):
+ """
+ Common prediction step based on the AbstractRNNTDecoder implementation.
+
+ Args:
+ label: (int/torch.Tensor): Label or "Start-of-Signal" token.
+ hidden: (Optional torch.Tensor): RNN State vector
+ add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token.
+ batch_size: Batch size of the output tensor.
+
+ Returns:
+ g: (B, U, H) if add_sos is false, else (B, U + 1, H)
+ hid: (h, c) where h is the final sequence hidden state and c is
+ the final cell state:
+ h (tensor), shape (L, B, H)
+ c (tensor), shape (L, B, H)
+ """
+ if isinstance(label, torch.Tensor):
+ # label: [batch, 1]
+ if label.dtype != torch.long:
+ label = label.long()
+
+ else:
+ # Label is an integer
+ if label == self._SOS:
+ assert not self.decoder.prepend_sos_label
+ return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size)
+
+ label = label_collate([[label]])
+
+ # output: [B, 1, K]
+ return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size)
+
+ def _joint_step(self, enc, pred, log_normalize: Optional[bool] = None):
+ """
+ Common joint step based on AbstractRNNTJoint implementation.
+
+ Args:
+ enc: Output of the Encoder model. A torch.Tensor of shape [B, 1, H1]
+ pred: Output of the Decoder model. A torch.Tensor of shape [B, 1, H2]
+ log_normalize: Whether to log normalize or not. None will log normalize only for CPU.
+
+ Returns:
+ logits of shape (B, T=1, U=1, V + 1)
+ """
+ with torch.no_grad():
+ logits = self.joint.joint(enc, pred)
+
+ if log_normalize is None:
+ if not logits.is_cuda: # Use log softmax only if on CPU
+ logits = logits.log_softmax(dim=len(logits.shape) - 1)
+ else:
+ if log_normalize:
+ logits = logits.log_softmax(dim=len(logits.shape) - 1)
+
+ return logits
+
+
+class GreedyTransformerTInfer(_GreedyTransformerTInfer):
+ """A greedy transducer decoder.
+
+ Sequence level greedy decoding, performed auto-repressively.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__(
+ decoder_model=decoder_model,
+ joint_model=joint_model,
+ blank_index=blank_index,
+ max_symbols_per_step=max_symbols_per_step,
+ )
+
+ @typecheck()
+ def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor):
+ """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
+ Output token is generated auto-repressively.
+
+ Args:
+ encoder_output: A tensor of size (batch, features, timesteps).
+ encoded_lengths: list of int representing the length of each sequence
+ output sequence.
+
+ Returns:
+ packed list containing batch number of sentences (Hypotheses).
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+ self.decoder.eval()
+ self.joint.eval()
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+
+ hypotheses = []
+ # Process each sequence independently
+ for batch_idx in range(encoder_output.size(0)):
+ inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D]
+ logitlen = encoded_lengths[batch_idx]
+ sentence = self._greedy_decode(inseq, logitlen)
+ hypotheses.append(sentence)
+
+ # Pack results into Hypotheses
+ packed_result = [
+ rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0)
+ for sent in hypotheses
+ ]
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (packed_result,)
+
+ @torch.no_grad()
+ def _greedy_decode(self, x: torch.Tensor, out_len: torch.Tensor):
+ # x: [T, 1, D]
+ # out_len: [seq_len]
+
+ hidden = self.decoder.initialize_state(x.size(1), x.dtype, x.device)
+ if not self.decoder.prepend_sos_label:
+ # Initialize blank state and empty label set
+ label = []
+ else:
+ label = [self.decoder.sos_idx]
+
+
+ # For timestep t in X_t
+ for time_idx in range(out_len):
+ # Extract encoder embedding at timestep t
+ # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D]
+ f = x.narrow(dim=0, start=time_idx, length=1)
+
+ # Setup exit flags and counter
+ not_blank = True
+ symbols_added = 0
+
+ # While blank is not predicted, or we dont run out of max symbols per timestep
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+ # In the first timestep, we initialize the network with TransformerT Blank
+ # In later timesteps, we provide previous predicted label as input.
+ last_label = self._SOS if label == [] else label[-1]
+
+ # Perform prediction network and joint network steps.
+ g, hidden_prime = self._pred_step(last_label, hidden)
+ logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :]
+
+ del g
+
+ # torch.max(0) op doesnt exist for FP 16.
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # get index k, of max prob
+ v, k = logp.max(0)
+ k = k.item() # K is the label at timestep t_s in inner loop, s >= 0.
+
+ del logp
+
+ # If blank token is predicted, exit inner loop, move onto next timestep t
+ if k == self._blank_index:
+ not_blank = False
+ else:
+ # Append token to label set, update RNN state.
+ label.append(k)
+ hidden = hidden_prime
+
+ # Increment token counter.
+ symbols_added += 1
+
+ return label
+
+
+class GreedyBatchedTransformerTInfer(_GreedyTransformerTInfer):
+ """A batch level greedy transducer decoder.
+
+ Batch level greedy decoding, performed auto-repressively.
+
+ Args:
+ decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
+ joint_model: rnnt_utils.AbstractRNNTJoint implementation.
+ blank_index: int index of the blank token. Can be 0 or len(vocabulary).
+ max_symbols_per_step: Optional int. The maximum number of symbols that can be added
+ to a sequence in a single time step; if set to None then there is
+ no limit.
+ """
+
+ def __init__(
+ self,
+ decoder_model: rnnt_abstract.AbstractRNNTDecoder,
+ joint_model: rnnt_abstract.AbstractRNNTJoint,
+ blank_index: int,
+ max_symbols_per_step: Optional[int] = None,
+ ):
+ super().__init__(
+ decoder_model=decoder_model,
+ joint_model=joint_model,
+ blank_index=blank_index,
+ max_symbols_per_step=max_symbols_per_step,
+ )
+
+ # Depending on availability of `blank_as_pad` support
+ # switch between more efficient batch decoding technique
+ if self.decoder.blank_as_pad:
+ self._greedy_decode = self._greedy_decode_blank_as_pad
+ else:
+ self._greedy_decode = self._greedy_decode_masked
+
+ @typecheck()
+ def forward(self, encoder_output: torch.Tensor, encoded_lengths: torch.Tensor):
+ """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
+ Output token is generated auto-repressively.
+
+ Args:
+ encoder_output: A tensor of size (batch, features, timesteps).
+ encoded_lengths: list of int representing the length of each sequence
+ output sequence.
+
+ Returns:
+ packed list containing batch number of sentences (Hypotheses).
+ """
+ # Preserve decoder and joint training state
+ decoder_training_state = self.decoder.training
+ joint_training_state = self.joint.training
+
+ with torch.no_grad():
+ # Apply optional preprocessing
+ encoder_output = encoder_output.transpose(1, 2) # (B, T, D)
+ logitlen = encoded_lengths
+
+ self.decoder.eval()
+ self.joint.eval()
+
+ with self.decoder.as_frozen(), self.joint.as_frozen():
+ inseq = encoder_output # [B, T, D]
+ hypotheses = self._greedy_decode(inseq, logitlen, device=inseq.device)
+
+ # Pack the hypotheses results
+ packed_result = [
+ rnnt_utils.Hypothesis(y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0)
+ for sent in hypotheses
+ ]
+
+ del hypotheses
+
+ self.decoder.train(decoder_training_state)
+ self.joint.train(joint_training_state)
+
+ return (packed_result,)
+
+ def _greedy_decode_blank_as_pad(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device):
+ with torch.no_grad():
+ # x: [B, T, D]
+ # out_len: [B]
+ # device: torch.device
+
+ # Initialize state
+ hidden = None
+ batchsize = x.shape[0]
+
+ # Output string buffer
+ label = [[] for _ in range(batchsize)]
+
+ # Last Label buffer + Last Label without blank buffer
+ # batch level equivalent of the last_label
+ last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)
+
+ # Mask buffers
+ blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
+
+ # Get max sequence length
+ max_out_len = out_len.max()
+
+ for time_idx in range(max_out_len):
+ f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]
+
+ # Prepare t timestamp batch variables
+ not_blank = True
+ symbols_added = 0
+
+ # Reset blank mask
+ blank_mask.mul_(False)
+
+ # Update blank mask with time mask
+ # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
+ # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
+ blank_mask = time_idx >= out_len
+
+ # Start inner loop
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+
+ # Batch prediction and joint network steps
+ # If very first prediction step, submit SOS tag (blank) to pred_step.
+ # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
+ if time_idx == 0 and symbols_added == 0:
+ g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
+ else:
+ # Perform batch step prediction of decoder, getting new states and scores ("g")
+ g, hidden_prime = self._pred_step(last_label, hidden, batch_size=batchsize)
+
+ # Batched joint step - Output = [B, V + 1]
+ logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :]
+
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # Get index k, of max prob for batch
+ v, k = logp.max(1)
+ del v, g, logp
+
+ # Update blank mask with current predicted blanks
+ # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
+ k_is_blank = k == self._blank_index
+ blank_mask |= k_is_blank
+
+ del k_is_blank
+
+ # If all samples predict / have predicted prior blanks, exit loop early
+ # This is equivalent to if single sample predicted k
+ if blank_mask.all():
+ not_blank = False
+ else:
+ # Collect batch indices where blanks occurred now/past
+ blank_indices = []
+ if hidden is not None:
+ blank_indices = (blank_mask == 1).nonzero(as_tuple=False)
+
+ # Recover prior state for all samples which predicted blank now/past
+ if hidden is not None:
+ # LSTM has 2 states
+ for state_id in range(len(hidden)):
+ hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :]
+
+ # Recover prior predicted label for all samples which predicted blank now/past
+ k[blank_indices] = last_label[blank_indices, 0]
+
+ # Update new label and hidden state for next iteration
+ last_label = k.clone().view(-1, 1)
+ hidden = hidden_prime
+
+ # Update predicted labels, accounting for time mask
+ # If blank was predicted even once, now or in the past,
+ # Force the current predicted label to also be blank
+ # This ensures that blanks propogate across all timesteps
+ # once they have occured (normally stopping condition of sample level loop).
+ for kidx, ki in enumerate(k):
+ if blank_mask[kidx] == 0:
+ label[kidx].append(ki)
+
+ symbols_added += 1
+
+ return label
+
+ @torch.no_grad()
+ def _greedy_decode_masked(self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device):
+ # x: [B, T, D]
+ # out_len: [B]
+ # device: torch.device
+
+ # Initialize state
+ hidden = None
+ batchsize = x.shape[0]
+
+ # Output string buffer
+ label = [[] for _ in range(batchsize)]
+
+ # Last Label buffer + Last Label without blank buffer
+ # batch level equivalent of the last_label
+ last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device)
+ last_label_without_blank = last_label.clone()
+
+ # Mask buffers
+ blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
+
+ # Get max sequence length
+ max_out_len = out_len.max()
+ for time_idx in range(max_out_len):
+ f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D]
+
+ # Prepare t timestamp batch variables
+ not_blank = True
+ symbols_added = 0
+
+ # Reset blank mask
+ blank_mask.mul_(False)
+
+ # Update blank mask with time mask
+ # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
+ # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
+ blank_mask = time_idx >= out_len
+
+ # Start inner loop
+ while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
+ # Batch prediction and joint network steps
+ # If very first prediction step, submit SOS tag (blank) to pred_step.
+ # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
+ if time_idx == 0 and symbols_added == 0:
+ g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize)
+ else:
+ # Set a dummy label for the blank value
+ # This value will be overwritten by "blank" again the last label update below
+ # This is done as vocabulary of prediction network does not contain "blank" token of TransformerT
+ last_label_without_blank_mask = last_label == self._blank_index
+ last_label_without_blank[last_label_without_blank_mask] = 0 # temp change of label
+ last_label_without_blank[~last_label_without_blank_mask] = last_label[
+ ~last_label_without_blank_mask
+ ]
+
+ # Perform batch step prediction of decoder, getting new states and scores ("g")
+ g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize)
+
+ # Batched joint step - Output = [B, V + 1]
+ logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :]
+
+ if logp.dtype != torch.float32:
+ logp = logp.float()
+
+ # Get index k, of max prob for batch
+ v, k = logp.max(1)
+ del v, g, logp
+
+ # Update blank mask with current predicted blanks
+ # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
+ k_is_blank = k == self._blank_index
+ blank_mask.bitwise_or_(k_is_blank)
+
+ # If all samples predict / have predicted prior blanks, exit loop early
+ # This is equivalent to if single sample predicted k
+ if blank_mask.all():
+ not_blank = False
+ else:
+ # Collect batch indices where blanks occurred now/past
+ blank_indices = []
+ if hidden is not None:
+ blank_indices = (blank_mask == 1).nonzero(as_tuple=False)
+
+ # Recover prior state for all samples which predicted blank now/past
+ if hidden is not None:
+ # LSTM has 2 states
+ for state_id in range(len(hidden)):
+ hidden_prime[state_id][:, blank_indices, :] = hidden[state_id][:, blank_indices, :]
+
+ # Recover prior predicted label for all samples which predicted blank now/past
+ k[blank_indices] = last_label[blank_indices, 0]
+
+ # Update new label and hidden state for next iteration
+ last_label = k.view(-1, 1)
+ hidden = hidden_prime
+
+ # Update predicted labels, accounting for time mask
+ # If blank was predicted even once, now or in the past,
+ # Force the current predicted label to also be blank
+ # This ensures that blanks propogate across all timesteps
+ # once they have occured (normally stopping condition of sample level loop).
+ for kidx, ki in enumerate(k):
+ if blank_mask[kidx] == 0:
+ label[kidx].append(ki)
+
+ symbols_added += 1
+
+ return label
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py
new file mode 100644
index 0000000000000000000000000000000000000000..708398da76ed4c3e9d1c060002cfa1058a661f8d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/asr/parts/wav2vec.py
@@ -0,0 +1,458 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import List, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from nemo.collections.asr.models.wav2vec.wav2vec_config import Wav2VecConvExtractorMode, Wav2VecTransformerConfig
+from nemo.collections.asr.parts.layer_norm import LayerNorm
+from nemo.collections.asr.parts.multihead_attention import MultiheadAttention
+
+
+class TransposeLast(torch.nn.Module):
+ """
+ Transposes last dimension. Useful for adding to a sequential block.
+ """
+
+ def forward(self, x):
+ return x.transpose(-2, -1)
+
+
+class SamePad(torch.nn.Module):
+ def __init__(self, kernel_size):
+ super().__init__()
+ self.remove = kernel_size % 2 == 0
+
+ def forward(self, x):
+ if self.remove:
+ x = x[:, :, :-1]
+ return x
+
+
+class ConvFeatureEncoder(nn.Module):
+ """
+ Converts input raw audio into features for downstream transformer model.
+ Uses 1D convolutional blocks with GeLU activation.
+ """
+
+ def __init__(
+ self,
+ conv_layers: List[Tuple[int, int, int]],
+ mode: Wav2VecConvExtractorMode = Wav2VecConvExtractorMode.default,
+ conv_bias: bool = False,
+ ):
+ super().__init__()
+
+ def block(
+ n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False,
+ ):
+ def make_conv():
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+ nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Sequential(TransposeLast(), nn.LayerNorm(dim, elementwise_affine=True), TransposeLast()),
+ nn.GELU(),
+ )
+ elif is_group_norm:
+ return nn.Sequential(make_conv(), nn.GroupNorm(dim, dim, affine=True), nn.GELU(),)
+ else:
+ return nn.Sequential(make_conv(), nn.GELU())
+
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode is Wav2VecConvExtractorMode.layer_norm,
+ is_group_norm=mode is Wav2VecConvExtractorMode.default and i == 0,
+ conv_bias=conv_bias,
+ )
+ )
+ in_d = dim
+
+ def forward(self, x):
+ # BxT -> BxCxT
+ x = x.unsqueeze(1)
+ for conv in self.conv_layers:
+ x = conv(x)
+ return x
+
+ def get_subsampled_lens(self, lens):
+ for m in self.conv_layers:
+ conv = m[0]
+ lens = (lens + 2 * conv.padding[0] - conv.dilation[0] * (conv.kernel_size[0] - 1) - 1) // conv.stride[0] + 1
+ return lens
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ conv_cfg = args.conv
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder.embedding_dim
+
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=conv_cfg.conv_pos,
+ padding=conv_cfg.conv_pos // 2,
+ groups=conv_cfg.conv_pos_groups,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (conv_cfg.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(conv_cfg.conv_pos), nn.GELU())
+ self.pos_conv_layer_drop = conv_cfg.layer_drop
+
+ encoder_cfg = args.encoder
+ self.layers = nn.ModuleList(
+ [
+ TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=encoder_cfg.ffn_embedding_dim,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=encoder_cfg.attention_dropout,
+ activation_dropout=encoder_cfg.activation_dropout,
+ activation_fn=encoder_cfg.activation_fn.value,
+ layer_norm_first=encoder_cfg.layer_norm_first,
+ )
+ for _ in range(encoder_cfg.encoder_layers)
+ ]
+ )
+
+ self.layer_norm_first = encoder_cfg.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = encoder_cfg.encoder_layerdrop
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None):
+ x = self.extract_features(x, padding_mask)
+
+ if self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ return x
+
+ def extract_features(self, x, padding_mask=None):
+
+ if padding_mask is not None:
+ x = index_put(x, padding_mask, 0)
+
+ if self.pos_conv_layer_drop > 0 and self.training and np.random.random() < self.pos_conv_layer_drop:
+ pass
+ else:
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x = x + x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ # layer_results = []
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability >= self.layerdrop):
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
+ # layer_results.append(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ return x
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ ) -> None:
+
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ att_args=None,
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, attn
+
+
+def index_put(tensor, indices, value):
+ tensor[indices] = value
+ return tensor
+
+
+def get_activation_fn(activation: str):
+ """ Returns the activation function corresponding to `activation` """
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+ elif activation == "tanh":
+ return torch.tanh
+ elif activation == "linear":
+ return lambda x: x
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class Wav2VecTransformerEncoder(nn.Module):
+ def __init__(self, cfg: Wav2VecTransformerConfig):
+ super().__init__()
+
+ conv_cfg = cfg.conv
+
+ self.dropout = cfg.dropout
+ self.embedding_dim = cfg.encoder.embedding_dim
+ self.layer_norm_first = cfg.encoder.layer_norm_first
+ assert not self.layer_norm_first, 'nn.TransformerEncoderLayer do not support layer_norm_first'
+
+ # positional convolutional embeddings
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=conv_cfg.conv_pos,
+ padding=conv_cfg.conv_pos // 2,
+ groups=conv_cfg.conv_pos_groups,
+ )
+
+ self.feature_dropout = nn.Dropout(self.dropout)
+
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (conv_cfg.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(conv_cfg.conv_pos), nn.GELU())
+
+ encoder_cfg = cfg.encoder
+ self.transformer_encoder = nn.TransformerEncoder(
+ encoder_layer=nn.TransformerEncoderLayer(
+ d_model=self.embedding_dim,
+ nhead=encoder_cfg.num_attention_heads,
+ dim_feedforward=encoder_cfg.ffn_embedding_dim,
+ dropout=self.dropout,
+ activation=encoder_cfg.activation_fn.value,
+ ),
+ num_layers=encoder_cfg.encoder_layers,
+ )
+ self.layer_norm = nn.LayerNorm(self.embedding_dim)
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None):
+ x = self.extract_features(x, padding_mask)
+
+ if self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ return x
+
+ def extract_features(self, x, padding_mask=None):
+
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x += x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = self.feature_dropout(x)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ return x
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
+
+
+def init_bert_params(module):
+ """
+ Initialize the weights specific to the BERT Model.
+ This overrides the default initializations depending on the specified arguments.
+ 1. If normal_init_linear_weights is set then weights of linear
+ layer will be initialized using the normal distribution and
+ bias will be set to the specified value.
+ 2. If normal_init_embed_weights is set then weights of embedding
+ layer will be initialized using the normal distribution.
+ 3. If normal_init_proj_weights is set then weights of
+ in_project_weight for MultiHeadAttention initialized using
+ the normal distribution (to be validated).
+ """
+
+ def normal_(data):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
+ )
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, MultiheadAttention):
+ normal_(module.q_proj.weight.data)
+ normal_(module.k_proj.weight.data)
+ normal_(module.v_proj.weight.data)
+ if isinstance(module, nn.TransformerEncoderLayer):
+ normal_(module.self_attn.in_proj_weight.data)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a45b36ffa09b87165240ccc6c176baee2af94e
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import nemo.collections.common.callbacks
+from nemo.collections.common import losses, parts, tokenizers
+from nemo.package_info import __version__
+
+# Set collection version equal to NeMo version.
+__version = __version__
+
+# Authorship.
+__author__ = "NVIDIA Corporation"
+
+# Set collection name.
+__description__ = "Common collection"
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad5c9c85a5f1c5921c115112e313b86a1e123a2
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..55fa5c50a1c54fa73ec2f8f2ef5b3fbfe96016e4
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/callbacks/callbacks.py
@@ -0,0 +1,32 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+
+from pytorch_lightning.callbacks.base import Callback
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class LogEpochTimeCallback(Callback):
+ """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log
+ """
+
+ @rank_zero_only
+ def on_epoch_start(self, trainer, pl_module):
+ self.epoch_start = time.time()
+
+ @rank_zero_only
+ def on_epoch_end(self, trainer, pl_module):
+ curr_time = time.time()
+ duration = curr_time - self.epoch_start
+ trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae7d3b0c1a8cc214d04b2b814f24f7d21ebd251d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.losses.aggregator import AggregatorLoss
+from nemo.collections.common.losses.cross_entropy import CrossEntropyLoss
+from nemo.collections.common.losses.mse_loss import MSELoss
+from nemo.collections.common.losses.smoothed_cross_entropy import SmoothedCrossEntropyLoss
+from nemo.collections.common.losses.spanning_loss import SpanningLoss
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..1987ddd22bae30d6455adf87a7cc4313d9cdff1c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/aggregator.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+from nemo.core.classes import Loss, typecheck
+from nemo.core.neural_types import LossType, NeuralType
+
+__all__ = ['AggregatorLoss']
+
+
+class AggregatorLoss(Loss):
+ """
+ Sums several losses into one.
+
+ Args:
+ num_inputs: number of input losses
+ weights: a list of coefficient for merging losses
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ input_types = {}
+ for i in range(self._num_losses):
+ input_types["loss_" + str(i + 1)] = NeuralType(elements_type=LossType())
+
+ return input_types
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"loss": NeuralType(elements_type=LossType())}
+
+ def __init__(self, num_inputs: int = 2, weights: List[float] = None):
+ super().__init__()
+ self._num_losses = num_inputs
+ if weights is not None and len(weights) != num_inputs:
+ raise ValueError("Length of weights should be equal to the number of inputs (num_inputs)")
+
+ self._weights = weights
+
+ @typecheck()
+ def forward(self, **kwargs):
+ values = [kwargs[x] for x in sorted(kwargs.keys())]
+ loss = torch.zeros_like(values[0])
+ for loss_idx, loss_value in enumerate(values):
+ if self._weights is not None:
+ loss = loss.add(loss_value, alpha=self._weights[loss_idx])
+ else:
+ loss = loss.add(loss_value)
+ return loss
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ed0a6b3d81929e8a544de6140d45435ce7440d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/cross_entropy.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+
+from nemo.core.classes import Serialization, Typing, typecheck
+from nemo.core.neural_types import LabelsType, LogitsType, LossType, MaskType, NeuralType
+
+__all__ = ['CrossEntropyLoss']
+
+
+class CrossEntropyLoss(nn.CrossEntropyLoss, Serialization, Typing):
+ """
+ CrossEntropyLoss
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "logits": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 1), LogitsType()),
+ "labels": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), LabelsType()),
+ "loss_mask": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), MaskType(), optional=True),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"loss": NeuralType(elements_type=LossType())}
+
+ def __init__(self, logits_ndim=2, weight=None, reduction='mean'):
+ """
+ Args:
+ logits_ndim (int): number of dimensions (or rank) of the logits tensor
+ weight (list): list of rescaling weight given to each class
+ reduction (str): type of the reduction over the batch
+ """
+ if weight is not None and not torch.is_tensor(weight):
+ weight = torch.FloatTensor(weight)
+ super().__init__(weight=weight, reduction=reduction)
+ self._logits_dim = logits_ndim
+
+ @typecheck()
+ def forward(self, logits, labels, loss_mask=None):
+ """
+ Args:
+ logits (float): output of the classifier
+ labels (long): ground truth labels
+ loss_mask (bool/float/int): tensor to specify the masking
+ """
+ logits_flatten = torch.flatten(logits, start_dim=0, end_dim=-2)
+ labels_flatten = torch.flatten(labels, start_dim=0, end_dim=-1)
+
+ if loss_mask is not None:
+ if loss_mask.dtype is not torch.bool:
+ loss_mask = loss_mask > 0.5
+ loss_mask_flatten = torch.flatten(loss_mask, start_dim=0, end_dim=-1)
+ logits_flatten = logits_flatten[loss_mask_flatten]
+ labels_flatten = labels_flatten[loss_mask_flatten]
+
+ if len(labels_flatten) == 0:
+ return super().forward(logits, torch.argmax(logits, dim=-1))
+
+ loss = super().forward(logits_flatten, labels_flatten)
+ return loss
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..802e8ca49204a9bf9d439b114d4f3315d453e00f
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/mse_loss.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch import Tensor, nn
+
+from nemo.core.classes import Serialization, Typing, typecheck
+from nemo.core.neural_types import LabelsType, LossType, NeuralType, RegressionValuesType
+
+__all__ = ['MSELoss']
+
+
+class MSELoss(nn.MSELoss, Serialization, Typing):
+ """
+ MSELoss
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "preds": NeuralType(tuple('B'), RegressionValuesType()),
+ "labels": NeuralType(tuple('B'), LabelsType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"loss": NeuralType(elements_type=LossType())}
+
+ def __init__(self, reduction: str = 'mean'):
+ """
+ Args:
+ reduction: type of the reduction over the batch
+ """
+ super().__init__(reduction=reduction)
+
+ @typecheck()
+ def forward(self, preds: Tensor, labels: Tensor) -> Tensor:
+ """
+ Args:
+ preds: output of the classifier
+ labels: ground truth labels
+ """
+ return super().forward(preds, labels)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e37d29ba45347d705c19833ce8dd2da60acde04
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/smoothed_cross_entropy.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+
+from nemo.core.classes import Loss, typecheck
+from nemo.core.neural_types import LabelsType, LogprobsType, LossType, MaskType, NeuralType
+
+__all__ = ["SmoothedCrossEntropyLoss"]
+
+
+class SmoothedCrossEntropyLoss(Loss):
+ """
+ Calculates Cross-entropy loss with label smoothing for a batch of sequences.
+
+ SmoothedCrossEntropyLoss:
+ 1) excludes padding tokens from loss calculation
+ 2) allows to use label smoothing regularization
+ 3) allows to calculate loss for the desired number of last tokens
+
+ Args:
+ label_smoothing (float): label smoothing regularization coefficient
+ predict_last_k (int): parameter which sets the number of last tokens to calculate the loss for, for example
+ 0: (default) calculate loss on the entire sequence (e.g., NMT)
+ 1: calculate loss on the last token only (e.g., LM evaluation)
+ Intermediate values allow to control the trade-off between eval
+ time (proportional to the number of batches) and eval performance
+ (proportional to the number of context tokens)
+ pad_id (int): padding id
+ eps (float): the small eps number to avoid division buy zero
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "log_probs": NeuralType(("B", "T", "D"), LogprobsType()),
+ "labels": NeuralType(("B", "T"), LabelsType()),
+ "output_mask": NeuralType(("B", "T"), MaskType(), optional=True),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {"loss": NeuralType(elements_type=LossType())}
+
+ def __init__(
+ self,
+ pad_id: Optional[int] = None,
+ label_smoothing: Optional[float] = 0.0,
+ predict_last_k: Optional[int] = 0,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self._pad_id = pad_id
+ self._eps = eps
+ self._predict_last_k = predict_last_k
+ self._label_smoothing = label_smoothing
+
+ @typecheck()
+ def forward(self, log_probs, labels, output_mask=None):
+ """
+ Args:
+ log_probs: float tensor of shape batch_size x seq_len x vocab_size, values should be log probabilities
+ labels: int tensor of shape batch_size x seq_len
+ output_mask: binary tensor of shape batch_size x seq_len
+ eps: epsilon param to avoid divide by zero in loss calculation
+ """
+ if output_mask is None and self._pad_id is None:
+ raise ValueError("Both output_mask and pad_id are None")
+ if output_mask is None and self._pad_id is not None:
+ output_mask = (labels != self._pad_id).to(log_probs.dtype)
+
+ if output_mask.dtype is not log_probs.dtype:
+ output_mask = output_mask.to(log_probs.dtype)
+
+ batch_size, seq_len, vocab_size = log_probs.size()
+ smoothing = vocab_size * self._label_smoothing / (vocab_size - 1)
+ target_log_probs = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
+
+ smoothing_log_probs = log_probs.mean(dim=-1)
+ neg_log_likelihood = (1.0 - smoothing) * target_log_probs + smoothing * smoothing_log_probs
+ neg_log_likelihood = neg_log_likelihood[:, -self._predict_last_k :]
+ output_mask = output_mask[:, -self._predict_last_k :]
+ neg_log_likelihood = -torch.sum(neg_log_likelihood * output_mask)
+ neg_log_likelihood = neg_log_likelihood / (output_mask.sum() + self._eps)
+
+ return neg_log_likelihood
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12dab64afd00f83003e8fb9fa25af3d72a360c7
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/losses/spanning_loss.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch import nn
+
+from nemo.core.classes import Loss, typecheck
+from nemo.core.neural_types import ChannelType, LogitsType, LossType, NeuralType
+
+__all__ = ['SpanningLoss']
+
+
+class SpanningLoss(Loss):
+ """
+ implements start and end loss of a span e.g. for Question Answering.
+ """
+
+ @property
+ def input_types(self):
+ """Returns definitions of module input ports.
+ """
+ return {
+ "logits": NeuralType(('B', 'T', 'D'), LogitsType()),
+ "start_positions": NeuralType(tuple('B'), ChannelType()),
+ "end_positions": NeuralType(tuple('B'), ChannelType()),
+ }
+
+ @property
+ def output_types(self):
+ """Returns definitions of module output ports.
+ """
+ return {
+ "loss": NeuralType(elements_type=LossType()),
+ "start_logits": NeuralType(('B', 'T'), LogitsType()),
+ "end_logits": NeuralType(('B', 'T'), LogitsType()),
+ }
+
+ def __init__(self,):
+ super().__init__()
+
+ @typecheck()
+ def forward(self, logits, start_positions, end_positions):
+ """
+ Args:
+ logits: Output of question answering head, which is a token classfier.
+ start_positions: Ground truth start positions of the answer w.r.t.
+ input sequence. If question is unanswerable, this will be
+ pointing to start token, e.g. [CLS], of the input sequence.
+ end_positions: Ground truth end positions of the answer w.r.t.
+ input sequence. If question is unanswerable, this will be
+ pointing to start token, e.g. [CLS], of the input sequence.
+ """
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ ignored_index = start_logits.size(1)
+ start_positions.clamp_(0, ignored_index)
+ end_positions.clamp_(0, ignored_index)
+
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+ return total_loss, start_logits, end_logits
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2b312b6f771b40a83db2f6adec8af033fc331c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.metrics.classification_accuracy import TopKClassificationAccuracy
+from nemo.collections.common.metrics.perplexity import Perplexity
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e31fa9bd107a22f762c510633633740f032c6f4
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/classification_accuracy.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+from pytorch_lightning.metrics import Metric
+
+__all__ = ['TopKClassificationAccuracy']
+
+
+class TopKClassificationAccuracy(Metric):
+ """
+ This metric computes numerator and denominator for Overall Accuracy between logits and labels.
+ When doing distributed training/evaluation the result of res=TopKClassificationAccuracy(logits, labels) calls
+ will be all-reduced between all workers using SUM operations.
+ Here contains two numbers res=[correctly_predicted, total_samples]. Accuracy=correctly_predicted/total_samples.
+
+ If used with PytorchLightning LightningModule, include correct_count and total_count inside validation_step results.
+ Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER.
+
+ Example:
+ def validation_step(self, batch, batch_idx):
+ ...
+ correct_count, total_count = self._accuracy(logits, labels)
+ return {'val_loss': loss_value, 'val_correct_count': correct_count, 'val_total_count': total_count}
+
+ def validation_epoch_end(self, outputs):
+ ...
+ val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
+ correct_counts = torch.stack([x['val_correct_counts'] for x in outputs])
+ total_counts = torch.stack([x['val_total_counts'] for x in outputs])
+
+ topk_scores = compute_topk_accuracy(correct_counts, total_counts)
+
+ tensorboard_log = {'val_loss': val_loss_mean}
+ for top_k, score in zip(self._accuracy.top_k, topk_scores):
+ tensorboard_log['val_epoch_top@{}'.format(top_k)] = score
+
+ return {'log': tensorboard_log}
+
+ Args:
+ top_k: Optional list of integers. Defaults to [1].
+
+ Returns:
+ res: a torch.Tensor object with two elements: [correct_count, total_count]. To correctly compute average
+ accuracy, compute acc=correct_count/total_count
+ """
+
+ def __init__(self, top_k=None, dist_sync_on_step=False):
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
+
+ if top_k is None:
+ top_k = [1]
+
+ self.top_k = top_k
+ self.add_state(
+ "correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False
+ )
+ self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False)
+
+ @torch.no_grad()
+ def top_k_predicted_labels(self, logits: torch.Tensor) -> torch.Tensor:
+ max_k = max(self.top_k)
+ _, predictions = logits.topk(max_k, dim=1, largest=True, sorted=True)
+ return predictions
+
+ def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ with torch.no_grad():
+ predictions = self.top_k_predicted_labels(logits)
+ predictions = predictions.t()
+ correct = predictions.eq(labels.view(1, -1)).expand_as(predictions)
+
+ correct_counts_k = []
+ total_counts_k = []
+
+ for k in self.top_k:
+ correct_k = correct[:k].reshape(-1).float().sum()
+ total_k = labels.shape[0]
+
+ correct_counts_k.append(correct_k)
+ total_counts_k.append(total_k)
+
+ self.correct_counts_k = torch.tensor(correct_counts_k, dtype=labels.dtype, device=labels.device)
+ self.total_counts_k = torch.tensor(total_counts_k, dtype=labels.dtype, device=labels.device)
+
+ def compute(self):
+ """
+ Computes the top-k accuracy.
+
+ Returns:
+ A list of length `K`, such that k-th index corresponds to top-k accuracy
+ over all distributed processes.
+ """
+ if not len(self.correct_counts_k) == len(self.top_k) == len(self.total_counts_k):
+ raise ValueError("length of counts must match to topk length")
+
+ if self.top_k == [1]:
+ return [self.correct_counts_k.float() / self.total_counts_k]
+
+ else:
+ top_k_scores = compute_topk_accuracy(self.correct_counts_k, self.total_counts_k)
+
+ return top_k_scores
+
+ @property
+ def top_k(self) -> List[int]:
+ return self._top_k
+
+ @top_k.setter
+ def top_k(self, value: List[int]):
+ if value is None:
+ value = [1]
+
+ if type(value) == int:
+ value = [value]
+
+ if type(value) != list:
+ value = list(value)
+
+ self._top_k = value
+
+
+def compute_topk_accuracy(correct_counts_k, total_counts_k):
+ """
+ Computes the top-k accuracy
+ Args:
+ correct_counts: Tensor of shape [K], K being the top-k parameter.
+ total_counts: Tensor of shape [K], and K being the top-k parameter.
+ Returns:
+ A list of length `K`, such that k-th index corresponds to top-k accuracy
+ over all distributed processes.
+ """
+ top_k_scores = []
+
+ for ki in range(len(correct_counts_k)):
+ correct_count = correct_counts_k[ki].item()
+ total_count = total_counts_k[ki].item()
+ top_k_scores.append(correct_count / float(total_count))
+
+ return top_k_scores
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2d093cee1288a67671e3a6a826778641c0abf94
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/metrics/perplexity.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from pytorch_lightning.metrics import Metric
+from torch.distributions.categorical import Categorical
+
+__all__ = ['Perplexity']
+
+
+class Perplexity(Metric):
+ """
+ This class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around
+ :doc:`torch.distributions.Categorical.perplexity` method. You have to provide either
+ ``probs`` or ``logits`` to the :meth:`update` method. The class computes perplexities for distributions passed to
+ :meth:`update` method in ``probs`` or ``logits`` arguments and averages the perplexities. Reducing results between
+ all workers is done via SUM operations.
+
+ See :doc:`PyTorch Lightning Metrics` for the metric usage instructions.
+
+ Args:
+ compute_on_step:
+ Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True``
+ dist_sync_on_step:
+ Synchronize metric state across processes at each ``forward()``
+ before returning the value at the step.
+ process_group:
+ Specify the process group on which synchronization is called. default: ``None`` (which selects the entire
+ world)
+ validate_args:
+ If ``True`` values of :meth:`update` method parameters are checked. ``logits`` has to not contain NaNs and
+ ``probs`` last dim has to be valid probability distribution.
+ """
+
+ def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True):
+ super().__init__(
+ compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
+ )
+ self.validate_args = validate_args
+ self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
+ # Total number of distributions seen since last reset
+ self.add_state('num_distributions', torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum')
+
+ def update(self, probs=None, logits=None):
+ """
+ Updates :attr:`perplexities_sum` and :attr:`num_distributions`.
+
+ Args:
+ probs: A ``torch.Tensor`` which innermost dimension is valid probability distribution.
+ logits: A ``torch.Tensor`` without NaNs.
+ """
+ d = Categorical(probs, logits, validate_args=self.validate_args)
+ ppl = d.perplexity()
+ self.num_distributions += ppl.numel()
+ self.perplexities_sum += ppl.sum()
+
+ def compute(self):
+ """
+ Returns perplexity across all workers and resets to 0 :attr:`perplexities_sum` and :attr:`num_distributions`.
+ """
+ if self.num_distributions.eq(0):
+ return None
+ return self.perplexities_sum / self.num_distributions
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f3916e3655d5aa32650d6189e580431d9b8bc32
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.parts.multi_layer_perceptron import MultiLayerPerceptron
+from nemo.collections.common.parts.transformer_utils import *
+from nemo.collections.common.parts.utils import *
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd0197e4836af100044e805a72ba734332da5395
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/mem_transformer.py
@@ -0,0 +1,531 @@
+import random
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from nemo.collections.common.parts.normalization import LayerVarNorm
+
+float32_min = np.finfo(np.float32).min
+float16_min = np.finfo(np.float16).min
+
+
+class PositionalEmbedding(nn.Module):
+ def __init__(self, demb):
+ super(PositionalEmbedding, self).__init__()
+
+ self.demb = demb
+
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, pos_seq, bsz=None):
+ sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
+
+ if bsz is not None:
+ return pos_emb[:,None,:].expand(-1, bsz, -1)
+ else:
+ return pos_emb[:,None,:]
+
+
+def get_norm(norm_type, d_model, ln_eps):
+ if norm_type == 'ln':
+ norm = nn.LayerNorm(d_model, eps=ln_eps)
+ else:
+ assert norm_type == 'var_ln'
+ norm = LayerVarNorm(d_model, eps=ln_eps)
+ return norm
+
+
+class PositionwiseFF(nn.Module):
+ def __init__(self, d_model, d_inner, *, dropout, pre_lnorm=False, ln_eps=1e-5, norm_type='ln'):
+ super(PositionwiseFF, self).__init__()
+
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+
+ self.CoreNet = nn.Sequential(
+ nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
+ nn.Dropout(dropout),
+ nn.Linear(d_inner, d_model),
+ nn.Dropout(dropout),
+ )
+
+ self.layer_norm = get_norm(norm_type, d_model, ln_eps)
+
+ self.pre_lnorm = pre_lnorm
+
+ def forward(self, inp):
+ if self.pre_lnorm:
+ ##### layer normalization + positionwise feed-forward
+ core_out = self.CoreNet(self.layer_norm(inp))
+
+ ##### residual connection
+ output = core_out + inp
+ else:
+ ##### positionwise feed-forward
+ core_out = self.CoreNet(inp)
+
+ ##### residual connection + layer normalization
+ output = self.layer_norm(inp + core_out)
+
+ return output
+
+
+class AdaptiveFFN(nn.Module):
+ def __init__(self, d_model, d_inner, *, dropout, ln_eps=1e-5, gate_mix_prob=0.5, gate_sample_prob=0.25,
+ identity_loss_weight=1.0, identity_threshold=0.9, init_identiy_bias=2.0,
+ ffn_residual=True, norm_in_ffn=True):
+ super(AdaptiveFFN, self).__init__()
+
+ self.ffn_net = nn.Sequential(
+ nn.Linear(d_model, d_inner),
+ nn.ReLU(inplace=True),
+ nn.Dropout(dropout),
+ nn.Linear(d_inner, d_model),
+ nn.Dropout(dropout),
+ )
+ self.ffn_residual = ffn_residual
+
+ self.layer_norm = nn.LayerNorm(d_model, eps=ln_eps)
+ self.norm_in_ffn = norm_in_ffn
+
+ self.gate_net = GateNet(d_model, gate_num=2, init_identiy_bias=init_identiy_bias)
+ self.gate_mix_prob = gate_mix_prob
+ self.gate_sample_prob = gate_sample_prob
+ assert 0 <= self.gate_mix_prob <= 1.0
+ assert 0 <= self.gate_sample_prob <= 1.0
+ assert 0 <= self.gate_mix_prob + self.gate_sample_prob <= 1.0
+ self.identity_threshold = identity_threshold
+
+ def forward(self, inputs, *, pad_mask):
+ # [T, B, D] => [T, B, gate_num]
+ gate_prob = self.gate_net(inputs)
+ # paddings may leads to NAN after softmax
+ # gate_prob = gate_prob.masked_fill(pad_mask.unsqueeze(2), 0.)
+ identity_prob, ffn_prob = torch.chunk(gate_prob, 2, dim=-1)
+
+ ffn_output = self._ffn_forward(inputs)
+
+ if self.training:
+ r = random.random()
+ if r < self.gate_mix_prob:
+ # learn gate
+ output = inputs * identity_prob + ffn_output * ffn_prob
+ adaptive_prob = identity_prob
+ elif r < self.gate_mix_prob + self.gate_sample_prob:
+ # exploit
+ identity_mask = torch.bernoulli(identity_prob)
+ output = inputs * identity_mask + ffn_output * (1 - identity_mask)
+ adaptive_prob = None
+ else:
+ # explore, by uniform sample branches
+ mask_size = (inputs.shape[0], inputs.shape[1], 1)
+ identity_mask = torch.bernoulli(torch.full(mask_size, 0.5, device=inputs.device)).type(inputs.dtype)
+ output = inputs * identity_mask + ffn_output * (1 - identity_mask)
+ adaptive_prob = None
+ else:
+ identity_mask = identity_prob > self.identity_threshold
+ output = inputs * identity_mask + ffn_output * ~identity_mask
+ adaptive_prob = identity_prob
+
+ if not self.norm_in_ffn:
+ output = self.layer_norm(output)
+
+ return output, adaptive_prob
+
+ def _ffn_forward(self, inp):
+ output = inp
+
+ output = self.ffn_net(output)
+ if self.ffn_residual:
+ output = output + inp
+
+ if self.norm_in_ffn:
+ output = self.layer_norm(output)
+
+ return output
+
+
+class GateNet(nn.Module):
+ def __init__(self, in_features, gate_num, init_identiy_bias):
+ super(GateNet, self).__init__()
+
+ assert gate_num == 2
+ self.weight = nn.Parameter(torch.Tensor(gate_num, in_features))
+ self.bias = nn.Parameter(torch.tensor([init_identiy_bias] + [0.] * (gate_num - 1)))
+
+ def forward(self, inputs):
+ logits = F.linear(inputs, self.weight, self.bias)
+ prob = F.softmax(logits, dim=-1)
+ return prob
+
+
+def _rel_shift_uni(x):
+ # x: qlen x rlen x bsz x n_head
+ zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
+ device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=1)
+
+ x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
+
+ x = x_padded[1:].view_as(x)
+
+ # if zero_triu:
+ # ones = torch.ones((x.size(0), x.size(1)))
+ # x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
+
+ return x
+
+
+def _rel_shift_bi(x, klen):
+ # x: qlen x rlen x bsz x n_head
+ """perform relative shift to form the relative attention score."""
+ x_size = x.size()
+ assert klen * 2 == x_size[1]
+
+ x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
+ x = torch.narrow(x, dim=0, start=1, length=x.size()[0]-1)
+ x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
+ x = torch.narrow(x, dim=1, start=0, length=klen)
+
+ return x
+
+
+def check_rel_shift_bi():
+ """in x, 14 means query 1, rel emb at position 4, -32 mean query 3, rel emb at position -2"""
+ x = torch.tensor([[14, 13, 12, 11, 10, -11, -12, -13],
+ [24, 23, 22, 21, 20, -21, -22, -23],
+ [34, 33, 32, 31, 30, -31, -32, -33],
+ [44, 43, 42, 41, 40, -41, -42, -43]], dtype=torch.float32)
+ x = x.unsqueeze(-1).unsqueeze(-1)
+ shifted_x = _rel_shift_bi(x, klen=4)
+ shifted_x = shifted_x.squeeze(-1).squeeze(-1)
+ assert torch.equal(shifted_x,
+ torch.tensor([[10., -11., -12., -13.],
+ [21., 20., -21., -22.],
+ [32., 31., 30., -31.],
+ [43., 42., 41., 40.]]))
+ return shifted_x
+
+
+def create_pad_mask(lens, max_len):
+ mask = torch.arange(max_len).to(lens.device) >= lens.unsqueeze(-1)
+ return mask
+
+
+class RelPartialLearnableMultiHeadAttn(nn.Module):
+ def __init__(self, n_head, d_model, d_head, *, dropout, dropatt, pre_lnorm=False, ln_eps=1e-5, uni_attn=True,
+ norm_type='ln', pos_enc='xl'):
+ super(RelPartialLearnableMultiHeadAttn, self).__init__()
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_head = d_head
+ self.dropout = dropout
+
+ self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
+
+ self.drop = nn.Dropout(dropout)
+ self.dropatt = nn.Dropout(dropatt)
+ self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
+
+ self.layer_norm = get_norm(norm_type, d_model, ln_eps)
+
+ self.scale = 1 / (d_head ** 0.5)
+
+ self.pre_lnorm = pre_lnorm
+
+ self.uni_attn = uni_attn
+
+ self.pos_enc = pos_enc
+ if self.pos_enc == 'xl':
+ self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
+ else:
+ assert self.pos_enc is None
+
+ def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
+ qlen, bsz = w.size(0), w.size(1)
+
+ if self.pre_lnorm:
+ w_norm = self.layer_norm(w)
+ else:
+ w_norm = w
+ w_heads = self.qkv_net(w_norm)
+
+ w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+ klen = w_head_k.size(0)
+
+ w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+ w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+ w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+
+ if mems is not None:
+ assert self.uni_attn
+ k_mems, v_mems, real_mlen = mems
+ new_mems = w_head_k, w_head_v
+
+ w_head_k = torch.cat([k_mems, w_head_k], 0)
+ w_head_v = torch.cat([v_mems, w_head_v], 0)
+ else:
+ new_mems = None
+
+ #### compute attention score
+ if self.pos_enc == 'xl':
+ rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
+ else:
+ rw_head_q = w_head_q
+ attn_score = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
+
+ if self.pos_enc == 'xl':
+ rr_head_q = w_head_q + r_r_bias
+ r_head_k = self.r_net(r)
+ r_head_k = r_head_k.view(r.size(0), self.n_head, self.d_head) # qlen x n_head x d_head
+ BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
+ BD = self._rel_shift(BD, attn_score.size(1))
+ # [qlen x klen x bsz x n_head]
+ attn_score = attn_score + BD
+
+ attn_score.mul_(self.scale)
+
+ neg_min = float16_min if attn_score.dtype == torch.float16 else float32_min
+ #### compute attention probability
+ if self.uni_attn:
+ # attn_mask: [qlen, klen, 1] -> [qlen, klen, 1, 1]
+ attn_score = attn_score.float().masked_fill(
+ attn_mask.unsqueeze(-1), neg_min).type_as(attn_score)
+ else:
+ # attn_mask: [klen, bsz] -> [1, klen, bsz, 1]
+ attn_score = attn_score.masked_fill(attn_mask.unsqueeze(0).unsqueeze(-1), neg_min)
+
+ # [qlen x klen x bsz x n_head]
+ attn_prob = F.softmax(attn_score, dim=1)
+ attn_prob = self.dropatt(attn_prob)
+
+ #### compute attention vector
+ attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
+
+ # [qlen x bsz x n_head x d_head]
+ attn_vec = attn_vec.contiguous().view(
+ attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+ ##### linear projection
+ attn_out = self.o_net(attn_vec)
+ attn_out = self.drop(attn_out)
+
+ if self.pre_lnorm:
+ ##### residual connection
+ output = w + attn_out
+ else:
+ ##### residual connection + layer normalization
+ output = self.layer_norm(w + attn_out)
+
+ return output, new_mems
+
+ def _rel_shift(self, x, klen):
+ if self.uni_attn:
+ return _rel_shift_uni(x)
+ else:
+ return _rel_shift_bi(x, klen)
+
+
+class RelPartialLearnableDecoderLayer(nn.Module):
+ def __init__(self, n_head, d_model, d_head, d_inner, *, dropout, dropatt, pre_lnorm, ln_eps, uni_attn, norm_type='ln',
+ pos_enc='xl', adaptive_ffn=None):
+ super(RelPartialLearnableDecoderLayer, self).__init__()
+
+ self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout=dropout, dropatt=dropatt,
+ pre_lnorm=pre_lnorm, ln_eps=ln_eps, uni_attn=uni_attn,
+ norm_type=norm_type, pos_enc=pos_enc)
+ self.use_adaptive_ffn = adaptive_ffn is not None
+ if not self.use_adaptive_ffn:
+ self.pos_ff = PositionwiseFF(d_model, d_inner, dropout=dropout, pre_lnorm=pre_lnorm, ln_eps=ln_eps, norm_type=norm_type)
+ else:
+ assert not pre_lnorm and norm_type == 'ln'
+ self.pos_ff = AdaptiveFFN(d_model, d_inner, dropout=dropout, ln_eps=ln_eps, **adaptive_ffn)
+
+ def forward(self, dec_inp, r, r_w_bias, r_r_bias, *, dec_attn_mask=None, pad_mask=None, mems=None):
+
+ output, new_mems = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
+ attn_mask=dec_attn_mask,
+ mems=mems)
+ if not self.use_adaptive_ffn:
+ output = self.pos_ff(output)
+ ada_prob = None
+ else:
+ output, ada_prob = self.pos_ff(output, pad_mask=pad_mask)
+
+ return output, new_mems, (ada_prob,)
+
+
+class RelTransformerBlock(nn.Module):
+ def __init__(self, n_layer, d_model, n_head, d_head, d_inner, dropout, dropout_att,
+ pre_lnorm=False, norm_output=False, ln_eps=1e-5, uni_attn=True, norm_type='ln', pos_enc='xl',
+ layer_drop=0.0, adaptive_ffn=None):
+ super(RelTransformerBlock, self).__init__()
+
+ self.n_layer = n_layer
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+
+ self.drop = nn.Dropout(dropout)
+
+ self.uni_attn = uni_attn
+ self.att_trunc_len = -1
+
+ self.clamp_len = 0
+
+ self.layer_drop = layer_drop
+
+ self.layers = nn.ModuleList()
+ for i in range(n_layer):
+ self.layers.append(
+ RelPartialLearnableDecoderLayer(
+ n_head, d_model, d_head, d_inner, dropout=dropout,
+ dropatt=dropout_att, pre_lnorm=pre_lnorm, ln_eps=ln_eps, uni_attn=uni_attn, norm_type=norm_type,
+ pos_enc=pos_enc, adaptive_ffn=adaptive_ffn)
+ )
+
+ if norm_output:
+ self.output_norm = nn.LayerNorm(d_model, eps=ln_eps)
+ else:
+ self.output_norm = identity
+
+ self.use_adaptive_ffn = adaptive_ffn is not None
+ if self.use_adaptive_ffn:
+ self.identity_loss_weight = adaptive_ffn.identity_loss_weight
+
+ self.pos_enc = pos_enc
+ self._create_params()
+
+ def _create_params(self):
+ if self.pos_enc == 'xl':
+ self.pos_emb = PositionalEmbedding(self.d_model)
+ self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+ self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+ else:
+ assert self.pos_enc is None
+ self.pos_emb = None
+ self.r_w_bias = None
+ self.r_r_bias = None
+
+ def _forward(self, dec_inp, lens=None, mems=None):
+ is_decoding = mems is not None
+ qlen, bsz, _ = dec_inp.size()
+
+ mlen = mems[0][0].size(0) if mems is not None else 0
+ klen = mlen + qlen
+
+ if not self.uni_attn or self.use_adaptive_ffn:
+ assert lens is not None
+ pad_mask = create_pad_mask(lens, max_len=qlen)
+ # [B, L] -> [L, B]
+ pad_mask = pad_mask.transpose(0, 1).contiguous()
+ else:
+ pad_mask = None
+
+ if self.uni_attn:
+ dec_attn_mask = torch.triu(
+ dec_inp.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
+ else:
+ assert pad_mask is not None
+ dec_attn_mask = pad_mask
+
+ hids = []
+ new_kv_mems = []
+ core_out = dec_inp
+
+ if self.pos_enc == 'xl':
+ if self.uni_attn:
+ pos_s, pos_e = klen-1, -1
+ else:
+ pos_s, pos_e = klen, -qlen
+ pos_seq = torch.arange(pos_s, pos_e, -1.0, device=dec_inp.device,
+ dtype=dec_inp.dtype)
+ if self.clamp_len > 0:
+ pos_seq.clamp_(max=self.clamp_len)
+ pos_emb = self.pos_emb(pos_seq)
+
+ pos_emb = self.drop(pos_emb)
+ else:
+ pos_emb = None
+
+ if self.use_adaptive_ffn:
+ adaptive_prob = []
+ else:
+ adaptive_prob = None
+
+ # hids.append(core_out)
+ for i, layer in enumerate(self.layers):
+ if self.layer_drop > 0 and self.training and np.random.random() < self.layer_drop:
+ continue
+
+ mems_i = None if not is_decoding else mems[i]
+ core_out, kv_mems, extra = layer(core_out, pos_emb, self.r_w_bias,
+ self.r_r_bias, dec_attn_mask=dec_attn_mask, pad_mask=pad_mask, mems=mems_i)
+ # hids.append(core_out)
+ new_kv_mems.append(kv_mems)
+
+ if self.use_adaptive_ffn:
+ if extra[0] is not None:
+ adaptive_prob.append(extra[0])
+
+ core_out = self.output_norm(core_out)
+
+ if self.use_adaptive_ffn:
+ if len(adaptive_prob) == 0:
+ assert self.training
+ adaptive_ffn_loss = torch.tensor(0., dtype=dec_inp.dtype, device=dec_inp.device)
+ else:
+ adaptive_prob_t = torch.stack(adaptive_prob)
+ zero_guard_eps = 1e-12
+ adaptive_logp = torch.log(adaptive_prob_t + zero_guard_eps)
+ # pad_mask: [L, B] => [1, L, B, 1]
+ adaptive_logp = adaptive_logp.masked_fill(pad_mask.unsqueeze(-1).unsqueeze(0), 0.)
+ avg_adaptive_logp = adaptive_logp.sum() / (len(adaptive_prob) * lens.sum())
+ adaptive_ffn_loss = -avg_adaptive_logp * self.identity_loss_weight
+ else:
+ adaptive_ffn_loss = None
+
+ new_mems = []
+ if is_decoding:
+ new_mems = self._update_decode_mems(new_kv_mems, mems, [self.att_trunc_len] * len(self.layers))
+
+ return core_out, new_mems, (adaptive_ffn_loss,)
+
+ def forward(self, data, *, lens=None, mems=None):
+ # data: [T, B, D]
+ output, new_mems, extra = self._forward(data, lens=lens, mems=mems)
+
+ if mems is None:
+ assert len(new_mems) == 0
+
+ return output, tuple(new_mems), extra
+
+ def _update_decode_mems(self, step_mem, prev_mems, mem_len):
+ assert prev_mems is not None and len(step_mem) == len(prev_mems)
+
+ with torch.no_grad():
+ new_mems = []
+ for i in range(len(step_mem)):
+ mem_len_i = mem_len[i]
+ k_mem, v_mem = step_mem[i]
+ k_prev_mem, v_prev_mem, real_mlen = prev_mems[i]
+ new_k_mem = torch.cat([k_prev_mem, k_mem], 0)
+ new_v_mem = torch.cat([v_prev_mem, v_mem], 0)
+ real_mlen = real_mlen + k_mem.size(0)
+ if mem_len_i > 0:
+ new_k_mem = self.neg_slice(new_k_mem, mem_len, True)
+ new_v_mem = self.neg_slice(new_v_mem, mem_len, True)
+ real_mlen = torch.min(real_mlen, mem_len)
+ new_mems.append((new_k_mem, new_v_mem, real_mlen))
+ return new_mems
+
+
+def identity(x):
+ return x
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py
new file mode 100644
index 0000000000000000000000000000000000000000..76c06bf23ea647148d4eaa0a9686e631c169a436
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/multi_layer_perceptron.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+class MultiLayerPerceptron(torch.nn.Module):
+ """
+ A simple MLP that can either be used independently or put on top
+ of pretrained models (such as BERT) and act as a classifier.
+ Args:
+ hidden_size (int): the size of each layer
+ num_classes (int): number of output classes
+ num_layers (int): number of layers
+ activation (str): type of activations for layers in between
+ log_softmax (bool): whether to add a log_softmax layer before output
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_classes: int,
+ num_layers: int = 2,
+ activation: str = 'relu',
+ log_softmax: bool = True,
+ ):
+ super().__init__()
+ self.layers = 0
+ for _ in range(num_layers - 1):
+ layer = torch.nn.Linear(hidden_size, hidden_size)
+ setattr(self, f'layer{self.layers}', layer)
+ setattr(self, f'layer{self.layers + 1}', getattr(torch, activation))
+ self.layers += 2
+ layer = torch.nn.Linear(hidden_size, num_classes)
+ setattr(self, f'layer{self.layers}', layer)
+ self.layers += 1
+ self.log_softmax = log_softmax
+
+ @property
+ def last_linear_layer(self):
+ return getattr(self, f'layer{self.layers - 1}')
+
+ def forward(self, hidden_states):
+ output_states = hidden_states[:]
+ for i in range(self.layers):
+ output_states = getattr(self, f'layer{i}')(output_states)
+
+ if self.log_softmax:
+ output_states = torch.log_softmax(output_states, dim=-1)
+ return output_states
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e80e2bef0e7795479d65b3a775651a0607fcc335
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/normalization.py
@@ -0,0 +1,37 @@
+import math
+import numbers
+
+import torch
+import torch.nn as nn
+
+
+class LayerVarNorm(nn.Module):
+ __constants__ = ['normalized_shape', 'weight', 'eps', 'elementwise_affine']
+
+ def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=False):
+ super(LayerVarNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = tuple(normalized_shape)
+ self.eps = math.sqrt(eps) # eps is added directly on std, i.e., outside sqrt(var)
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
+ else:
+ self.register_parameter('weight', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.elementwise_affine:
+ nn.init.ones_(self.weight)
+
+ def forward(self, input):
+ std = torch.std(input, dim=-1, unbiased=False, keepdim=True)
+ output = input / (std + self.eps)
+ if self.elementwise_affine:
+ output = output * self.weight
+ return output
+
+ def extra_repr(self):
+ return '{normalized_shape}, eps={eps}, ' \
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
\ No newline at end of file
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea50d7e1918e4603e4eb044a09cc721b0c75df87
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/rnn.py
@@ -0,0 +1,510 @@
+# Copyright (c) 2019, Myrtle Software Limited. All rights reserved.
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+
+def rnn(
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ norm: Optional[str] = None,
+ forget_gate_bias: Optional[float] = 1.0,
+ dropout: Optional[float] = 0.0,
+ norm_first_rnn: Optional[bool] = None,
+ t_max: Optional[int] = None,
+) -> torch.nn.Module:
+ """
+ Utility function to provide unified interface to common LSTM RNN modules.
+
+ Args:
+ input_size: Input dimension.
+
+ hidden_size: Hidden dimension of the RNN.
+
+ num_layers: Number of RNN layers.
+
+ norm: Optional string representing type of normalization to apply to the RNN.
+ Supported values are None, batch and layer.
+
+ forget_gate_bias: float, set by default to 1.0, which constructs a forget gate
+ initialized to 1.0.
+ Reference:
+ [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf)
+
+ dropout: Optional dropout to apply to end of multi-layered RNN.
+
+ norm_first_rnn: Whether to normalize the first RNN layer.
+
+ t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization
+ of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course
+ of training.
+ Reference:
+ [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
+
+ Returns:
+ A RNN module
+ """
+ if norm not in [None, "batch", "layer"]:
+ raise ValueError(f"unknown norm={norm}")
+
+ if norm is None:
+ return LSTMDropout(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ dropout=dropout,
+ forget_gate_bias=forget_gate_bias,
+ t_max=t_max,
+ )
+
+ if norm == "batch":
+ return BNRNNSum(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ rnn_layers=num_layers,
+ batch_norm=True,
+ dropout=dropout,
+ forget_gate_bias=forget_gate_bias,
+ t_max=t_max,
+ norm_first_rnn=norm_first_rnn,
+ )
+
+ if norm == "layer":
+ return torch.jit.script(
+ ln_lstm( # torch.jit.script(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ dropout=dropout,
+ forget_gate_bias=forget_gate_bias,
+ t_max=t_max,
+ )
+ )
+
+
+class OverLastDim(torch.nn.Module):
+ """Collapses a tensor to 2D, applies a module, and (re-)expands the tensor.
+ An n-dimensional tensor of shape (s_1, s_2, ..., s_n) is first collapsed to
+ a tensor with shape (s_1*s_2*...*s_n-1, s_n). The module is called with
+ this as input producing (s_1*s_2*...*s_n-1, s_n') --- note that the final
+ dimension can change. This is expanded to (s_1, s_2, ..., s_n-1, s_n') and
+ returned.
+ Args:
+ module (torch.nn.Module): Module to apply. Must accept a 2D tensor as
+ input and produce a 2D tensor as output, optionally changing the
+ size of the last dimension.
+ """
+
+ def __init__(self, module: torch.nn.Module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ *dims, _ = x.size()
+
+ reduced_dims = 1
+ for dim in dims:
+ reduced_dims *= dim
+
+ x = x.view(reduced_dims, -1)
+ x = self.module(x)
+ x = x.view(*dims, -1)
+ return x
+
+
+class LSTMDropout(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ dropout: Optional[float],
+ forget_gate_bias: Optional[float],
+ t_max: Optional[int] = None,
+ ):
+ """Returns an LSTM with forget gate bias init to `forget_gate_bias`.
+ Args:
+ input_size: See `torch.nn.LSTM`.
+ hidden_size: See `torch.nn.LSTM`.
+ num_layers: See `torch.nn.LSTM`.
+ dropout: See `torch.nn.LSTM`.
+
+ forget_gate_bias: float, set by default to 1.0, which constructs a forget gate
+ initialized to 1.0.
+ Reference:
+ [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf)
+
+ t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization
+ of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course
+ of training.
+ Reference:
+ [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab)
+
+ Returns:
+ A `torch.nn.LSTM`.
+ """
+ super(LSTMDropout, self).__init__()
+
+ self.lstm = torch.nn.LSTM(
+ input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
+ )
+
+ if t_max is not None:
+ # apply chrono init
+ for name, v in self.lstm.named_parameters():
+ if 'bias' in name:
+ p = getattr(self.lstm, name)
+ n = p.nelement()
+ hidden_size = n // 4
+ p.data.fill_(0)
+ p.data[hidden_size : 2 * hidden_size] = torch.log(
+ torch.nn.init.uniform_(p.data[0:hidden_size], 1, t_max - 1)
+ )
+ # forget gate biases = log(uniform(1, Tmax-1))
+ p.data[0:hidden_size] = -p.data[hidden_size : 2 * hidden_size]
+ # input gate biases = -(forget gate biases)
+
+ elif forget_gate_bias is not None:
+ for name, v in self.lstm.named_parameters():
+ if "bias_ih" in name:
+ bias = getattr(self.lstm, name)
+ bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias)
+ if "bias_hh" in name:
+ bias = getattr(self.lstm, name)
+ bias.data[hidden_size : 2 * hidden_size].fill_(0)
+
+ self.dropout = torch.nn.Dropout(dropout) if dropout else None
+
+ def forward(
+ self, x: torch.Tensor, h: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ x, h = self.lstm(x, h)
+
+ if self.dropout:
+ x = self.dropout(x)
+
+ return x, h
+
+
+class RNNLayer(torch.nn.Module):
+ """A single RNNLayer with optional batch norm."""
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ rnn_type: torch.nn.Module = torch.nn.LSTM,
+ batch_norm: bool = True,
+ forget_gate_bias: Optional[float] = 1.0,
+ t_max: Optional[int] = None,
+ ):
+ super().__init__()
+
+ if batch_norm:
+ self.bn = OverLastDim(torch.nn.BatchNorm1d(input_size))
+
+ if isinstance(rnn_type, torch.nn.LSTM) and not batch_norm:
+ # batch_norm will apply bias, no need to add a second to LSTM
+ self.rnn = LSTMDropout(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ num_layers=1,
+ dropout=0.0,
+ forget_gate_bias=forget_gate_bias,
+ t_max=t_max,
+ )
+ else:
+ self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, bias=not batch_norm)
+
+ def forward(
+ self, x: torch.Tensor, hx: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ if hasattr(self, 'bn'):
+ x = x.contiguous()
+ x = self.bn(x)
+ x, h = self.rnn(x, hx=hx)
+ return x, h
+
+ def _flatten_parameters(self):
+ self.rnn.flatten_parameters()
+
+
+class BNRNNSum(torch.nn.Module):
+ """RNN wrapper with optional batch norm.
+ Instantiates an RNN. If it is an LSTM it initialises the forget gate
+ bias =`lstm_gate_bias`. Optionally applies a batch normalisation layer to
+ the input with the statistics computed over all time steps. If dropout > 0
+ then it is applied to all layer outputs except the last.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ rnn_type: torch.nn.Module = torch.nn.LSTM,
+ rnn_layers: int = 1,
+ batch_norm: bool = True,
+ dropout: Optional[float] = 0.0,
+ forget_gate_bias: Optional[float] = 1.0,
+ norm_first_rnn: bool = False,
+ t_max: Optional[int] = None,
+ ):
+ super().__init__()
+ self.rnn_layers = rnn_layers
+
+ self.layers = torch.nn.ModuleList()
+ for i in range(rnn_layers):
+ final_layer = (rnn_layers - 1) == i
+
+ self.layers.append(
+ RNNLayer(
+ input_size,
+ hidden_size,
+ rnn_type=rnn_type,
+ batch_norm=batch_norm and (norm_first_rnn or i > 0),
+ forget_gate_bias=forget_gate_bias,
+ t_max=t_max,
+ )
+ )
+
+ if dropout is not None and dropout > 0.0 and not final_layer:
+ self.layers.append(torch.nn.Dropout(dropout))
+
+ input_size = hidden_size
+
+ def forward(
+ self, x: torch.Tensor, hx: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ hx = self._parse_hidden_state(hx)
+
+ hs = []
+ cs = []
+ rnn_idx = 0
+ for layer in self.layers:
+ if isinstance(layer, torch.nn.Dropout):
+ x = layer(x)
+ else:
+ x, h_out = layer(x, hx=hx[rnn_idx])
+ hs.append(h_out[0])
+ cs.append(h_out[1])
+ rnn_idx += 1
+ del h_out
+
+ h_0 = torch.stack(hs, dim=0)
+ c_0 = torch.stack(cs, dim=0)
+ return x, (h_0, c_0)
+
+ def _parse_hidden_state(
+ self, hx: Optional[Tuple[torch.Tensor, torch.Tensor]]
+ ) -> Union[List[None], List[Tuple[torch.Tensor, torch.Tensor]]]:
+ """
+ Dealing w. hidden state:
+ Typically in pytorch: (h_0, c_0)
+ h_0 = ``[num_layers * num_directions, batch, hidden_size]``
+ c_0 = ``[num_layers * num_directions, batch, hidden_size]``
+ """
+ if hx is None:
+ return [None] * self.rnn_layers
+ else:
+ h_0, c_0 = hx
+
+ if h_0.shape[0] != self.rnn_layers:
+ raise ValueError(
+ 'Provided initial state value `h_0` must be of shape : '
+ '[num_layers * num_directions, batch, hidden_size]'
+ )
+
+ return [(h_0[i], c_0[i]) for i in range(h_0.shape[0])]
+
+ def _flatten_parameters(self):
+ for layer in self.layers:
+ if isinstance(layer, (torch.nn.LSTM, torch.nn.GRU, torch.nn.RNN)):
+ layer._flatten_parameters()
+
+
+class StackTime(torch.nn.Module):
+ """
+ Stacks time within the feature dim, so as to behave as a downsampling operation.
+ """
+
+ def __init__(self, factor: int):
+ super().__init__()
+ self.factor = int(factor)
+
+ def forward(self, x: List[Tuple[torch.Tensor]]) -> (torch.Tensor, torch.Tensor):
+ # T, B, U
+ x, x_lens = x
+ seq = [x]
+ for i in range(1, self.factor):
+ tmp = torch.zeros_like(x)
+ tmp[:-i, :, :] = x[i:, :, :]
+ seq.append(tmp)
+ x_lens = torch.ceil(x_lens.float() / self.factor).int()
+ return torch.cat(seq, dim=2)[:: self.factor, :, :], x_lens
+
+
+def ln_lstm(
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ dropout: Optional[float],
+ forget_gate_bias: Optional[float],
+ t_max: Optional[int],
+) -> torch.nn.Module:
+ """Returns a ScriptModule that mimics a PyTorch native LSTM."""
+ # The following are not implemented.
+ if dropout is not None and dropout != 0.0:
+ raise ValueError('`dropout` not supported with LayerNormLSTM')
+
+ if t_max is not None:
+ raise ValueError("LayerNormLSTM does not support chrono init")
+
+ return StackedLSTM(
+ num_layers,
+ LSTMLayer,
+ first_layer_args=[LayerNormLSTMCell, input_size, hidden_size, forget_gate_bias],
+ other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size, forget_gate_bias],
+ )
+
+
+class LSTMLayer(torch.nn.Module):
+ def __init__(self, cell, *cell_args):
+ super(LSTMLayer, self).__init__()
+ self.cell = cell(*cell_args)
+
+ def forward(
+ self, input: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ inputs = input.unbind(0)
+ outputs = []
+ for i in range(len(inputs)):
+ out, state = self.cell(inputs[i], state)
+ outputs += [out]
+ return torch.stack(outputs), state
+
+
+class LayerNormLSTMCell(torch.nn.Module):
+ def __init__(self, input_size, hidden_size, forget_gate_bias):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.weight_ih = torch.nn.Parameter(torch.randn(4 * hidden_size, input_size))
+ self.weight_hh = torch.nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
+
+ # LayerNorm provide learnable biases
+ self.layernorm_i = torch.nn.LayerNorm(4 * hidden_size)
+ self.layernorm_h = torch.nn.LayerNorm(4 * hidden_size)
+ self.layernorm_c = torch.nn.LayerNorm(hidden_size)
+
+ self.reset_parameters()
+
+ self.layernorm_i.bias.data[hidden_size : 2 * hidden_size].fill_(0.0)
+ self.layernorm_h.bias.data[hidden_size : 2 * hidden_size].fill_(forget_gate_bias)
+
+ def reset_parameters(self):
+ stdv = 1.0 / math.sqrt(self.hidden_size)
+ for weight in self.parameters():
+ torch.nn.init.uniform_(weight, -stdv, stdv)
+
+ def forward(
+ self, input: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ hx, cx = state
+ igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
+ hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
+ gates = igates + hgates
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+ ingate = torch.sigmoid(ingate)
+ forgetgate = torch.sigmoid(forgetgate)
+ cellgate = torch.tanh(cellgate)
+ outgate = torch.sigmoid(outgate)
+
+ cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
+ hy = outgate * torch.tanh(cy)
+
+ return hy, (hy, cy)
+
+
+def init_stacked_lstm(
+ num_layers: int, layer: torch.nn.Module, first_layer_args: List, other_layer_args: List
+) -> torch.nn.ModuleList:
+ layers = [layer(*first_layer_args)] + [layer(*other_layer_args) for _ in range(num_layers - 1)]
+ return torch.nn.ModuleList(layers)
+
+
+class StackedLSTM(torch.nn.Module):
+ def __init__(self, num_layers: int, layer: torch.nn.Module, first_layer_args: List, other_layer_args: List):
+ super(StackedLSTM, self).__init__()
+ self.layers: torch.nn.ModuleList = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args)
+
+ def forward(
+ self, input: torch.Tensor, states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
+ if states is None:
+ temp_states: List[Tuple[torch.Tensor, torch.Tensor]] = []
+ batch = input.size(1)
+ for layer in self.layers:
+ temp_states.append(
+ (
+ torch.zeros(batch, layer.cell.hidden_size, dtype=input.dtype, device=input.device),
+ torch.zeros(batch, layer.cell.hidden_size, dtype=input.dtype, device=input.device),
+ )
+ )
+
+ states = temp_states
+
+ output_states: List[Tuple[torch.Tensor, torch.Tensor]] = []
+ output = input
+ for i, rnn_layer in enumerate(self.layers):
+ state = states[i]
+ output, out_state = rnn_layer(output, state)
+ output_states.append(out_state)
+ i += 1
+ return output, output_states
+
+
+def label_collate(labels, device=None):
+ """Collates the label inputs for the rnn-t prediction network.
+ If `labels` is already in torch.Tensor form this is a no-op.
+
+ Args:
+ labels: A torch.Tensor List of label indexes or a torch.Tensor.
+ device: Optional torch device to place the label on.
+
+ Returns:
+ A padded torch.Tensor of shape (batch, max_seq_len).
+ """
+
+ if isinstance(labels, torch.Tensor):
+ assert labels.dtype == torch.int32 or labels.dtype == torch.int64
+ return labels.type(torch.int64)
+ if not isinstance(labels, (list, tuple)):
+ raise ValueError(f"`labels` should be a list or tensor not {type(labels)}")
+
+ batch_size = len(labels)
+ max_len = max(len(label) for label in labels)
+
+ cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32)
+ for e, l in enumerate(labels):
+ cat_labels[e, : len(l)] = l
+ labels = torch.tensor(cat_labels, dtype=torch.int64, device=device)
+
+ return labels
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..467ee9011915a78cd1956229408d1a358708cd65
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/transformer_utils.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+__all__ = ['NEG_INF', 'form_attention_mask', 'transformer_weights_init', 'mask_padded_tokens']
+
+NEG_INF = -10000.0
+
+
+def form_attention_mask(input_mask, diagonal=None):
+ """
+ Build attention mask with optional masking of future tokens we forbid
+ to attend to (e.g. as it is in Transformer decoder).
+
+ Args:
+ input_mask: binary mask of size B x L with 1s corresponding to valid
+ tokens and 0s corresponding to padding tokens
+ diagonal: diagonal where triangular future mask starts
+ None -- do not mask anything
+ 0 -- regular translation or language modeling future masking
+ 1 -- query stream masking as in XLNet architecture
+ Returns:
+ attention_mask: mask of size B x 1 x L x L with 0s corresponding to
+ tokens we plan to attend to and -10000 otherwise
+ """
+
+ if input_mask is None:
+ return None
+ attn_shape = (1, input_mask.shape[1], input_mask.shape[1])
+ attn_mask = input_mask.byte().unsqueeze(1)
+ if diagonal is not None:
+ future_mask = torch.tril(torch.ones(attn_shape).byte().to(input_mask.device), diagonal)
+ attn_mask = attn_mask & future_mask
+ attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF
+ return attention_mask.unsqueeze(1)
+
+
+def transformer_weights_init(module, std_init_range=0.02, xavier=True):
+ """
+ Initialize different weights in Transformer model.
+
+ Args:
+ module: torch.nn.Module to be initialized
+ std_init_range: standard deviation of normal initializer
+ xavier: if True, xavier initializer will be used in Linear layers
+ as was proposed in AIAYN paper, otherwise normal initializer
+ will be used (like in BERT paper)
+ """
+
+ if isinstance(module, nn.Linear):
+ if xavier:
+ nn.init.xavier_uniform_(module.weight)
+ else:
+ nn.init.normal_(module.weight, mean=0.0, std=std_init_range)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0.0)
+ elif isinstance(module, nn.Embedding):
+ nn.init.normal_(module.weight, mean=0.0, std=std_init_range)
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.constant_(module.weight, 1.0)
+ nn.init.constant_(module.bias, 0.0)
+
+
+def mask_padded_tokens(tokens, pad_id):
+ mask = tokens != pad_id
+ return mask
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0542e5dce94ee7c9d9107754d8dc1c575d7c46f1
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/parts/utils.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os
+from typing import List
+
+__all__ = ['if_exist', '_compute_softmax']
+
+
+def if_exist(outfold: str, files: List[str]):
+ """
+ Returns true if all given files exist in the given folder
+ Args:
+ outfold: folder path
+ files: list of file names relative to outfold
+ """
+ if not os.path.exists(outfold):
+ return False
+ for file in files:
+ if not os.path.exists(f'{outfold}/{file}'):
+ return False
+ return True
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae6f0950d6ac003d6534fb0f8c8be731b33b30c3
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer
+from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
+from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
+from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
+from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..27674256f315838bdfa8f6f3f60444306281ff0d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/char_tokenizer.py
@@ -0,0 +1,521 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+import warnings
+from collections import Counter
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, NewType, Optional, Union
+
+from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
+
+__all__ = ['CharTokenizer']
+
+
+NUMBER_OF_CHARACTERS_READ_BUFFER_SIZE = 10 ** 7
+
+
+class SpecialTokenString(Enum):
+ MASK = 'mask'
+ BOS = 'bos'
+ EOS = 'eos'
+ PAD = 'pad'
+ SEP = 'sep'
+ CLS = 'cls'
+ UNK = 'unk'
+
+ @classmethod
+ def has_value(cls, value):
+ return value in cls._value2member_map_
+
+
+SpecialTokenStringType = NewType('SpecialTokenString', SpecialTokenString)
+
+
+class CharTokenizer(TokenizerSpec):
+ rf"""
+ Each character is a token.
+ Args:
+ vocab_file: path to file with vocabulary for a tokenizer. The file consists of valid Python string literals
+ separated by the new line character. Such literals must contain 1 character. Examples of valid Python
+ literals: ``'a'``, ``'\n'``, ``"'"``, ``'ж'``, ``'\u8976'``. Optionally the first line in the file can be a
+ JSON dictionary of special tokens. The keys of the special tokens dictionary are ``'mask_token'``,
+ ``'bos_token'`` and so on. Some special tokens names can be omitted in the special tokens dictionary line.
+ A file ``vocab_file`` has to be in ``'utf-8'`` encoding.
+ mask_token: mask token. The following is applicable to all special tokens. Parameter ``mask_token`` is used
+ for adding mask token to vocabulary or for modification of mask token present in special tokens dictionary
+ in the first line of file ``vocab_file``. Parameter ``mask_token`` can be either of type ``bool`` or a
+ ``str`` of length 1.
+
+ If ``mask_token`` is ``bool`` it has to be ``False``. If ``mask_token`` is ``True`` an exception is raised.
+ If ``mask_token`` is ``False`` and ``mask_token`` is present in special tokens dictionary in vocabulary
+ file ``vocab_file``, then ``mask_token`` is remove from special tokens dictionary.
+
+ If the parameter ``mask_token`` is a string, then such strings in the input sequence are interpreted as
+ mask tokens.
+ bos_token: the beginning of sequence token. See more in ``mask_token`` parameter description.
+ eos_token: the end of sequence token. Usually equal to sep_token. See more in ``mask_token`` parameter
+ description.
+ pad_token: token to use for padding. See more in ``mask_token`` parameter description.
+ sep_token: token used for separating sequences. See more in ``mask_token`` parameter description.
+ cls_token: class token. Usually equal to bos_token. See more in ``mask_token`` parameter description.
+ unk_token: token to use for unknown tokens. If the parameter ``unk_token`` is set and there is a character
+ in the input of ``text_to_ids`` of ``text_to_tokens`` methods which is not in the vocabulary, then
+ such an unknown character is tokenized into ``unk_token``. If the parameter ``unk_token`` is ``False``,
+ then unknown tokens are discarded. See more in ``mask_token`` parameter description.
+ special_token_to_prepend: special token to prepend to the output of ``text_to_ids`` of ``text_to_tokens``
+ methods. This option can be used if you decide to add EOS and BOS tokens to the input on the stage of
+ tokenization. Possible options are: {[None] + [e.value for e in SpecialTokenString]}.
+ special_token_to_append: special token to append to the output of ``text_to_ids`` of ``text_to_tokens``
+ methods. See more in the description of ``special_token_to_prepend`` parameter.
+ special_tokens_to_remove_while_decoding: which special tokens are remove before detokenization. If this
+ parameter equals ``'all'``, then all special tokens are removed. The parameter
+ ``special_tokens_to_remove_while_decoding`` can also be a list of values from this set
+ {set(e.value for e in SpecialTokenString)}.
+ """
+
+ def __init__(
+ self,
+ vocab_file: str,
+ mask_token: Optional[Union[str, bool]] = None,
+ bos_token: Optional[Union[str, bool]] = None,
+ eos_token: Optional[Union[str, bool]] = None,
+ pad_token: Optional[Union[str, bool]] = None,
+ sep_token: Optional[Union[str, bool]] = None,
+ cls_token: Optional[Union[str, bool]] = None,
+ unk_token: Optional[Union[str, bool]] = None,
+ special_token_to_prepend: Optional[SpecialTokenStringType] = None,
+ special_token_to_append: Optional[SpecialTokenStringType] = None,
+ special_tokens_to_remove_while_decoding: Union[List[SpecialTokenStringType], str] = 'all',
+ ):
+ vocab_file = Path(vocab_file).expanduser()
+ with vocab_file.open(encoding='utf-8') as f:
+ first_line = f.readline()
+ if first_line[0] == '{':
+ special_tokens_dict = json.loads(first_line)
+ self.check_special_tokens_dict_from_file(special_tokens_dict, vocab_file)
+ vocab_list = f.readlines()
+ else:
+ special_tokens_dict = {}
+ vocab_list = [first_line] + f.readlines()
+ special_tokens_dict = self.update_special_tokens_dict(
+ special_tokens_dict, mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token
+ )
+ for e in SpecialTokenString:
+ name = e.value + '_token'
+ setattr(self, name, special_tokens_dict[name] if name in special_tokens_dict else None)
+ for k, v in special_tokens_dict.items():
+ setattr(self, k, v)
+ for value, name in [
+ (special_token_to_prepend, 'special_token_to_prepend'),
+ (special_token_to_append, 'special_token_to_append'),
+ ]:
+ self.check_special_token_name(name, value, special_tokens_dict)
+ setattr(self, name, value + '_token' if isinstance(value, str) else value)
+ self.vocab = {}
+ count = 0
+ for v in special_tokens_dict.values():
+ self.vocab[v] = count
+ count += 1
+ for i, token in enumerate(vocab_list):
+ token = eval(token.strip())
+ self.check_token_from_file(token, vocab_file, i)
+ if token not in self.vocab:
+ self.vocab[token] = count
+ count += 1
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
+ self.vocab_size = len(self.vocab)
+ self.check_special_tokens_to_remove_while_decoding(
+ special_tokens_to_remove_while_decoding, special_tokens_dict
+ )
+ self.special_token_ids_to_remove_while_decoding = (
+ self.tokens_to_ids([v for v in special_tokens_dict.values()])
+ if special_tokens_to_remove_while_decoding == 'all'
+ else [getattr(self, e + '_id') for e in special_tokens_to_remove_while_decoding]
+ )
+
+ @classmethod
+ def check_special_tokens_dict_from_file(cls, special_tokens_dict, vocab_file):
+ for k, v in special_tokens_dict.items():
+ if k[-6:] != '_token' or not SpecialTokenString.has_value(k[:-6]):
+ raise ValueError(
+ f"Unsupported key {repr(k)} in special tokens dictionary in vocabulary file {vocab_file} "
+ f"(first line). Supported keys are {[e.value + '_token' for e in SpecialTokenString]}."
+ )
+ if not isinstance(v, str):
+ raise ValueError(
+ f"Values of special tokens dictionary in vocabulary file {vocab_file} (first line) has to belong "
+ f"to type `str`, whereas type of item '{k}' value {repr(v)} is `{type(v)}`."
+ )
+ elif len(v) == 0:
+ raise ValueError(
+ f"Values of special tokens dictionary in vocabulary file {vocab_file} (first line) has to not "
+ f"empty strings, whereas value of item '{k}' is an empty string."
+ )
+ cls.check_special_tokens_dict_for_duplicate_values(
+ special_tokens_dict, f"Loaded from vocabulary file {vocab_file}"
+ )
+
+ @staticmethod
+ def check_special_tokens_dict_for_duplicate_values(special_tokens_dict, err_msg_prefix):
+ if len(special_tokens_dict) != len(set(special_tokens_dict.values())):
+ tokens_with_equal_values = []
+ duplicate_values = []
+ for k, v in list(reversed(list(special_tokens_dict.items())))[:-1]:
+ tokens = [k]
+ for kk, vv in special_tokens_dict.items():
+ if kk == k:
+ break
+ if v == vv:
+ tokens.append(kk)
+ if len(tokens) > 1:
+ duplicate_values.append(v)
+ tokens_with_equal_values.append(tokens)
+ if duplicate_values:
+ dup_values_msg = '. '.join(
+ [f"Tokens {t} have value '{v}'" for t, v in zip(tokens_with_equal_values, duplicate_values)]
+ )
+ raise ValueError(
+ err_msg_prefix + f" special tokens dictionary has duplicate values. " + dup_values_msg
+ )
+
+ @classmethod
+ def update_special_tokens_dict(
+ cls,
+ init_special_tokens_dict: Dict[str, str],
+ mask_token: Optional[Union[str, bool]] = None,
+ bos_token: Optional[Union[str, bool]] = None,
+ eos_token: Optional[Union[str, bool]] = None,
+ pad_token: Optional[Union[str, bool]] = None,
+ sep_token: Optional[Union[str, bool]] = None,
+ cls_token: Optional[Union[str, bool]] = None,
+ unk_token: Optional[Union[str, bool]] = None,
+ ):
+ special_tokens_dict = init_special_tokens_dict.copy()
+ for value, name in zip(
+ [pad_token, unk_token, bos_token, eos_token, sep_token, mask_token, cls_token],
+ ['pad_token', 'unk_token', 'bos_token', 'eos_token', 'sep_token', 'mask_token', 'cls_token'],
+ ):
+ if value is not None:
+ if isinstance(value, bool):
+ if value:
+ raise ValueError(
+ f"If `CharTokenizer` constructor parameter `{name}` is `bool` it has to be `False`"
+ )
+ else:
+ if name in special_tokens_dict:
+ del special_tokens_dict[name]
+ else:
+ warnings.warn(
+ f"Cannot remove special token `{name}` since it is not in special tokens dictionary "
+ f"{special_tokens_dict}."
+ )
+ elif not isinstance(value, str):
+ raise ValueError(
+ f"`CharTokenizer` constructor parameter `{name}` has to be either `False` or belong to type "
+ f"`str`, whereas type of `{name}` is `{type(value)}`."
+ )
+ else:
+ special_tokens_dict[name] = value
+ cls.check_special_tokens_dict_for_duplicate_values(
+ special_tokens_dict,
+ "After updating special tokens dictionary with tokens passed in `CharTokenizer` constructor parameters",
+ )
+ return special_tokens_dict
+
+ @staticmethod
+ def check_token_from_file(token, vocab_file, line_i):
+ if not isinstance(token, str) or isinstance(token, str) and len(token) != 1:
+ raise ValueError(
+ f"Each line in vocabulary have to be a Python string literal containing 1 character. "
+ f"Encountered {repr(token)} on line {line_i} in file {vocab_file}."
+ )
+
+ @staticmethod
+ def check_special_token_name(parameter_name, value, special_tokens_dict):
+ if value is not None:
+ if not SpecialTokenString.has_value(value):
+ raise ValueError(
+ f"Value {repr(value)} of parameter `{parameter_name}` is wrong. Supported values are "
+ f"{[e.value for e in SpecialTokenString]}."
+ )
+ elif value + '_token' not in special_tokens_dict:
+ raise ValueError(
+ f"You should provide `{value + '_token'}` parameter to `CharTokenizer` constructor if "
+ f"you wish to pass token {repr(value)} in parameter `{parameter_name}`."
+ )
+
+ @staticmethod
+ def check_special_tokens_to_remove_while_decoding(special_tokens_to_remove_while_decoding, special_tokens_dict):
+ if isinstance(special_tokens_to_remove_while_decoding, list):
+ for i, value in enumerate(special_tokens_to_remove_while_decoding):
+ if not SpecialTokenString.has_value(value):
+ raise ValueError(
+ f'Wrong element with value {repr(value)} in position {i} of parameter '
+ f'`special_tokens_to_remove_while_decoding` of `CharTokenizer` constructor. Supported values '
+ f'are {[e.value for e in SpecialTokenString]}.'
+ )
+ elif value + '_token' not in special_tokens_dict:
+ raise ValueError(
+ f"You should provide `{value + '_token'}` parameter to `CharTokenizer` constructor if "
+ f"you wish to pass token {repr(value)} in parameter `special_tokens_to_remove_while_decoding`. "
+ f"`{value + '_token'}` was detected in position {i} in "
+ f"`special_tokens_to_remove_while_decoding`."
+ )
+ elif (
+ isinstance(special_tokens_to_remove_while_decoding, str)
+ and special_tokens_to_remove_while_decoding != 'all'
+ or not isinstance(special_tokens_to_remove_while_decoding, str)
+ ):
+ raise ValueError(
+ f"Parameter `special_tokens_to_remove_while_decoding` of `CharTokenizer` constructor has to be "
+ f"equal to a string 'all' or be a list of values from set {set(e.value for e in SpecialTokenString)} "
+ f"whereas `special_tokens_to_remove_while_decoding={repr(special_tokens_to_remove_while_decoding)}`"
+ )
+
+ def text_to_tokens(self, text: str) -> List[str]:
+ token_candidates = [char for char in text]
+ tokens = []
+ if self.special_token_to_prepend is not None:
+ tokens.append(getattr(self, self.special_token_to_prepend))
+ for i, token in enumerate(token_candidates):
+ if token in self.vocab:
+ tokens.append(token)
+ elif self.unk_token is not None:
+ tokens.append(self.unk_token)
+ else:
+ warnings.warn(
+ f"Character {repr(token)} in position {i} is not present in vocabulary and no `` token was "
+ f"set. Character {repr(token)} is discarded."
+ )
+ if self.special_token_to_append is not None:
+ tokens.append(getattr(self, self.special_token_to_append))
+ return tokens
+
+ def tokens_to_text(self, tokens: List[str]) -> str:
+ return self.ids_to_text(self.tokens_to_ids(tokens))
+
+ def text_to_ids(self, text: str) -> List[int]:
+ ids = [self.vocab[token] for token in self.text_to_tokens(text)]
+ return ids
+
+ def ids_to_text(self, ids: List[int]) -> str:
+ ids_ = [id_ for id_ in ids if id_ not in self.special_token_ids_to_remove_while_decoding]
+ return "".join(self.ids_to_tokens(ids_))
+
+ def tokens_to_ids(self, tokens: List[str]) -> List[int]:
+ return [self.vocab[token] for token in tokens]
+
+ def token_to_id(self, token: str) -> int:
+ return self.vocab[token]
+
+ def ids_to_tokens(self, ids: List[int]) -> List[str]:
+ return [self.inv_vocab[id] for id in ids]
+
+ @staticmethod
+ def check_special_token_id_getting(special_token, id_name):
+ if special_token is None:
+ token_param = id_name[:-3] + '_token'
+ raise ValueError(
+ f"Cannot return `{id_name}` since `{token_param}` is not set. To obtain `{id_name}` you need to pass "
+ f"parameter `{token_param}` to `CharTokenizer` constructor."
+ )
+
+ @property
+ def pad_id(self):
+ self.check_special_token_id_getting(self.pad_token, 'pad_id')
+ return self.vocab[self.pad_token]
+
+ @property
+ def bos_id(self):
+ self.check_special_token_id_getting(self.bos_token, 'bos_id')
+ return self.vocab[self.bos_token]
+
+ @property
+ def eos_id(self):
+ self.check_special_token_id_getting(self.eos_token, 'eos_id')
+ return self.vocab[self.eos_token]
+
+ @property
+ def unk_id(self):
+ self.check_special_token_id_getting(self.unk_token, 'unk_id')
+ return self.vocab[self.unk_token]
+
+ @property
+ def mask_id(self):
+ self.check_special_token_id_getting(self.mask_token, 'mask_id')
+ return self.vocab[self.mask_token]
+
+ @property
+ def sep_id(self):
+ self.check_special_token_id_getting(self.sep_token, 'sep_id')
+ return self.vocab[self.sep_token]
+
+ @property
+ def cls_id(self):
+ self.check_special_token_id_getting(self.cls_token, 'cls_id')
+ return self.vocab[self.cls_token]
+
+ @staticmethod
+ def create_special_tokens_dict(
+ mask_token: Optional[str] = None,
+ bos_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ sep_token: Optional[str] = None,
+ cls_token: Optional[str] = None,
+ unk_token: Optional[str] = None,
+ ):
+ special_tokens_dict = {}
+ for value, name in zip(
+ [pad_token, unk_token, bos_token, eos_token, sep_token, mask_token, cls_token],
+ ['pad_token', 'unk_token', 'bos_token', 'eos_token', 'sep_token', 'mask_token', 'cls_token'],
+ ):
+ if value is not None:
+ if not isinstance(value, str):
+ raise ValueError(
+ f"The type of parameter `{name}` has to be `None` or `str`, found `{type(value)}`"
+ )
+ elif len(value) == 0:
+ raise ValueError(f"If the parameter `{name}` is `str`, then its length has to be nonzero.")
+ elif value in special_tokens_dict.values():
+ other_name = None
+ for k, v in special_tokens_dict.items():
+ if v == value:
+ other_name = k
+ raise ValueError(
+ f"The value {repr(value)} of special token `{name}` is the same as the value of special token "
+ f"`{other_name}`."
+ )
+ special_tokens_dict[name] = value
+ return special_tokens_dict
+
+ @staticmethod
+ def check_characters_to_exclude_from_vocabulary(characters_to_exclude_from_vocabulary):
+ for i, char in enumerate(characters_to_exclude_from_vocabulary):
+ if not isinstance(char, str):
+ raise ValueError(
+ f"Character to exclude from vocabulary has to `str`, whereas an element in position {i} is of "
+ f"type `{type(char)}`."
+ )
+ elif len(char) != 1:
+ raise ValueError(
+ f"A length of an element of `characters_to_exclude_from_vocabulary` parameter has to be 1. "
+ f"The length of an element in position {i} is {len(char)}."
+ )
+
+ @staticmethod
+ def check_text_and_text_file_name(text, text_file_name):
+ if text is None and text_file_name is None:
+ raise ValueError(
+ f'Exactly one of parameters `text` and `text_file_name` should be provided whereas both parameters '
+ f'are `None`.'
+ )
+ if text is not None and text_file_name is not None:
+ raise ValueError(
+ f"Exactly one of parameters `text` and `text_file_name` has to be provided, whereas both parameters "
+ f"are not `None`."
+ )
+ if text is not None:
+ if not isinstance(text, str):
+ raise ValueError(
+ f"Parameter `text` has to be of type `str`, whereas it belongs to type `{type(text)}`."
+ )
+
+ @classmethod
+ def build_vocab(
+ cls,
+ save_path: Union[str, bytes, os.PathLike],
+ text: Optional[str] = None,
+ text_file_name: Optional[Union[str, bytes, os.PathLike]] = None,
+ characters_to_exclude: Optional[List[str]] = None,
+ vocab_size: int = None,
+ mask_token: Optional[str] = None,
+ bos_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ sep_token: Optional[str] = None,
+ cls_token: Optional[str] = None,
+ unk_token: Optional[str] = None,
+ ):
+ """
+ Creates character vocabulary and saves it to file ``save_path``. You should provide one of parameters ``text``
+ and ``text_file_name``. The format of created character vocabulary file is following:
+ ```
+ {['mask_token': "ANY NON EMPTY STRING", ]['bos_token': "ANY NON EMPTY STRING", ] and so on}
+ ' '
+ 'e'
+ ...
+ ```
+ The first line is a JSON which contains special tokens. This special token are set using parameters
+ ``mas_token``, ``bos_token``, ``eos_token``, ``pad_token``, ``sep_token``, ``cls_token``, ``unk_token``.
+ Other lines in created vocabulary file are Python string literals containing one character each.
+
+ Args:
+ save_path: path to the output text file. If ``save_path`` parent directory does not exist it will be created
+ text: string which characters are used for vocabulary creation.
+ text_file_name: path to a file which characters are used for vocabulary creation. Use this parameter if
+ the text in file is too large to be loaded in memory.
+ characters_to_exclude: a list of characters which will not be added to vocabulary.
+ vocab_size: vocabulary size. If this parameter is set only most frequent ``vocab_size`` characters are added
+ to vocabulary.
+ mask_token: mask token
+ bos_token: the beginning of sequence token
+ eos_token: the end of sequence token. Usually equal to sep_token.
+ pad_token: token to use for padding.
+ sep_token: token used for separating sequences.
+ cls_token: class token. Usually equal to bos_token.
+ unk_token: token to use for unknown tokens. If the parameter ``unk_token`` is set and there is a character
+ in the input of ``text_to_ids`` of ``text_to_tokens`` methods which is not in the vocabulary, then
+ such an unknown character is tokenized into ``unk_token``. If the parameter ``unk_token`` is ``False``,
+ then unknown tokens are discarded.
+ """
+ special_tokens_dict = cls.create_special_tokens_dict(
+ mask_token, bos_token, eos_token, pad_token, sep_token, cls_token, unk_token
+ )
+ if characters_to_exclude is None:
+ characters_to_exclude = []
+ else:
+ cls.check_characters_to_exclude_from_vocabulary(characters_to_exclude)
+ cls.check_text_and_text_file_name(text, text_file_name)
+ if text is not None:
+ counter = Counter(text)
+ else:
+ assert text_file_name is not None
+ text_file_name = Path(text_file_name).expanduser()
+ counter = Counter()
+ with text_file_name.open(encoding='utf-8') as f:
+ while True:
+ segment = f.read(NUMBER_OF_CHARACTERS_READ_BUFFER_SIZE)
+ if not segment:
+ break
+ counter.update(segment)
+ for char in characters_to_exclude:
+ if char in counter:
+ del counter[char]
+ save_path = Path(save_path).expanduser()
+ save_path.parent.mkdir(exist_ok=True, parents=True)
+ with save_path.open('w', encoding='utf-8') as f:
+ f.write(json.dumps(special_tokens_dict) + '\n')
+ if vocab_size is None:
+ for c, _ in sorted(counter.items(), key=lambda x: -x[1]):
+ f.write(repr(c) + '\n')
+ else:
+ vocab_size -= len(special_tokens_dict)
+ for i, (c, _) in enumerate(sorted(counter.items(), key=lambda x: -x[1])):
+ if i < vocab_size:
+ f.write(repr(c) + '\n')
+ else:
+ break
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa886ca6b698303f208ada7caa57fa42894c210a
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/g2p_table_tokenizer.py
@@ -0,0 +1,106 @@
+import random
+from typing import List
+
+from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
+
+
+class G2PTableTokenizer(TokenizerSpec):
+ def __init__(self, vocab, g2p_table, unk_label=''):
+ self.vocab = vocab
+ self.vocab_size = len(self.vocab)
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
+
+ self.unk_label = unk_label
+ self.unk_id = self.vocab[unk_label]
+
+ self.g2p_table = g2p_table
+
+ self.g2p_id_table = {}
+ for word, phone_seqs in self.g2p_table.items():
+ phone_id_seqs = []
+ for phone_seq_i in phone_seqs:
+ phone_seq_id_i = tuple(self.vocab[phone] for phone in phone_seq_i)
+ assert phone_seq_id_i not in phone_id_seqs
+ phone_id_seqs.append(phone_seq_id_i)
+ assert len(phone_id_seqs) > 0
+ self.g2p_id_table[word] = phone_id_seqs
+
+ @classmethod
+ def load(cls, vocab_fp, lexicon_fp):
+ vocab = load_vocab(vocab_fp)
+
+ g2p_table = {}
+ with open(lexicon_fp, encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ word, phone_seq = line.split('\t')
+ if word not in g2p_table:
+ g2p_table[word] = []
+ phone_seq = tuple(phone_seq.split(' '))
+ if phone_seq in g2p_table[word]:
+ print('found duplicated mapping: ', line)
+ else:
+ assert len(phone_seq) > 0
+ g2p_table[word].append(phone_seq)
+ return cls(vocab, g2p_table)
+
+ def text_to_tokens(self, text: str) -> List[str]:
+ text = text.strip()
+ words = text.split()
+ tokens = []
+ for word_i in words:
+ phones_list = self.g2p_table.get(word_i)
+ if phones_list:
+ if len(phones_list) == 1:
+ tokens.extend(phones_list[0])
+ else:
+ tokens.extend(random.choice(phones_list))
+ else:
+ tokens.append(self.unk_label)
+ return tokens
+
+ def tokens_to_text(self, tokens: List[str]) -> str:
+ return ' '.join(tokens)
+
+ def text_to_ids(self, text: str) -> List[int]:
+ words = text.split()
+ ids = []
+ for word_i in words:
+ phone_ids_list = self.g2p_id_table.get(word_i)
+ if phone_ids_list:
+ if len(phone_ids_list) == 1:
+ ids.extend(phone_ids_list[0])
+ else:
+ ids.extend(random.choice(phone_ids_list))
+ else:
+ print('found unk: ', text)
+ ids.append(self.unk_id)
+ return ids
+
+ def ids_to_text(self, ids: List[int]) -> str:
+ ids_ = [id_ for id_ in ids]
+ return ' '.join(self.ids_to_tokens(ids_))
+
+ def tokens_to_ids(self, tokens: List[str]) -> List[int]:
+ return [self.vocab[token] for token in tokens]
+
+ def token_to_id(self, token: str) -> int:
+ return self.vocab[token]
+
+ def ids_to_tokens(self, ids: List[int]) -> List[str]:
+ return [self.inv_vocab[id] for id in ids]
+
+
+def load_vocab(path, encoding='utf-8'):
+ vocab_dict = {}
+ with open(path, encoding=encoding) as f:
+ idx = 0
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ token = line.split('\t')[0]
+ assert token not in vocab_dict
+ vocab_dict[token] = idx
+ idx += 1
+ return vocab_dict
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc120ba925c87717b1ce1f4fe52f259e7f52122
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+from typing import Dict, List, Optional, Union
+
+import sentencepiece
+
+from nemo.collections.common.parts.utils import if_exist
+from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
+from nemo.utils import logging
+
+__all__ = ['SentencePieceTokenizer', 'create_spt_model']
+
+
+class SentencePieceTokenizer(TokenizerSpec):
+ '''
+ Sentencepiecetokenizer https://github.com/google/sentencepiece.
+ '''
+
+ def __init__(self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None):
+ """
+ Args:
+ model_path: path to sentence piece tokenizer model. To create the model use create_spt_model()
+ special_tokens: either list of special tokens or dictionary of token name to token value
+ """
+ if not model_path or not os.path.exists(model_path):
+ raise ValueError(f"model_path: {model_path} is invalid")
+ self.tokenizer = sentencepiece.SentencePieceProcessor()
+ self.tokenizer.Load(model_path)
+ self.original_vocab_size = self.tokenizer.get_piece_size()
+ self.vocab_size = self.tokenizer.get_piece_size()
+ self.special_token_to_id = {}
+ self.id_to_special_token = {}
+ if special_tokens:
+ self.add_special_tokens(special_tokens)
+
+ def text_to_tokens(self, text):
+ tokens = []
+ idx = 0
+ last_idx = 0
+
+ while 1:
+ indices = {}
+
+ for token in self.special_token_to_id:
+ try:
+ indices[token] = text[idx:].index(token)
+ except ValueError:
+ continue
+
+ if len(indices) == 0:
+ break
+
+ next_token = min(indices, key=indices.get)
+ next_idx = idx + indices[next_token]
+
+ tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx]))
+ tokens.append(next_token)
+ idx = next_idx + len(next_token)
+
+ tokens.extend(self.tokenizer.encode_as_pieces(text[idx:]))
+ return tokens
+
+ def text_to_ids(self, text, nbest_size=None, alpha=None):
+ ids = []
+ idx = 0
+ last_idx = 0
+
+ while 1:
+ indices = {}
+
+ for token in self.special_token_to_id:
+ try:
+ indices[token] = text[idx:].index(token)
+ except ValueError:
+ continue
+
+ if len(indices) == 0:
+ break
+
+ next_token = min(indices, key=indices.get)
+ next_idx = idx + indices[next_token]
+
+ if nbest_size is not None:
+ ids.extend(self.tokenizer.sample_encode_as_ids(text[idx:next_idx], nbest_size=nbest_size, alpha=alpha))
+ else:
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
+ ids.append(self.special_token_to_id[next_token])
+ idx = next_idx + len(next_token)
+
+ if nbest_size is not None:
+ ids.extend(self.tokenizer.sample_encode_as_ids(text[idx:], nbest_size=nbest_size, alpha=alpha))
+ else:
+ ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
+ return ids
+
+ def tokens_to_text(self, tokens):
+ return self.tokenizer.decode_pieces(tokens)
+
+ def ids_to_text(self, ids):
+ text = ""
+ last_i = 0
+
+ for i, id in enumerate(ids):
+ if id in self.id_to_special_token:
+ text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
+ text += self.id_to_special_token[id] + " "
+ last_i = i + 1
+
+ text += self.tokenizer.decode_ids(ids[last_i:])
+ return text.strip()
+
+ def token_to_id(self, token):
+ if token in self.special_token_to_id:
+ return self.special_token_to_id[token]
+ return self.tokenizer.piece_to_id(token)
+
+ def ids_to_tokens(self, ids):
+ tokens = []
+ for id in ids:
+ if id >= self.original_vocab_size:
+ tokens.append(self.id_to_special_token[id])
+ else:
+ tokens.append(self.tokenizer.id_to_piece(id))
+ return tokens
+
+ def tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
+ if isinstance(tokens, str):
+ tokens = [tokens]
+ ids = []
+ for token in tokens:
+ ids.append(self.token_to_id(token))
+ return ids
+
+ def add_special_tokens(self, special_tokens):
+ if isinstance(special_tokens, list):
+ for token in special_tokens:
+ if (
+ self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id()
+ and token not in self.special_token_to_id
+ ):
+ self.special_token_to_id[token] = self.vocab_size
+ self.id_to_special_token[self.vocab_size] = token
+ self.vocab_size += 1
+ elif isinstance(special_tokens, dict):
+ for token_name, token in special_tokens.items():
+ setattr(self, token_name, token)
+ if (
+ self.tokenizer.piece_to_id(token) == self.tokenizer.unk_id()
+ and token not in self.special_token_to_id
+ ):
+ self.special_token_to_id[token] = self.vocab_size
+ self.id_to_special_token[self.vocab_size] = token
+ self.vocab_size += 1
+
+ @property
+ def pad_id(self):
+ return self.tokens_to_ids([self.pad_token])[0]
+
+ @property
+ def bos_id(self):
+ return self.tokens_to_ids([self.bos_token])[0]
+
+ @property
+ def eos_id(self):
+ return self.tokens_to_ids([self.eos_token])[0]
+
+ @property
+ def sep_id(self):
+ return self.tokens_to_ids([self.sep_token])[0]
+
+ @property
+ def cls_id(self):
+ return self.tokens_to_ids([self.cls_token])[0]
+
+
+def create_spt_model(
+ data_file: str,
+ vocab_size: int,
+ sample_size: int,
+ do_lower_case: bool,
+ tokenizer_type: str = 'unigram',
+ output_dir: Optional[str] = None,
+ character_coverage: float = 1.0,
+):
+ """
+ Creates sentence piece tokenizer model from data file.
+ Args:
+ data_file: data file
+ vocab_size: vocabulary size
+ sample_size: maximum size of sentences the trainer loads
+ do_lower_case: if text should be lower cased before tokenizer model is created
+ character_coverage: float value between 0 and 1 (as a percentage). For languages with a vast charset,
+ can be < 1.0, but for all other languages, it should be set as 1.0
+ output_dir: folder to save created tokenizer model. If not specified will store model at data_file/../spt folder
+ """
+
+ if not data_file or not os.path.exists(data_file):
+ raise ValueError(f"data_file must be valid file path, but got {data_file}")
+ data_dir = os.path.dirname(data_file)
+ vocab = []
+ if not output_dir:
+ output_dir = f'{data_dir}/spt'
+ if if_exist(output_dir, ['tokenizer.model']):
+ logging.info(f"tokenizer model {output_dir}/tokenizer.model already exists")
+ return f'{output_dir}/tokenizer.model', f'{output_dir}/vocab.txt'
+ logging.info(f'Processing {data_file} and store at {output_dir}')
+ os.makedirs(output_dir, exist_ok=True)
+
+ cmd = (
+ f"--input={data_file} --model_prefix={output_dir}/tokenizer "
+ f"--vocab_size={vocab_size} "
+ f"--shuffle_input_sentence=true --hard_vocab_limit=false "
+ f"--model_type={tokenizer_type} "
+ f"--character_coverage={character_coverage} "
+ f"--bos_id=-1 --eos_id=-1"
+ )
+ if do_lower_case:
+ cmd += " --normalization_rule_name=nmt_nfkc_cf"
+
+ if sample_size > 0:
+ cmd += f" --input_sentence_size={sample_size}"
+
+ sentencepiece.SentencePieceTrainer.Train(cmd)
+
+ # Add BERT control symbols
+ tokens = []
+
+ with open(f"{output_dir}/tokenizer.vocab", "r") as f:
+ f.readline() # skip first token
+
+ # Read tokens from each line and parse for vocab
+ for line in f:
+ piece = line.split("\t")[0]
+ token = piece[1:] if piece.startswith("ā") else f"##{piece}"
+
+ if len(token) > 0:
+ tokens.append(token)
+ else:
+ tokens.append(piece[0])
+
+ vocab.extend(tokens)
+
+ # Save vocabulary to output file
+ vocab_file = f'{output_dir}/vocab.txt'
+ with open(vocab_file, "w") as f:
+ for token in vocab:
+ f.write(f"{token}\n")
+ return f'{output_dir}/tokenizer.model', vocab_file
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..252571d76ef205e8f29283c9143eccde11a3d6f8
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/tokenizer_spec.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from typing import List
+
+__all__ = ['TokenizerSpec']
+
+
+class TokenizerSpec(ABC):
+ """
+ Inherit this class to implement a new tokenizer.
+ """
+
+ @abstractmethod
+ def text_to_tokens(self, text):
+ pass
+
+ @abstractmethod
+ def tokens_to_text(self, tokens):
+ pass
+
+ @abstractmethod
+ def tokens_to_ids(self, tokens):
+ pass
+
+ @abstractmethod
+ def ids_to_tokens(self, ids):
+ pass
+
+ @abstractmethod
+ def text_to_ids(self, text):
+ pass
+
+ @abstractmethod
+ def ids_to_text(self, ids):
+ pass
+
+ def add_special_tokens(self, special_tokens: List[str]):
+ raise NotImplementedError("To be implemented")
+
+ @property
+ def name(self):
+ return type(self).__name__
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3431af9d734707f63c8d68f38a090cf9f76478b
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/collections/common/tokenizers/word_tokenizer.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer
+
+__all__ = ['WordTokenizer']
+
+
+class WordTokenizer(CharTokenizer):
+ "Tokenizes at word boundary"
+
+ def __init__(
+ self,
+ vocab_file: str,
+ mask_token: Optional[str] = None,
+ bos_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ sep_token: Optional[str] = None,
+ cls_token: Optional[str] = None,
+ unk_token: Optional[str] = None,
+ ):
+ """
+ Args:
+ vocab_file: path to file with vocabulary which consists
+ of characters separated by \n
+ mask_token: mask token
+ bos_token: the beginning of sequence token
+ eos_token: the end of sequence token. Usually equal to sep_token
+ pad_token: token to use for padding
+ sep_token: token used for separating sequences
+ cls_token: class token. Usually equal to bos_token
+ unk_token: token to use for unknown tokens
+ """
+
+ super().__init__(
+ vocab_file=vocab_file,
+ mask_token=mask_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ )
+
+ def text_to_tokens(self, text):
+ token_candidates = text.strip().split()
+ tokens = []
+ for token in token_candidates:
+ if token in self.vocab:
+ tokens.append(token)
+ else:
+ tokens.append(self.unk_token)
+ return tokens
+
+ def ids_to_text(self, ids):
+ ids_ = [id_ for id_ in ids if id_ not in self.special_tokens]
+ return " ".join(self.ids_to_tokens(ids_))
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..f678ea13a2fbfc3c0efa68d180a72f1e9e87e44c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/constants.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+NEMO_ENV_VARNAME_ENABLE_COLORING = "NEMO_ENABLE_COLORING"
+NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR = "NEMO_REDIRECT_LOGS_TO_STDERR"
+NEMO_ENV_VARNAME_TESTING = "NEMO_TESTING" # Set to True to enable nemo.util.logging's debug mode
+NEMO_ENV_VARNAME_VERSION = "NEMO_EXPM_VERSION" # Used for nemo.utils.exp_manager versioning
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d13c9efd936f2d6e602e486943829f5b3b613c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import nemo.core.neural_types
+from nemo.core.classes import *
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9226c0ea84ff8c3d5926af6f753589ed94370b2
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nemo.core.classes.common import FileIO, Model, Serialization, Typing, is_typecheck_enabled, typecheck
+from nemo.core.classes.dataset import Dataset, IterableDataset
+from nemo.core.classes.exportable import Exportable, ExportFormat
+from nemo.core.classes.loss import Loss
+from nemo.core.classes.modelPT import ModelPT
+from nemo.core.classes.module import NeuralModule
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0aa82481ccb739a3556f51eb07eea37985b4d94
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/common.py
@@ -0,0 +1,553 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Interfaces common to all Neural Modules and Models."""
+import hashlib
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import hydra
+import wrapt
+from omegaconf import DictConfig, OmegaConf
+
+import nemo
+from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
+from nemo.utils import logging
+from nemo.utils.cloud import maybe_download_from_cloud
+from nemo.utils.model_utils import maybe_update_config_version
+
+__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck']
+
+
+_TYPECHECK_ENABLED = True
+
+
+def is_typecheck_enabled():
+ """
+ Getter method for typechecking state.
+ """
+ return _TYPECHECK_ENABLED
+
+
+class Typing(ABC):
+ """
+ An interface which endows module with neural types
+ """
+
+ @property
+ def input_types(self) -> Optional[Dict[str, NeuralType]]:
+ """Define these to enable input neural type checks"""
+ return None
+
+ @property
+ def output_types(self) -> Optional[Dict[str, NeuralType]]:
+ """Define these to enable output neural type checks"""
+ return None
+
+ def _validate_input_types(self, input_types=None, **kwargs):
+ """
+ This function does a few things.
+ 1) It ensures that len(self.input_types ) <= len(kwargs) <= len(self.input_types).
+ 2) For each (keyword name, keyword value) passed as input to the wrapped function:
+ - Check if the keyword name exists in the list of valid self.input_types names.
+ - Check if keyword value has the `neural_type` property.
+ - If it does, then perform a comparative check and assert that neural types
+ are compatible (SAME or GREATER).
+ - Check if keyword value is a container type (list or tuple). If yes,
+ then perform the elementwise test of neural type above on each element
+ of the nested structure, recursively.
+
+ Args:
+ input_types: Either the `input_types` defined at class level, or the local function
+ overridden type definition.
+ kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped
+ function upon call.
+ """
+ # TODO: Properly implement this
+ if input_types is not None:
+ total_input_types = len(input_types)
+ mandatory_input_types = len(
+ [type_val for type_key, type_val in input_types.items() if not type_val.optional]
+ )
+
+ if len(kwargs) < mandatory_input_types or len(kwargs) > total_input_types:
+ raise TypeError(
+ f"Number of input arguments provided ({len(kwargs)}) is not as expected. Function has "
+ f"{total_input_types} total inputs with {mandatory_input_types} mandatory inputs."
+ )
+
+ for key, value in kwargs.items():
+ # Check if keys exists in the defined input types
+ if key not in input_types:
+ raise TypeError(
+ f"Input argument {key} has no corresponding input_type match. "
+ f"Existing input_types = {input_types.keys()}"
+ )
+
+ # Perform neural type check
+ if hasattr(value, 'neural_type') and not input_types[key].compare(value.neural_type) in (
+ NeuralTypeComparisonResult.SAME,
+ NeuralTypeComparisonResult.GREATER,
+ ):
+ error_msg = [
+ f"{input_types[key].compare(value.neural_type)} :",
+ f"Input type expected : {input_types[key]}",
+ f"Input type found : {value.neural_type}",
+ ]
+ for i, dict_tuple in enumerate(input_types[key].elements_type.type_parameters.items()):
+ error_msg.insert(i + 2, f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}')
+ for i, dict_tuple in enumerate(value.neural_type.elements_type.type_parameters.items()):
+ error_msg.append(f' input param_{i} : {dict_tuple[0]}: {dict_tuple[1]}')
+ raise TypeError("\n".join(error_msg))
+
+ # Perform input ndim check
+ if hasattr(value, 'shape'):
+ value_shape = value.shape
+ type_shape = input_types[key].axes
+ name = key
+
+ if type_shape is not None and len(value_shape) != len(type_shape):
+ raise TypeError(
+ f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n"
+ f"Input shape expected = {input_types[key].axes} | \n"
+ f"Input shape found : {value_shape}"
+ )
+
+ # Perform recursive neural type check for homogeneous elements
+ elif isinstance(value, list) or isinstance(value, tuple):
+ for ind, val in enumerate(value):
+ self.__check_neural_type(val, input_types[key], name=key)
+
+ def _attach_and_validate_output_types(self, out_objects, output_types=None):
+ """
+ This function does a few things.
+ 1) It ensures that len(out_object) == len(self.output_types).
+ 2) If the output is a tensor (or list/tuple of list/tuple ... of tensors), it
+ attaches a neural_type to it. For objects without the neural_type attribute,
+ such as python objects (dictionaries and lists, primitive data types, structs),
+ no neural_type is attached.
+
+ Note: tensor.neural_type is only checked during _validate_input_types which is
+ called prior to forward().
+
+ Args:
+ output_types: Either the `output_types` defined at class level, or the local function
+ overridden type definition.
+ out_objects: The outputs of the wrapped function.
+ """
+ # TODO: Properly implement this
+ if output_types is not None:
+ out_types_list = list(output_types.items())
+
+ # First convert all outputs to list/tuple format to check correct number of outputs
+ if type(out_objects) in (list, tuple):
+ out_container = out_objects
+ else:
+ out_container = [out_objects]
+
+ if len(output_types) != len(out_container):
+ raise TypeError(
+ "Number of output arguments provided ({}) is not as expected ({})".format(
+ len(out_container), len(output_types)
+ )
+ )
+
+ # Attach types recursively, if possible
+ if not isinstance(out_objects, tuple) and not isinstance(out_objects, list):
+ try:
+ out_objects.neural_type = out_types_list[0][1]
+ except Exception:
+ pass
+
+ # Perform output ndim check
+ if hasattr(out_objects, 'shape'):
+ value_shape = out_objects.shape
+ type_shape = out_types_list[0][1].axes
+ name = out_types_list[0][0]
+
+ if type_shape is not None and len(value_shape) != len(type_shape):
+ raise TypeError(
+ f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n"
+ f"Output shape expected = {type_shape} | \n"
+ f"Output shape found : {value_shape}"
+ )
+ else:
+ for ind, res in enumerate(out_objects):
+ self.__attach_neural_type(res, out_types_list[ind][1], name=out_types_list[ind][0])
+
+ def __check_neural_type(self, obj, type_val, name=None):
+ if isinstance(obj, tuple) or isinstance(obj, list):
+ for elem in obj:
+ self.__check_neural_type(elem, type_val, name=name)
+ return # after processing nest, return to avoid testing nest itself
+
+ if hasattr(obj, 'neural_type') and not type_val.compare(obj.neural_type) in (
+ NeuralTypeComparisonResult.SAME,
+ NeuralTypeComparisonResult.GREATER,
+ ):
+ raise TypeError(
+ f"{type_val.compare(obj.neural_type)} : \n"
+ f"Input type expected = {type_val} | \n"
+ f"Input type found : {obj.neural_type}"
+ )
+
+ # Perform input ndim check
+ if hasattr(obj, 'shape'):
+ value_shape = obj.shape
+ type_shape = type_val.axes
+
+ if type_shape is not None and len(value_shape) != len(type_shape):
+ raise TypeError(
+ f"Input shape mismatch occured for {name} in module {self.__class__.__name__} : \n"
+ f"Input shape expected = {type_shape} | \n"
+ f"Input shape found : {value_shape}"
+ )
+
+ def __attach_neural_type(self, obj, type_val, name=None):
+ if isinstance(obj, tuple) or isinstance(obj, list):
+ for elem in obj:
+ self.__attach_neural_type(elem, type_val, name=name)
+ return # after processing nest, return to avoid argument insertion into nest itself
+
+ try:
+ obj.neural_type = type_val
+ except Exception:
+ pass
+
+ # Perform output ndim check
+ if hasattr(obj, 'shape'):
+ value_shape = obj.shape
+ type_shape = type_val.axes
+
+ if type_shape is not None and len(value_shape) != len(type_shape):
+ raise TypeError(
+ f"Output shape mismatch occured for {name} in module {self.__class__.__name__} : \n"
+ f"Output shape expected = {type_shape} | \n"
+ f"Output shape found : {value_shape}"
+ )
+
+
+class Serialization(ABC):
+ @classmethod
+ def from_config_dict(cls, config: DictConfig):
+ """Instantiates object using DictConfig-based configuration"""
+ # Resolve the config dict
+ if isinstance(config, DictConfig):
+ config = OmegaConf.to_container(config, resolve=True)
+ config = OmegaConf.create(config)
+ OmegaConf.set_struct(config, True)
+
+ config = maybe_update_config_version(config)
+
+ if ('cls' in config or 'target' in config) and 'params' in config:
+ # regular hydra-based instantiation
+ instance = hydra.utils.instantiate(config=config)
+ elif '_target_' in config:
+ # regular hydra-based instantiation
+ instance = hydra.utils.instantiate(config=config)
+ else:
+ # models are handled differently for now
+ instance = cls(cfg=config)
+
+ if not hasattr(instance, '_cfg'):
+ instance._cfg = config
+ return instance
+
+ def to_config_dict(self) -> DictConfig:
+ """Returns object's configuration to config dictionary"""
+ if hasattr(self, '_cfg') and self._cfg is not None and isinstance(self._cfg, DictConfig):
+ # Resolve the config dict
+ config = OmegaConf.to_container(self._cfg, resolve=True)
+ config = OmegaConf.create(config)
+ OmegaConf.set_struct(config, True)
+
+ config = maybe_update_config_version(config)
+
+ self._cfg = config
+
+ return self._cfg
+ else:
+ raise NotImplementedError(
+ 'to_config_dict() can currently only return object._cfg but current object does not have it.'
+ )
+
+
+class FileIO(ABC):
+ def save_to(self, save_path: str):
+ """Saves module/model with weights"""
+ raise NotImplementedError()
+
+ @classmethod
+ def restore_from(
+ cls,
+ restore_path: str,
+ override_config_path: Optional[str] = None,
+ map_location: Optional['torch.device'] = None,
+ strict: bool = True,
+ ):
+ """Restores module/model with weights"""
+ raise NotImplementedError()
+
+ @classmethod
+ def from_config_file(cls, path2yaml_file: str):
+ """
+ Instantiates an instance of NeMo Model from YAML config file.
+ Weights will be initialized randomly.
+ Args:
+ path2yaml_file: path to yaml file with model configuration
+
+ Returns:
+
+ """
+ if issubclass(cls, Serialization):
+ conf = OmegaConf.load(path2yaml_file)
+ return cls.from_config_dict(config=conf)
+ else:
+ raise NotImplementedError()
+
+ def to_config_file(self, path2yaml_file: str):
+ """
+ Saves current instance's configuration to YAML config file. Weights will not be saved.
+ Args:
+ path2yaml_file: path2yaml_file: path to yaml file where model model configuration will be saved
+
+ Returns:
+ """
+ if hasattr(self, '_cfg'):
+ self._cfg = maybe_update_config_version(self._cfg)
+ with open(path2yaml_file, 'w', encoding='utf-8') as fout:
+ OmegaConf.save(config=self._cfg, f=fout, resolve=True)
+ else:
+ raise NotImplementedError()
+
+
+@dataclass
+class PretrainedModelInfo:
+ pretrained_model_name: str
+ description: str
+ location: str
+ class_: 'Model' = None
+
+
+class Model(Typing, Serialization, FileIO):
+ """
+ Abstract class offering interface which should be implemented by all NeMo models.
+ """
+
+ @classmethod
+ @abstractmethod
+ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
+ """
+ Should list all pre-trained models available via NVIDIA NGC cloud
+
+ Returns:
+ A list of PretrainedModelInfo entries
+ """
+ pass
+
+ @classmethod
+ def get_available_model_names(cls) -> List[str]:
+ """
+ Returns the list of model names available via NVIDIA NGC cloud,
+ to get the complete model description use list_available_models()
+ Returns:
+ A list of model names
+ """
+ model_names = []
+ if cls.list_available_models() is not None:
+ model_names = [model.pretrained_model_name for model in cls.list_available_models()]
+ return model_names
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name: str,
+ refresh_cache: bool = False,
+ override_config_path: Optional[str] = None,
+ map_location: Optional['torch.device'] = None,
+ strict: bool = True,
+ ):
+ """
+ Instantiates an instance of NeMo from NVIDIA NGC cloud
+ Use restore_from() to instantiate from a local .nemo file.
+ Args:
+ model_name: string key which will be used to find the module.
+ refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file
+ from cloud even if it is already found in a cache locally.
+ override_config_path: path to a yaml config that will override the internal
+ config file
+ map_location: Optional torch.device() to map the instantiated model to a device.
+ By default (None), it will select a GPU if available, falling back to CPU otherwise.
+ strict: Passed to torch.load_state_dict
+
+ Returns:
+ A model instance of a particular model class
+ """
+ location_in_the_cloud = None
+ description = None
+ if cls.list_available_models() is not None:
+ for pretrained_model_info in cls.list_available_models():
+ if pretrained_model_info.pretrained_model_name == model_name:
+ location_in_the_cloud = pretrained_model_info.location
+ description = pretrained_model_info.description
+ class_ = pretrained_model_info.class_
+ if location_in_the_cloud is None:
+ raise FileNotFoundError(
+ f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models."
+ )
+ filename = location_in_the_cloud.split("/")[-1]
+ url = location_in_the_cloud.replace(filename, "")
+ cache_dir = Path.joinpath(Path.home(), f'.cache/torch/NeMo/NeMo_{nemo.__version__}/{filename[:-5]}')
+ # If either description and location in the cloud changes, this will force re-download
+ cache_subfolder = hashlib.md5((location_in_the_cloud + description).encode('utf-8')).hexdigest()
+ # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True
+ nemo_model_file_in_cache = maybe_download_from_cloud(
+ url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache
+ )
+ logging.info("Instantiating model from pre-trained checkpoint")
+ if class_ is None:
+ class_ = cls
+ instance = class_.restore_from(
+ restore_path=nemo_model_file_in_cache,
+ override_config_path=override_config_path,
+ map_location=map_location,
+ strict=strict,
+ )
+ return instance
+
+
+class typecheck:
+ class TypeState(Enum):
+ """
+ Placeholder to denote the default value of type information provided.
+ If the constructor of this decorator is used to override the class level type definition,
+ this enum value indicate that types will be overridden.
+ """
+
+ UNINITIALIZED = 0
+
+ def __init__(
+ self,
+ input_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED,
+ output_types: Union[TypeState, Dict[str, NeuralType]] = TypeState.UNINITIALIZED,
+ ):
+ """
+ A decorator which performs input-output neural type checks, and attaches
+ neural types to the output of the function that it wraps.
+
+ Requires that the class inherit from `nemo.core.Typing` in order to perform
+ type checking, and will raise an error if that is not the case.
+
+ # Usage (Class level type support)
+ @typecheck()
+ def fn(self, arg1, arg2, ...):
+ ...
+
+ # Usage (Function level type support)
+ @typecheck(input_types=..., output_types=...)
+ def fn(self, arg1, arg2, ...):
+ ...
+
+ Points to be noted:
+ 1) The brackets () in `@typecheck()` are necessary.
+
+ You will encounter a TypeError: __init__() takes 1 positional argument but X
+ were given without those brackets.
+
+ 2) The function can take any number of positional arguments during definition.
+
+ When you call this function, all arguments must be passed using kwargs only.
+
+ """
+ self.input_types = input_types
+ self.output_types = output_types
+
+ if input_types == self.TypeState.UNINITIALIZED:
+ self.input_override = False
+ else:
+ self.input_override = True
+
+ if output_types == self.TypeState.UNINITIALIZED:
+ self.output_override = False
+ else:
+ self.output_override = True
+
+ @wrapt.decorator(enabled=is_typecheck_enabled)
+ def __call__(self, wrapped, instance: Typing, args, kwargs):
+ if instance is None:
+ raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !")
+
+ if not isinstance(instance, Typing):
+ raise RuntimeError("Only classes which inherit nemo.core.Typing can use this decorator !")
+
+ if hasattr(instance, 'input_ports') or hasattr(instance, 'output_ports'):
+ raise RuntimeError(
+ "Typing requires override of `input_types()` and `output_types()`, "
+ "not `input_ports() and `output_ports()`"
+ )
+
+ # Preserve type information
+ if self.input_types is typecheck.TypeState.UNINITIALIZED:
+ self.input_types = instance.input_types
+
+ if self.output_types is typecheck.TypeState.UNINITIALIZED:
+ self.output_types = instance.output_types
+
+ # Resolve global type or local overridden type
+ if self.input_override:
+ input_types = self.input_types
+ else:
+ input_types = instance.input_types
+
+ if self.output_override:
+ output_types = self.output_types
+ else:
+ output_types = instance.output_types
+
+ # If types are not defined, skip type checks and just call the wrapped method
+ if input_types is None and output_types is None:
+ return wrapped(*args, **kwargs)
+
+ # Check that all arguments are kwargs
+ if input_types is not None and len(args) > 0:
+ raise TypeError("All arguments must be passed by kwargs only for typed methods")
+
+ # Perform rudimentary input checks here
+ instance._validate_input_types(input_types=input_types, **kwargs)
+
+ # Call the method - this can be forward, or any other callable method
+ outputs = wrapped(*args, **kwargs)
+
+ instance._attach_and_validate_output_types(output_types=output_types, out_objects=outputs)
+
+ return outputs
+
+ @staticmethod
+ def set_typecheck_enabled(enabled: bool = True):
+ global _TYPECHECK_ENABLED
+ _TYPECHECK_ENABLED = enabled
+
+ @staticmethod
+ @contextmanager
+ def disable_checks():
+ typecheck.set_typecheck_enabled(enabled=False)
+ try:
+ yield
+ finally:
+ typecheck.set_typecheck_enabled(enabled=True)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..51bd46ef79010660616eba88461b1c74ce8fbd04
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/dataset.py
@@ -0,0 +1,109 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional
+
+from torch.utils import data
+
+from nemo.core.classes import Serialization, Typing, typecheck
+
+__all__ = ['Dataset', 'IterableDataset']
+
+
+class Dataset(data.Dataset, Typing, Serialization):
+ """Dataset with output ports
+
+ Please Note: Subclasses of IterableDataset should *not* implement input_types.
+ """
+
+ def _collate_fn(self, batch):
+ """
+ A default implementation of a collation function.
+ Users should override this method to define custom data loaders.
+ """
+ return data.dataloader.default_collate(batch)
+
+ @typecheck()
+ def collate_fn(self, batch):
+ """
+ This is the method that user pass as functor to DataLoader.
+ The method optionally performs neural type checking and add types to the outputs.
+
+ Please note, subclasses of Dataset should not implement `input_types`.
+
+ # Usage:
+ dataloader = torch.utils.data.DataLoader(
+ ....,
+ collate_fn=dataset.collate_fn,
+ ....
+ )
+
+ Returns:
+ Collated batch, with or without types.
+ """
+ if self.input_types is not None:
+ raise TypeError("Datasets should not implement `input_types` as they are not checked")
+
+ # Simply forward the inner `_collate_fn`
+ return self._collate_fn(batch)
+
+
+class IterableDataset(data.IterableDataset, Typing, Serialization):
+ """Iterable Dataset with output ports
+
+ Please Note: Subclasses of IterableDataset should *not* implement input_types.
+ """
+
+ def _collate_fn(self, batch):
+ """
+ A default implementation of a collation function.
+ Users should override this method to define custom data loaders.
+ """
+ return data.dataloader.default_collate(batch)
+
+ @typecheck()
+ def collate_fn(self, batch):
+ """
+ This is the method that user pass as functor to DataLoader.
+ The method optionally performs neural type checking and add types to the outputs.
+
+ # Usage:
+ dataloader = torch.utils.data.DataLoader(
+ ....,
+ collate_fn=dataset.collate_fn,
+ ....
+ )
+
+ Returns:
+ Collated batch, with or without types.
+ """
+ if self.input_types is not None:
+ raise TypeError("Datasets should not implement `input_types` as they are not checked")
+
+ # Simply forward the inner `_collate_fn`
+ return self._collate_fn(batch)
+
+
+@dataclass
+class DatasetConfig:
+ """
+
+ """
+
+ # ...
+ batch_size: int = 32
+ drop_last: bool = False
+ shuffle: bool = False
+ num_workers: Optional[int] = None
+ pin_memory: bool = True
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py
new file mode 100644
index 0000000000000000000000000000000000000000..fddc2f4df5f8da51b54c3ada2618c10592b865ba
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/exportable.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from abc import ABC
+from collections import defaultdict
+from enum import Enum
+from typing import Dict
+
+import onnx
+import torch
+
+from nemo.core.classes import typecheck
+from nemo.core.neural_types import AxisKind, NeuralType
+from nemo.utils.export_utils import replace_for_export
+
+__all__ = ['ExportFormat', 'Exportable']
+
+
+class ExportFormat(Enum):
+ """Which format to use when exporting a Neural Module for deployment"""
+
+ ONNX = (1,)
+ TORCHSCRIPT = (2,)
+
+
+_EXT_DICT = {
+ ".pt": ExportFormat.TORCHSCRIPT,
+ ".onnx": ExportFormat.ONNX,
+}
+
+
+class Exportable(ABC):
+ """
+ This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT.
+ It gives these entities ability to be exported for deployment to formats such as ONNX.
+ """
+
+ @staticmethod
+ def get_format(filename: str):
+ _, ext = os.path.splitext(filename)
+ try:
+ return _EXT_DICT[ext]
+ except KeyError:
+ raise ValueError(f"Export file {filename} extension does not correspond to any export format!")
+
+ def export(
+ self,
+ output: str,
+ input_example=None,
+ output_example=None,
+ verbose=False,
+ export_params=True,
+ do_constant_folding=True,
+ keep_initializers_as_inputs=False,
+ onnx_opset_version: int = 12,
+ try_script: bool = False,
+ set_eval: bool = True,
+ check_trace: bool = True,
+ use_dynamic_axes: bool = True,
+ ):
+ try:
+ # Disable typechecks
+ typecheck.set_typecheck_enabled(enabled=False)
+
+ # Set module to eval mode
+ if set_eval:
+ self.eval()
+
+ format = self.get_format(output)
+ self._prepare_for_export()
+
+ if input_example is not None:
+ _in_example = input_example
+ else:
+ _in_example = self.input_example()
+
+ if output_example is None:
+ _out_example = self.forward(*_in_example)
+
+ if not (hasattr(self, 'input_types') and hasattr(self, 'output_types')):
+ raise NotImplementedError('For export to work you must define input and output types')
+ input_names = list(self.input_types.keys())
+ output_names = list(self.output_types.keys())
+ # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
+ dynamic_axes = defaultdict(list)
+
+ # extract dynamic axes and remove unnecessary inputs/outputs
+ # for input_ports
+ for _name, ntype in self.input_types.items():
+ if _name in self.disabled_deployment_input_names:
+ input_names.remove(_name)
+ continue
+ if use_dynamic_axes:
+ dynamic_axes = {**dynamic_axes, **self._extract_dynamic_axes(_name, ntype)}
+ # for output_ports
+ for _name, ntype in self.output_types.items():
+ if _name in self.disabled_deployment_output_names:
+ output_names.remove(_name)
+ continue
+ if use_dynamic_axes:
+ dynamic_axes = {**dynamic_axes, **self._extract_dynamic_axes(_name, ntype)}
+
+ if len(dynamic_axes) == 0:
+ dynamic_axes = None
+
+ with torch.jit.optimized_execution(True):
+ jitted_model = None
+ if try_script:
+ try:
+ jitted_model = torch.jit.script(self)
+ except Exception as e:
+ print("jit.script() failed!", e)
+ if _in_example is None:
+ raise ValueError(f'Example input is None, but jit.script() has failed or not tried')
+
+ if isinstance(_in_example, Dict):
+ _in_example = tuple(_in_example.values())
+
+ if jitted_model is None:
+ jitted_model = torch.jit.trace(self, _in_example, check_trace=check_trace)
+
+ if format == ExportFormat.TORCHSCRIPT:
+ jitted_model.save(output)
+ assert os.path.exists(output)
+ elif format == ExportFormat.ONNX:
+ if _out_example is None:
+ if isinstance(_in_example, tuple):
+ _out_example = self.forward(*_in_example)
+ else:
+ _out_example = self.forward(_in_example)
+
+ torch.onnx.export(
+ jitted_model,
+ _in_example,
+ output,
+ input_names=input_names,
+ output_names=output_names,
+ verbose=verbose,
+ export_params=export_params,
+ do_constant_folding=do_constant_folding,
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
+ dynamic_axes=dynamic_axes,
+ opset_version=onnx_opset_version,
+ example_outputs=_out_example,
+ )
+
+ # Verify the model can be read, and is valid
+ onnx_model = onnx.load(output)
+ onnx.checker.check_model(onnx_model, full_check=True)
+ return onnx_model
+ else:
+ raise ValueError(f'Encountered unknown export format {format}.')
+ finally:
+ typecheck.set_typecheck_enabled(enabled=True)
+ return [output] # Subclasses may create more than one file.
+
+ @property
+ def disabled_deployment_input_names(self):
+ """Implement this method to return a set of input names disabled for export"""
+ return set()
+
+ @property
+ def disabled_deployment_output_names(self):
+ """Implement this method to return a set of output names disabled for export"""
+ return set()
+
+ @property
+ def supported_export_formats(self):
+ """Implement this method to return a set of export formats supported. Default is all types."""
+ return set([ExportFormat.ONNX, ExportFormat.TORCHSCRIPT])
+
+ @staticmethod
+ def _extract_dynamic_axes(name: str, ntype: NeuralType):
+ """
+ Implement this method to provide dynamic axes id for ONNX export.
+ By default, this method will extract BATCH and TIME dimension ids from each provided input/output name argument.
+
+ For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim]
+ shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes
+ as they can change from call to call during inference.
+
+ Args:
+ name: Name of input or output parameter
+ ntype: Corresponding Neural Type
+
+ Returns:
+
+ """
+ dynamic_axes = defaultdict(list)
+ if ntype.axes:
+ for ind, axis in enumerate(ntype.axes):
+ if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]:
+ dynamic_axes[name].append(ind)
+ return dynamic_axes
+
+ def _prepare_for_export(self):
+ """
+ Override this method to prepare module for export. This is in-place operation.
+ Base version does common necessary module replacements (Apex etc)
+ """
+ replace_for_export(self)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6ede4049c5374548cecc389cd8529e8db576dfd
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/loss.py
@@ -0,0 +1,26 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from nemo.core.classes.common import Serialization, Typing
+
+__all__ = ['Loss']
+
+
+class Loss(torch.nn.modules.loss._Loss, Typing, Serialization):
+ """Inherit this class to implement custom loss."""
+
+ def __init__(self, **kwargs):
+ super(Loss, self).__init__(**kwargs)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb75d4341337a0e9ad55bccacbf51fae94754bf
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/modelPT.py
@@ -0,0 +1,1303 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+import os
+import shutil
+import tarfile
+import tempfile
+from abc import abstractmethod
+from os import path
+from typing import Callable, Dict, List, Optional, Union
+
+import hydra
+import torch
+from omegaconf import DictConfig, OmegaConf, open_dict
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.utilities import rank_zero_only
+
+from nemo.core import optim
+from nemo.core.classes.common import Model
+from nemo.core.config.modelPT import ModelPTConfig
+from nemo.core.optim import prepare_lr_scheduler
+from nemo.utils import config_utils, logging, model_utils
+from nemo.utils.app_state import AppState
+from nemo.utils.get_rank import is_global_rank_zero
+
+# Need to set them before EFF import as it is using them.
+_MODEL_CONFIG_YAML = "model_config.yaml"
+_MODEL_WEIGHTS = "model_weights.ckpt"
+
+try:
+ # Try to import strategies for .nemo archive.
+ from eff.cookbooks import NeMoCookbook
+
+ _EFF_PRESENT_ = True
+except ImportError:
+ _EFF_PRESENT_ = False
+
+__all__ = ['ModelPT']
+
+"""
+Internal global flags that determine core functionality of ModelPT.
+
+_MODEL_IS_RESTORED:
+ This flag determines the context of the model - whether the model is currently being
+ restored or not.
+ - When set, it can be assumed that the model's will disable all automatic methods -
+ setup_training_data(), setup_validation/test_data() and their multi equivalents.
+ - If a model is being restored from a archive file (tarfile), it can be assumed that
+ under this context, the cwd is *inside* the tarfile itself.
+
+_MODEL_RESTORE_PATH:
+ A string path to a a file from which the model is being restored.
+ This file can either be a PyTorch Lightning Checkpoint, or a archive (tarfile) that contains
+ artifact objects.
+ If it is an archive file, during restoration, the cwd will be temporarily moved to inside the
+ archive itself.
+
+_MODEL_EFF_SAVE:
+ A global flag that switches the format of the archive file that will be stored.
+ This flag only enables EFF when the package support is available.
+"""
+_MODEL_IS_RESTORED = False
+_MODEL_RESTORE_PATH = None
+_MODEL_EFF_SAVE = True
+
+
+class ModelPT(LightningModule, Model):
+ """
+ Interface for Pytorch-lightning based NeMo models
+ """
+
+ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
+ """
+ Base class from which all NeMo models should inherit
+
+ Args:
+ cfg (DictConfig): configuration object.
+ The cfg object should have (optionally) the following sub-configs:
+
+ * train_ds - to instantiate training dataset
+ * validation_ds - to instantiate validation dataset
+ * test_ds - to instantiate testing dataset
+ * optim - to instantiate optimizer with learning rate scheduler
+
+ trainer (Optional): Pytorch Lightning Trainer instance
+ """
+ if trainer is not None and not isinstance(trainer, Trainer):
+ raise ValueError(
+ f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead."
+ )
+ super().__init__()
+
+ # Convert config to a DictConfig
+ cfg = model_utils.convert_model_config_to_dict_config(cfg)
+
+ # Convert config to support Hydra 1.0+ instantiation
+ cfg = model_utils.maybe_update_config_version(cfg)
+
+ if 'target' not in cfg:
+ # This is for Jarvis service.
+ OmegaConf.set_struct(cfg, False)
+ cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__)
+ OmegaConf.set_struct(cfg, True)
+
+ self._cfg = cfg
+
+ self.save_hyperparameters(self._cfg)
+ self._train_dl = None
+ self._validation_dl = None
+ self._test_dl = None
+ self._optimizer = None
+ self._scheduler = None
+ self._trainer = trainer
+
+ # Set device_id in AppState
+ if torch.cuda.is_available() and torch.cuda.current_device() is not None:
+ app_state = AppState()
+ app_state.device_id = torch.cuda.current_device()
+
+ if self._cfg is not None and not self._is_model_being_restored():
+ if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
+ self.setup_training_data(self._cfg.train_ds)
+
+ if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
+ self.setup_multiple_validation_data(val_data_config=None)
+
+ if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
+ self.setup_multiple_test_data(test_data_config=None)
+
+ else:
+ if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
+ logging.warning(
+ f"Please call the ModelPT.setup_training_data() method "
+ f"and provide a valid configuration file to setup the train data loader.\n"
+ f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}"
+ )
+
+ if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
+ logging.warning(
+ f"Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method "
+ f"and provide a valid configuration file to setup the validation data loader(s). \n"
+ f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}"
+ )
+
+ if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
+ logging.warning(
+ f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method "
+ f"and provide a valid configuration file to setup the test data loader(s).\n"
+ f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}"
+ )
+
+ # ModelPT wrappers over subclass implementations
+ self.training_step = model_utils.wrap_training_step(self.training_step)
+
+ def register_artifact(self, config_path: str, src: str):
+ """
+ Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
+ when model.save_to("mymodel.nemo") is called.
+
+ WARNING: If you specified /example_folder/example.txt but ./example.txt exists, then ./example.txt will be used.
+
+ Args:
+ config_path: config path where artifact is used
+ src: path to the artifact
+
+ Returns:
+ path to be used when accessing artifact. If src='' or None then '' or None will be returned
+ """
+ if not hasattr(self, 'artifacts'):
+ self.artifacts = {}
+ if self.artifacts is None:
+ self.artifacts = {}
+ if src is not None and src.strip() != '':
+ archive_item = model_utils.ArtifactItem()
+
+ basename_src = os.path.basename(src)
+ # filename exists in current workdir - use it and raise warning
+ # this case is during model restoration or when file is written to cwd.
+ if os.path.exists(basename_src):
+ logging.warning(f"Using {os.path.abspath(basename_src)} instead of {src}.")
+ used_src = basename_src
+
+ # Case: register_artifact() called inside restoration context
+ if self._is_model_being_restored() and self._is_restore_type_tarfile():
+ archive_item.path_type = model_utils.ArtifactPathType.TAR_PATH
+ else:
+ archive_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH
+
+ else:
+ used_src = src
+ archive_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH
+
+ if not os.path.exists(used_src):
+ # File not found in local path or by basename
+ # Try to locate it inside the .nemo archive (if model was restored)
+ # Case: register_artifact() called outside restoration context
+ if self._is_restore_type_tarfile():
+ # Get path where the command is executed - the artifacts will be "retrieved" there
+ # (original .nemo behavior)
+ cwd = os.getcwd()
+ try:
+ # Step into the nemo archive to try and find the file
+ with tempfile.TemporaryDirectory() as tmpdir:
+ self.__unpack_nemo_file(path2file=_MODEL_RESTORE_PATH, out_folder=tmpdir)
+ os.chdir(tmpdir)
+ if os.path.exists(basename_src):
+ logging.warning(f"Using {os.path.abspath(basename_src)} instead of {src}.")
+ used_src = basename_src
+
+ archive_item.path = used_src
+ archive_item.path_type = model_utils.ArtifactPathType.TAR_PATH
+ else:
+ # No further action can be taken, file not found anywhere
+ raise FileNotFoundError(
+ f"Could not find {used_src} inside "
+ f"tarfile {_MODEL_RESTORE_PATH} or under local"
+ )
+ finally:
+ # change back working directory
+ os.chdir(cwd)
+ else:
+ # No further action can be taken, file not found anywhere
+ raise FileNotFoundError(f"Could not find {used_src}")
+ else:
+ # Found filepath
+ archive_item.path = used_src
+
+ # But disregarding whether you use "local" or "remote" artifact - always store the original path.
+ # This fixes issues raising when finetuning NLP models that create and register tokenizer vocabs.
+ if config_path in self.artifacts:
+ logging.warning(
+ f"Artifact {config_path} with value '{self.artifacts[config_path]}' "
+ f"already exists and will be overwritten with value '{src}'!"
+ )
+
+ self.artifacts[config_path] = archive_item
+ return used_src
+ else:
+ return src
+
+ def _default_save_to(self, save_path: str):
+ """
+ Saves model instance (weights and configuration) into .nemo file.
+ You can use "restore_from" method to fully restore instance from .nemo file.
+
+ .nemo file is an archive (tar.gz) with the following:
+ model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
+ model_wights.chpt - model checkpoint
+
+ Args:
+ save_path: Path to .nemo file where model instance should be saved
+ """
+ with tempfile.TemporaryDirectory() as tmpdir:
+ config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML)
+ model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
+
+ if hasattr(self, 'artifacts') and self.artifacts is not None:
+ for (conf_path, src) in self.artifacts.items(): # type: (str, model_utils.ArtifactItem)
+ try:
+ if src.path_type == model_utils.ArtifactPathType.LOCAL_PATH and os.path.exists(src.path):
+ shutil.copy2(src.path, tmpdir)
+ elif src.path_type == model_utils.ArtifactPathType.TAR_PATH:
+ # Need to step into nemo archive to extract file
+ # Get path where the command is executed - the artifacts will be "retrieved" there
+ # (original .nemo behavior)
+ cwd = os.getcwd()
+ try:
+ # Step into the nemo archive to try and find the file
+ with tempfile.TemporaryDirectory() as archive_dir:
+ self.__unpack_nemo_file(path2file=_MODEL_RESTORE_PATH, out_folder=archive_dir)
+ os.chdir(archive_dir)
+ shutil.copy2(src.path, tmpdir)
+ finally:
+ # change back working directory
+ os.chdir(cwd)
+ else:
+ raise ValueError(f"Invalid ArchivePathType found: {src.path_type}")
+ except Exception:
+ logging.error(f"Could not copy artifact {src} used in {conf_path}")
+
+ self.to_config_file(path2yaml_file=config_yaml)
+ torch.save(self.state_dict(), model_weights)
+ self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
+
+ def _eff_save_to(self, save_path: str):
+ """
+ Saves model instance (weights, configuration and artifacts) into an EFF archive using
+ the default `save_to` recipe from NeMoCookbook.
+
+ .. note::
+ For NVIDIA NeMo the EFF archives will also use .nemo postfix.
+
+ Method creates an EFF-based file that is an archive (tar.gz) with the following:
+ manifest.yaml - yaml file describing the content of the archive.
+ model_config.yaml - model configuration in .yaml format.
+ You can deserialize this into cfg argument for model's constructor
+ model_wights.chpt - model checkpoint
+
+ Args:
+ save_path: Path to archive file where model instance should be saved.
+ """
+ NeMoCookbook().save_to(obj=self, save_path=save_path)
+
+ @rank_zero_only
+ def save_to(self, save_path: str):
+ """
+ Saves model instance (weights and configuration) into EFF archive or .
+ You can use "restore_from" method to fully restore instance from .nemo file.
+
+ .nemo file is an archive (tar.gz) with the following:
+ model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
+ model_wights.chpt - model checkpoint
+
+ Args:
+ save_path: Path to .nemo file where model instance should be saved
+ """
+
+ # Add nemo rank check as well
+ if not is_global_rank_zero():
+ return
+
+ if _EFF_PRESENT_ and self.use_eff_save():
+ # Save EFF archive.
+ self._eff_save_to(save_path)
+ else:
+ # Save .nemo tar archive.
+ self._default_save_to(save_path)
+
+ @classmethod
+ def _default_restore_from(
+ cls,
+ restore_path: str,
+ override_config_path: Optional[str] = None,
+ map_location: Optional[torch.device] = None,
+ strict: bool = False,
+ ):
+ """
+ Restores model instance (weights and configuration) into .nemo file
+ Args:
+ restore_path: path to .nemo file from which model should be instantiated
+ override_config_path: path to a yaml config that will override the internal
+ config file
+ map_location: Optional torch.device() to map the instantiated model to a device.
+ By default (None), it will select a GPU if available, falling back to CPU otherwise.
+ strict: Passed to load_state_dict.
+
+ Example:
+ ```
+ model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
+ assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
+ ```
+
+ Returns:
+ An instance of type cls
+ """
+ # Get path where the command is executed - the artifacts will be "retrieved" there
+ # (original .nemo behavior)
+ cwd = os.getcwd()
+
+ if map_location is None:
+ if torch.cuda.is_available():
+ map_location = torch.device('cuda')
+ else:
+ map_location = torch.device('cpu')
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ try:
+ cls._set_model_restore_state(is_being_restored=True)
+ cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
+ os.chdir(tmpdir)
+ if override_config_path is None:
+ config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML)
+ else:
+ config_yaml = override_config_path
+ conf = OmegaConf.load(config_yaml)
+ if override_config_path is not None:
+ # Resolve the override config
+ conf = OmegaConf.to_container(conf, resolve=True)
+ conf = OmegaConf.create(conf)
+ # If override is top level config, extract just `model` from it
+ if 'model' in conf:
+ conf = conf.model
+ model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
+ OmegaConf.set_struct(conf, True)
+ instance = cls.from_config_dict(config=conf)
+ instance = instance.to(map_location)
+ instance.load_state_dict(torch.load(model_weights, map_location=map_location), strict=strict)
+
+ logging.info(f'Model {cls.__name__} was successfully restored from {restore_path}.')
+ finally:
+ cls._set_model_restore_state(is_being_restored=False)
+ os.chdir(cwd)
+
+ return instance
+
+ @classmethod
+ def _eff_restore_from(
+ cls,
+ restore_path: str,
+ override_config_path: Optional[str] = None,
+ map_location: Optional[torch.device] = None,
+ strict: bool = False,
+ ):
+ """
+ Restores model instance (weights, configuration and artifacts) from EFF Archive using
+ the default `restore_from` recipe from NeMoCookbook.
+
+ Args:
+ restore_path: path to file from which model should be instantiated
+ override_config_path: path to a yaml config that will override the internal
+ config file
+ map_location: Optional torch.device() to map the instantiated model to a device.
+ By default (None), it will select a GPU if available, falling back to CPU otherwise.
+ strict: Passed to load_state_dict.
+
+ Returns:
+ An instance of type cls
+ """
+ return NeMoCookbook().restore_from(
+ restore_path=restore_path,
+ obj_cls=cls,
+ override_config_path=override_config_path,
+ map_location=map_location,
+ strict=strict,
+ )
+
+ @classmethod
+ def restore_from(
+ cls,
+ restore_path: str,
+ override_config_path: Optional[str] = None,
+ map_location: Optional[torch.device] = None,
+ strict: bool = False,
+ ):
+ """
+ Restores model instance (weights and configuration) from file.
+
+ The methods tries to load it as EFF archive.
+ If EFF library is not present in the system, or the indicated file is not EFF archive,
+ the function defaults to the original .nemo restore method.
+
+ Args:
+ restore_path: path to .nemo file from which model should be instantiated
+ override_config_path: path to a yaml config that will override the internal
+ config file
+ map_location: Optional torch.device() to map the instantiated model to a device.
+ By default (None), it will select a GPU if available, falling back to CPU otherwise.
+ strict: Passed to load_state_dict.
+
+ Example:
+ ```
+ model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
+ assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
+ ```
+
+ Returns:
+ An instance of type cls
+ """
+ if not path.exists(restore_path):
+ raise FileNotFoundError(f"Can't find {restore_path}")
+
+ global _MODEL_RESTORE_PATH
+ _MODEL_RESTORE_PATH = os.path.abspath(os.path.expanduser(restore_path))
+
+ if _EFF_PRESENT_:
+ # Try to load the EFF archive.
+ try:
+ return cls._eff_restore_from(restore_path, override_config_path, map_location, strict)
+ except (FileNotFoundError, TypeError):
+ # Default to the old .nemo tar archive restore method.
+ return cls._default_restore_from(restore_path, override_config_path, map_location, strict)
+ else:
+ # Load .nemo tar archive using the old restore method.
+ return cls._default_restore_from(restore_path, override_config_path, map_location, strict)
+
+ @classmethod
+ def extract_state_dict_from(cls, restore_path: str, save_dir: str, split_by_module: bool = False):
+ """
+ Extract the state dict(s) from a provided .nemo tarfile and save it to a directory.
+ Args:
+ restore_path: path to .nemo file from which state dict(s) should be extracted
+ save_dir: directory in which the saved state dict(s) should be stored
+ split_by_module: bool flag, which determins whether the output checkpoint should
+ be for the entire Model, or the individual module's that comprise the Model
+
+ Example:
+ To convert the .nemo tarfile into a single Model level PyTorch checkpoint
+ ```
+ state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts)
+ ```
+
+ To restore a model from a Model level checkpoint
+ ```
+ model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
+ model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))
+ ```
+
+ To convert the .nemo tarfile into multiple Module level PyTorch checkpoints
+ ```
+ state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts,
+ split_by_module=True)
+ ```
+
+ To restore a module from a Module level checkpoint
+ ```
+ model = model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration
+
+ # load the individual components
+ model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
+ model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
+ model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))
+ ```
+
+ Returns:
+ The state dict that was loaded from the original .nemo checkpoint
+ """
+ if not path.exists(restore_path):
+ raise FileExistsError(f"Can't find {restore_path}")
+
+ cwd = os.getcwd()
+
+ save_dir = os.path.abspath(save_dir)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ try:
+ cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
+ os.chdir(tmpdir)
+ model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
+ state_dict = torch.load(model_weights)
+
+ if not split_by_module:
+ filepath = os.path.join(save_dir, _MODEL_WEIGHTS)
+ torch.save(state_dict, filepath)
+
+ else:
+ key_set = set([key.split(".")[0] for key in state_dict.keys()])
+ for primary_key in key_set:
+ inner_keys = [key for key in state_dict.keys() if key.split(".")[0] == primary_key]
+ state_dict_subset = {
+ ".".join(inner_key.split(".")[1:]): state_dict[inner_key] for inner_key in inner_keys
+ }
+ filepath = os.path.join(save_dir, f"{primary_key}.ckpt")
+ torch.save(state_dict_subset, filepath)
+
+ logging.info(f'Checkpoints from {restore_path} were successfully extracted into {save_dir}.')
+ finally:
+ os.chdir(cwd)
+
+ return state_dict
+
+ @classmethod
+ def load_from_checkpoint(
+ cls,
+ checkpoint_path: str,
+ *args,
+ map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
+ hparams_file: Optional[str] = None,
+ strict: bool = True,
+ **kwargs,
+ ):
+ """
+ Loads ModelPT from checkpoint, with some maintenance of restoration.
+ For documentation, please refer to LightningModule.load_from_checkpoin() documentation.
+ """
+ checkpoint = None
+ try:
+ cls._set_model_restore_state(is_being_restored=True)
+
+ checkpoint = super().load_from_checkpoint(
+ checkpoint_path=checkpoint_path,
+ *args,
+ map_location=map_location,
+ hparams_file=hparams_file,
+ strict=strict,
+ **kwargs,
+ )
+
+ finally:
+ cls._set_model_restore_state(is_being_restored=False)
+ return checkpoint
+
+ @classmethod
+ def load_state_from_checkpoint(
+ cls,
+ model,
+ checkpoint_path: str,
+ map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
+ strict: bool = True,
+ ):
+ try:
+ cls._set_model_restore_state(is_being_restored=True)
+
+ from pytorch_lightning.utilities.cloud_io import load as pl_load
+ if map_location is not None:
+ checkpoint = pl_load(checkpoint_path, map_location=map_location)
+ else:
+ checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
+
+ # for past checkpoint need to add the new key
+ if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
+ checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
+
+ # give model a chance to load something
+ model.on_load_checkpoint(checkpoint)
+
+ # load the state_dict on the model automatically
+ model.load_state_dict(checkpoint['state_dict'], strict=strict)
+ finally:
+ cls._set_model_restore_state(is_being_restored=False)
+
+ @abstractmethod
+ def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
+ """
+ Setups data loader to be used in training
+
+ Args:
+ train_data_layer_config: training data layer parameters.
+ Returns:
+
+ """
+ pass
+
+ @abstractmethod
+ def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
+ """
+ Setups data loader to be used in validation
+ Args:
+
+ val_data_layer_config: validation data layer parameters.
+ Returns:
+
+ """
+ pass
+
+ def setup_test_data(self, test_data_config: Union[DictConfig, Dict]):
+ """
+ (Optionally) Setups data loader to be used in test
+
+ Args:
+ test_data_layer_config: test data layer parameters.
+ Returns:
+
+ """
+ raise NotImplementedError()
+
+ def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]):
+ """
+ (Optionally) Setups data loader to be used in validation, with support for multiple data loaders.
+
+ Args:
+ val_data_layer_config: validation data layer parameters.
+ """
+ # Set some placeholder overriden by helper method
+ self._val_dl_idx = 0
+ self._validation_names = None
+ self._validation_dl = None # type: torch.utils.data.DataLoader
+
+ # preserve config
+ self._update_dataset_config(dataset_name='validation', config=val_data_config)
+
+ try:
+ self._multi_dataset_mode = True
+ model_utils.resolve_validation_dataloaders(model=self)
+ finally:
+ self._multi_dataset_mode = False
+
+ if self._validation_names is None:
+ if self._validation_dl is not None and type(self._validation_dl) in [list, tuple]:
+ self._validation_names = ['val_{}_'.format(idx) for idx in range(len(self._validation_dl))]
+
+ def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]):
+ """
+ (Optionally) Setups data loader to be used in test, with support for multiple data loaders.
+
+ Args:
+ test_data_layer_config: test data layer parameters.
+ """
+ # Set some placeholder overriden by helper method
+ self._test_dl_idx = 0
+ self._test_names = None
+ self._test_dl = None # type: torch.utils.data.DataLoader
+
+ # preserve config
+ self._update_dataset_config(dataset_name='test', config=test_data_config)
+
+ try:
+ self._multi_dataset_mode = True
+ model_utils.resolve_test_dataloaders(model=self)
+ finally:
+ self._multi_dataset_mode = False
+
+ if self._test_names is None:
+ if self._test_dl is not None and type(self._test_dl) in [list, tuple]:
+ self._test_names = ['test_{}_'.format(idx) for idx in range(len(self._test_dl))]
+
+ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None):
+ """
+ Prepares an optimizer from a string name and its optional config parameters.
+
+ Args:
+ optim_config: A dictionary containing the following keys:
+
+ * "lr": mandatory key for learning rate. Will raise ValueError if not provided.
+ * "optimizer": string name pointing to one of the available optimizers in the registry. \
+ If not provided, defaults to "adam".
+ * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \
+ The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \
+ will be built and supplied to instantiate the optimizer.
+ """
+ # If config was not explicitly passed to us
+ if optim_config is None:
+ # See if internal config has `optim` namespace
+ if self._cfg is not None and hasattr(self._cfg, 'optim'):
+ optim_config = self._cfg.optim
+
+ # If config is still None, or internal config has no Optim, return without instantiation
+ if optim_config is None:
+ logging.info('No optimizer config provided, therefore no optimizer was created')
+ return
+
+ else:
+ # Preserve the configuration
+ if not isinstance(optim_config, DictConfig):
+ optim_config = OmegaConf.create(optim_config)
+
+ # See if internal config has `optim` namespace before preservation
+ if self._cfg is not None and hasattr(self._cfg, 'optim'):
+ if self._cfg.optim is None:
+ self._cfg.optim = copy.deepcopy(optim_config)
+ else:
+ with open_dict(self._cfg.optim):
+ self._cfg.optim = copy.deepcopy(optim_config)
+
+ # Setup optimizer and scheduler
+ if optim_config is not None and isinstance(optim_config, DictConfig):
+ optim_config = OmegaConf.to_container(optim_config, resolve=True)
+
+ if 'sched' in optim_config and self._trainer is not None:
+ if not isinstance(self._trainer.accumulate_grad_batches, int):
+ raise ValueError("We do not currently support gradient acculumation that is not an integer.")
+ if self._trainer.max_steps is None:
+ # Store information needed to calculate max_steps
+ optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs
+ optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches
+ optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches
+ if self._trainer.distributed_backend is None:
+ optim_config['sched']['t_num_workers'] = self._trainer.num_gpus or 1
+ elif self._trainer.distributed_backend == "ddp_cpu":
+ optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes
+ elif self._trainer.distributed_backend == "ddp":
+ optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes
+ else:
+ logging.warning(
+ f"The lightning trainer received accelerator: {self._trainer.distributed_backend}. We "
+ "recommend to use 'ddp' instead."
+ )
+ optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes
+ else:
+ optim_config['sched']['max_steps'] = self._trainer.max_steps
+
+ # Force into DictConfig from nested structure
+ optim_config = OmegaConf.create(optim_config)
+ # Get back nested dict so we its mutable
+ optim_config = OmegaConf.to_container(optim_config, resolve=True)
+
+ # Extract scheduler config if inside optimizer config
+ if 'sched' in optim_config:
+ scheduler_config = optim_config.pop('sched')
+ else:
+ scheduler_config = None
+
+ # Check if caller provided optimizer name, default to Adam otherwise
+ optimizer_cls = optim_config.get('_target_', None)
+
+ if optimizer_cls is None:
+ # Try to get optimizer name for dynamic resolution, defaulting to Adam
+ optimizer_name = optim_config.get('name', 'adam')
+ else:
+ if inspect.isclass(optimizer_cls):
+ optimizer_name = optimizer_cls.__name__.lower()
+ else:
+ # resolve the class name (lowercase) from the class path if not provided
+ optimizer_name = optimizer_cls.split(".")[-1].lower()
+
+ # We are guarenteed to have lr since it is required by the argparser
+ # But maybe user forgot to pass it to this function
+ lr = optim_config.get('lr', None)
+
+ # Check if caller has optimizer kwargs, default to empty dictionary
+ if 'args' in optim_config:
+ optimizer_args = optim_config.pop('args')
+ optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args)
+ else:
+ optimizer_args = copy.deepcopy(optim_config)
+
+ # Remove extra parameters from optimizer_args nest
+ # Assume all other parameters are to be passed into optimizer constructor
+ optimizer_args.pop('name', None)
+ optimizer_args.pop('cls', None)
+ optimizer_args.pop('lr', None)
+
+ # Adaptive schedulers don't need `lr`
+ if lr is not None:
+ optimizer_args['lr'] = lr
+
+ # Actually instantiate the optimizer
+ if optimizer_cls is not None:
+ if inspect.isclass(optimizer_cls):
+ optimizer = optimizer_cls(self.optim_param_groups(), **optimizer_args)
+ logging.info("Optimizer config = %s", str(optimizer))
+
+ self._optimizer = optimizer
+
+ else:
+ # Attempt class path resolution
+ try:
+ optimizer_cls = OmegaConf.create({'_target_': optimizer_cls})
+ if lr is not None:
+ optimizer_config = {'lr': lr}
+ else:
+ optimizer_config = {}
+ optimizer_config.update(optimizer_args)
+
+ optimizer_instance = hydra.utils.instantiate(
+ optimizer_cls, self.optim_param_groups(), **optimizer_config
+ ) # type: DictConfig
+
+ logging.info("Optimizer config = %s", str(optimizer_instance))
+
+ self._optimizer = optimizer_instance
+
+ except Exception as e:
+ logging.error(
+ "Could not instantiate class path - {} with kwargs {}".format(
+ optimizer_cls, str(optimizer_config)
+ )
+ )
+ raise e
+
+ else:
+ optimizer = optim.get_optimizer(optimizer_name)
+ optimizer = optimizer(self.optim_param_groups(), **optimizer_args)
+
+ logging.info("Optimizer config = %s", str(optimizer))
+
+ self._optimizer = optimizer
+
+ # Try to instantiate scheduler for optimizer
+ self._scheduler = prepare_lr_scheduler(
+ optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl
+ )
+
+ # Return the optimizer with/without scheduler
+ # This return allows multiple optimizers or schedulers to be created
+ return self._optimizer, self._scheduler
+
+ def optim_param_groups(self):
+ return self.parameters()
+
+ def configure_optimizers(self):
+ self.setup_optimization()
+
+ if self._scheduler is None:
+ return self._optimizer
+ else:
+ return [self._optimizer], [self._scheduler]
+
+ def train_dataloader(self):
+ if self._train_dl is not None:
+ return self._train_dl
+
+ def val_dataloader(self):
+ if self._validation_dl is not None:
+ return self._validation_dl
+
+ def test_dataloader(self):
+ if self._test_dl is not None:
+ return self._test_dl
+
+ def validation_epoch_end(
+ self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]
+ ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
+ """
+ Default DataLoader for Validation set which automatically supports multiple data loaders
+ via `multi_validation_epoch_end`.
+
+ If multi dataset support is not required, override this method entirely in base class.
+ In such a case, there is no need to implement `multi_validation_epoch_end` either.
+
+ .. note::
+ If more than one data loader exists, and they all provide `val_loss`,
+ only the `val_loss` of the first data loader will be used by default.
+ This default can be changed by passing the special key `val_dl_idx: int`
+ inside the `validation_ds` config.
+
+ Args:
+ outputs: Single or nested list of tensor outputs from one or more data loaders.
+
+ Returns:
+ A dictionary containing the union of all items from individual data_loaders,
+ along with merged logs from all data loaders.
+ """
+ # Case where we dont provide data loaders
+ if outputs is not None and len(outputs) == 0:
+ return {}
+
+ # Case where we provide exactly 1 data loader
+ if type(outputs[0]) == dict:
+ output_dict = self.multi_validation_epoch_end(outputs, dataloader_idx=0)
+
+ if output_dict is not None and 'log' in output_dict:
+ self.log_dict(output_dict.pop('log'), on_epoch=True)
+
+ return output_dict
+
+ else: # Case where we provide more than 1 data loader
+ output_dict = {'log': {}}
+
+ # The output is a list of list of dicts, outer list corresponds to dataloader idx
+ for dataloader_idx, val_outputs in enumerate(outputs):
+ # Get prefix and dispatch call to multi epoch end
+ dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx)
+ dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx)
+
+ # If result was not provided, generate empty dict
+ dataloader_logs = dataloader_logs or {}
+
+ # Perform `val_loss` resolution first (if provided outside logs)
+ if 'val_loss' in dataloader_logs:
+ if 'val_loss' not in output_dict and dataloader_idx == self._val_dl_idx:
+ output_dict['val_loss'] = dataloader_logs['val_loss']
+
+ # For every item in the result dictionary
+ for k, v in dataloader_logs.items():
+ # If the key is `log`
+ if k == 'log':
+ # Parse every element of the log, and attach the prefix name of the data loader
+ log_dict = {}
+
+ for k_log, v_log in v.items():
+ # If we are logging the metric, but dont provide it at result level,
+ # store it twice - once in log and once in result level.
+ # Also mark log with prefix name to avoid log level clash with other data loaders
+ if k_log not in output_dict['log'] and dataloader_idx == self._val_dl_idx:
+ new_k_log = k_log
+
+ # Also insert duplicate key with prefix for ease of comparison / avoid name clash
+ log_dict[dataloader_prefix + k_log] = v_log
+
+ else:
+ # Simply prepend prefix to key and save
+ new_k_log = dataloader_prefix + k_log
+
+ # Store log value
+ log_dict[new_k_log] = v_log
+
+ # Update log storage of individual data loader
+ output_logs = output_dict['log']
+ output_logs.update(log_dict)
+
+ # Update global log storage
+ output_dict['log'] = output_logs
+
+ else:
+ # If any values are stored outside 'log', simply prefix name and store
+ new_k = dataloader_prefix + k
+ output_dict[new_k] = v
+
+ if 'log' in output_dict:
+ self.log_dict(output_dict.pop('log'), on_epoch=True)
+
+ # return everything else
+ return output_dict
+
+ def test_epoch_end(
+ self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]
+ ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
+ """
+ Default DataLoader for Test set which automatically supports multiple data loaders
+ via `multi_test_epoch_end`.
+
+ If multi dataset support is not required, override this method entirely in base class.
+ In such a case, there is no need to implement `multi_test_epoch_end` either.
+
+ .. note::
+ If more than one data loader exists, and they all provide `test_loss`,
+ only the `test_loss` of the first data loader will be used by default.
+ This default can be changed by passing the special key `test_dl_idx: int`
+ inside the `test_ds` config.
+
+ Args:
+ outputs: Single or nested list of tensor outputs from one or more data loaders.
+
+ Returns:
+ A dictionary containing the union of all items from individual data_loaders,
+ along with merged logs from all data loaders.
+ """
+ # Case where we dont provide data loaders
+ if outputs is not None and len(outputs) == 0:
+ return {}
+
+ # Case where we provide exactly 1 data loader
+ if type(outputs[0]) == dict:
+ output_dict = self.multi_test_epoch_end(outputs, dataloader_idx=0)
+
+ if output_dict is not None and 'log' in output_dict:
+ self.log_dict(output_dict.pop('log'), on_epoch=True)
+
+ return output_dict
+
+ else: # Case where we provide more than 1 data loader
+ output_dict = {'log': {}}
+
+ # The output is a list of list of dicts, outer list corresponds to dataloader idx
+ for dataloader_idx, test_outputs in enumerate(outputs):
+ # Get prefix and dispatch call to multi epoch end
+ dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx)
+ dataloader_logs = self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx)
+
+ # If result was not provided, generate empty dict
+ dataloader_logs = dataloader_logs or {}
+
+ # Perform `test_loss` resolution first (if provided outside logs)
+ if 'test_loss' in dataloader_logs:
+ if 'test_loss' not in output_dict and dataloader_idx == self._test_dl_idx:
+ output_dict['test_loss'] = dataloader_logs['test_loss']
+
+ # For every item in the result dictionary
+ for k, v in dataloader_logs.items():
+ # If the key is `log`
+ if k == 'log':
+ # Parse every element of the log, and attach the prefix name of the data loader
+ log_dict = {}
+ for k_log, v_log in v.items():
+ # If we are logging the loss, but dont provide it at result level,
+ # store it twice - once in log and once in result level.
+ # Also mark log with prefix name to avoid log level clash with other data loaders
+ if k_log not in output_dict['log'] and dataloader_idx == self._test_dl_idx:
+ new_k_log = k_log
+
+ # Also insert duplicate key with prefix for ease of comparison / avoid name clash
+ log_dict[dataloader_prefix + k_log] = v_log
+
+ else:
+ # Simply prepend prefix to key and save
+ new_k_log = dataloader_prefix + k_log
+
+ log_dict[new_k_log] = v_log
+
+ # Update log storage of individual data loader
+ output_logs = output_dict.get('log', {})
+ output_logs.update(log_dict)
+
+ # Update global log storage
+ output_dict['log'] = output_logs
+
+ else:
+ # If any values are stored outside 'log', simply prefix name and store
+ new_k = dataloader_prefix + k
+ output_dict[new_k] = v
+
+ if 'log' in output_dict:
+ self.log_dict(output_dict.pop('log'), on_epoch=True)
+
+ # return everything else
+ return output_dict
+
+ def multi_validation_epoch_end(
+ self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
+ ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
+ logging.warning(
+ "Multi data loader support has been enabled, but "
+ "`multi_validation_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
+ "If you require multi data loader support for validation sets, please override this method.\n"
+ "If you do not require multi data loader support, please instead override "
+ "`validation_epoch_end(outputs)."
+ )
+
+ def multi_test_epoch_end(
+ self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
+ ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
+ logging.warning(
+ "Multi data loader support has been enabled, but "
+ "`multi_test_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
+ "If you require multi data loader support for validation sets, please override this method.\n"
+ "If you do not require multi data loader support, please instead override "
+ "`test_epoch_end(outputs)."
+ )
+
+ def get_validation_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
+ """
+ Get the name of one or more data loaders, which will be prepended to all logs.
+
+ Args:
+ dataloader_idx: Index of the data loader.
+
+ Returns:
+ str name of the data loader at index provided.
+ """
+ return self._validation_names[dataloader_idx]
+
+ def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
+ """
+ Get the name of one or more data loaders, which will be prepended to all logs.
+
+ Args:
+ dataloader_idx: Index of the data loader.
+
+ Returns:
+ str name of the data loader at index provided.
+ """
+ return self._test_names[dataloader_idx]
+
+ def teardown(self, stage: str):
+ """
+ Called at the end of fit and test.
+
+ Args:
+ stage: either 'fit' or 'test'
+ """
+ if stage == 'fit':
+ # Update env variable to bypass multi gpu issue after training
+ # This fix affects usage of trainer.test() after trainer.train()
+ # If trainer.train() was done on multiple GPUs, then trainer.test()
+ # will try to do ddp, even if its a new Trainer object with just 1 GPU.
+ # Temporary patch to fix that
+ if 'PL_TRAINER_GPUS' in os.environ:
+ os.environ.pop('PL_TRAINER_GPUS')
+
+ super().teardown(stage)
+
+ def prepare_test(self, trainer: 'Trainer') -> bool:
+ """
+ Helper method to check whether the model can safely be tested
+ on a dataset after training (or loading a checkpoint).
+
+ # Usage:
+ trainer = Trainer()
+ if model.prepare_test(trainer):
+ trainer.test(model)
+
+ Returns:
+ bool which declares the model safe to test. Provides warnings if it has to
+ return False to guide the user.
+ """
+ if not hasattr(self._cfg, 'test_ds'):
+ logging.info("No `test_ds` config found within the manifest.")
+ return False
+
+ # Replace ddp multi-gpu until PTL has a fix
+ DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer "
+ "with single GPU and no DDP to obtain accurate results.
+ "Following pattern should be used: "
+ "gpu = 1 if cfg.trainer.gpus != 0 else 0"
+ "trainer = Trainer(gpus=gpu)"
+ "if model.prepare_test(trainer):"
+ " trainer.test(model)\n\n"""
+
+ if trainer is not None:
+ if trainer.num_gpus > 1:
+ logging.warning(DDP_WARN)
+ return False
+
+ # Assign trainer to the model
+ self.set_trainer(trainer)
+ return True
+
+ def set_trainer(self, trainer: Trainer):
+ """
+ Set an instance of Trainer object.
+
+ Args:
+ trainer: PyTorch Lightning Trainer object.
+ """
+ self._trainer = trainer
+ self.set_world_size(self._trainer)
+
+ def set_world_size(self, trainer: Trainer):
+ """
+ Determines the world size from the PyTorch Lightning Trainer.
+ And then updates AppState.
+
+ Args:
+ trainer (Trainer): PyTorch Lightning Trainer object
+ """
+ # Update AppState with world information from trainer
+ if isinstance(trainer, Trainer):
+ app_state = AppState()
+ if self._trainer.num_gpus and self._trainer.num_nodes:
+ app_state.world_size = self._trainer.num_gpus * self._trainer.num_nodes
+ else:
+ logging.warning(f'World size can only be set by PyTorch Lightning Trainer.')
+
+ def _update_dataset_config(self, dataset_name: str, config: Optional[Union[DictConfig, Dict]]):
+ """
+ Update the config (if not None) of the dataset by given name.
+ Preserves said config after updating.
+
+ Args:
+ dataset_name: str name of the dataset whose config is being updated.
+ Can be one of `train`, `validation` and `test`.
+ config: Optional DictConfig or dict. If None is passed, this method simply returns.
+ If dict is passed, it is cast into a DictConfig.
+ The internal config is updated with the passed config.
+ """
+ if hasattr(self, '_multi_dataset_mode') and self._multi_dataset_mode is True:
+ return
+
+ if config is not None:
+ if not isinstance(config, DictConfig):
+ config = OmegaConf.create(config)
+
+ if dataset_name in ['train', 'validation', 'test']:
+ OmegaConf.set_struct(self.cfg, False)
+
+ key_name = dataset_name + "_ds"
+ self.cfg[key_name] = config
+
+ OmegaConf.set_struct(self.cfg, True)
+
+ # Update hyper parameters by calling property setter
+ self.cfg = self._cfg
+ else:
+ raise ValueError("`dataset_name` when updating config must be one of [train, validation, test]")
+
+ @property
+ def num_weights(self):
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
+
+ @property
+ def cfg(self):
+ return self._cfg
+
+ @cfg.setter
+ def cfg(self, cfg):
+ self._cfg = cfg
+ self._set_hparams(cfg)
+
+ @staticmethod
+ def __make_nemo_file_from_folder(filename, source_dir):
+ with tarfile.open(filename, "w:gz") as tar:
+ # tar.add(source_dir, arcname=path.basename(source_dir))
+ tar.add(source_dir, arcname="./")
+
+ @staticmethod
+ def __unpack_nemo_file(path2file: str, out_folder: str) -> str:
+ if not path.exists(path2file):
+ raise FileNotFoundError(f"{path2file} does not exist")
+ tar = tarfile.open(path2file, "r:gz")
+ tar.extractall(path=out_folder)
+ tar.close()
+ return out_folder
+
+ @staticmethod
+ def _is_model_being_restored() -> bool:
+ global _MODEL_IS_RESTORED
+ return _MODEL_IS_RESTORED
+
+ @staticmethod
+ def _set_model_restore_state(is_being_restored: bool):
+ global _MODEL_IS_RESTORED
+ _MODEL_IS_RESTORED = is_being_restored
+
+ @staticmethod
+ def _is_restore_type_tarfile() -> bool:
+ """
+ Utility method that checks if the restore path of the underlying Model
+ is a tarfile (can be any valid archive)._MODEL_EFF_SAVE
+ """
+ global _MODEL_RESTORE_PATH
+
+ if _MODEL_RESTORE_PATH is None:
+ return False
+ else:
+ if tarfile.is_tarfile(_MODEL_RESTORE_PATH):
+ return True
+ else:
+ return False
+
+ @staticmethod
+ def set_eff_save(use_eff_save: bool):
+ global _MODEL_EFF_SAVE
+ _MODEL_EFF_SAVE = use_eff_save
+
+ @staticmethod
+ def use_eff_save() -> bool:
+ global _MODEL_EFF_SAVE
+ return _MODEL_EFF_SAVE
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e1ace2f483c1a7bec504c63d9d5854bb0e5231
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/classes/module.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager
+
+from torch.nn import Module
+
+from nemo.core.classes.common import FileIO, Serialization, Typing
+
+__all__ = ['NeuralModule']
+
+
+class NeuralModule(Module, Typing, Serialization, FileIO):
+ """
+ Abstract class offering interface shared between all PyTorch Neural Modules.
+ """
+
+ @property
+ def num_weights(self):
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
+
+ def input_example(self):
+ """
+ Override this method if random inputs won't work
+ Returns:
+ A tuple sample of valid input data.
+ """
+
+ return
+
+ def freeze(self):
+ r"""
+ Freeze all params for inference.
+ """
+ requires_grad_states = []
+ for param in self.parameters():
+ requires_grad_states.append(param.requires_grad)
+ param.requires_grad = False
+
+ self.eval()
+ return requires_grad_states
+
+ def unfreeze(self, requires_grad_states) -> None:
+ """
+ Unfreeze all parameters for training.
+ """
+ for i, param in enumerate(self.parameters()):
+ param.requires_grad = requires_grad_states[i]
+
+ self.train()
+
+ @contextmanager
+ def as_frozen(self):
+ """
+ Context manager which temporarily freezes a module, yields control and finally unfreezes the module.
+ """
+ requires_grad_states = self.freeze()
+
+ try:
+ yield
+ finally:
+ self.unfreeze(requires_grad_states)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3470ed2eb99a28343bb806c0247caa15d28b294
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.core.config.base_config import Config
+from nemo.core.config.optimizers import (
+ AdadeltaParams,
+ AdagradParams,
+ AdamaxParams,
+ AdamParams,
+ AdamWParams,
+ NovogradParams,
+ OptimizerParams,
+ RMSpropParams,
+ RpropParams,
+ SGDParams,
+ get_optimizer_config,
+ register_optimizer_params,
+)
+from nemo.core.config.pytorch import DataLoaderConfig
+from nemo.core.config.pytorch_lightning import TrainerConfig, HTrainerConfig
+from nemo.core.config.schedulers import (
+ CosineAnnealingParams,
+ InverseSquareRootAnnealingParams,
+ NoamAnnealingParams,
+ PolynomialDecayAnnealingParams,
+ PolynomialHoldDecayAnnealingParams,
+ SchedulerParams,
+ SquareAnnealingParams,
+ SquareRootAnnealingParams,
+ WarmupAnnealingParams,
+ WarmupHoldSchedulerParams,
+ WarmupSchedulerParams,
+ get_scheduler_config,
+ register_scheduler_params,
+)
+from nemo.core.config.set_config import hydra_runner, set_config
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..44649ef7b4fb460f677e3d1e37b7e3a9d459e34d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/base_config.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional
+
+__all__ = ['Config']
+
+
+@dataclass
+class Config:
+ """
+ Abstract NeMo Configuration class.
+
+ Args:
+ name: name of the module/dataset/loss/model object (used in serialization, DEFAULT: None)
+ """
+
+ name: Optional[str] = None
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py
new file mode 100644
index 0000000000000000000000000000000000000000..c729f29663c3d1da6e4ffd6d55ca31fad665066b
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/modelPT.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+from omegaconf import MISSING
+
+from nemo.core import config
+from nemo.core.classes.dataset import DatasetConfig
+from nemo.utils import exp_manager
+
+
+@dataclass
+class SchedConfig:
+ name: str = MISSING
+ min_lr: float = 0.0
+ last_epoch: int = -1
+
+
+@dataclass
+class OptimConfig:
+ name: str = MISSING
+ lr: float = MISSING
+ sched: Optional[SchedConfig] = None
+
+
+@dataclass
+class ModelConfig:
+ """
+ Model component inside ModelPT
+ """
+
+ # ...
+ train_ds: Optional[DatasetConfig] = None
+ validation_ds: Optional[DatasetConfig] = None
+ test_ds: Optional[DatasetConfig] = None
+ optim: Optional[OptimConfig] = None
+
+
+@dataclass
+class HydraConfig:
+ run: Dict[str, Any] = field(default_factory=lambda: {"dir": "."})
+ job_logging: Dict[str, Any] = field(default_factory=lambda: {"root": {"handlers": None}})
+
+
+@dataclass
+class ModelPTConfig:
+ name: str = MISSING
+ model: ModelConfig = MISSING
+ trainer: config.TrainerConfig = config.TrainerConfig(
+ accelerator="ddp", checkpoint_callback=False, logger=False, log_every_n_steps=1
+ )
+ exp_manager: Optional[Any] = exp_manager.ExpManagerConfig()
+ hydra: HydraConfig = HydraConfig()
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..18ace2dc0fc67eeabc334a3d935d88bee11d7d79
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/optimizers.py
@@ -0,0 +1,263 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, Optional, Tuple
+
+from omegaconf import OmegaConf
+
+__all__ = [
+ 'OptimizerParams',
+ 'AdamParams',
+ 'NovogradParams',
+ 'SGDParams',
+ 'AdadeltaParams',
+ 'AdamaxParams',
+ 'AdagradParams',
+ 'AdamWParams',
+ 'RMSpropParams',
+ 'RpropParams',
+]
+
+
+@dataclass
+class OptimizerParams:
+ """
+ Base Optimizer params with no values. User can chose it to explicitly override via
+ command line arguments
+ """
+
+
+@dataclass
+class SGDParams(OptimizerParams):
+ """
+ Default configuration for Adam optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD
+ """
+
+ momentum: float = 0
+ dampening: float = 0
+ weight_decay: float = 0
+ nesterov: bool = False
+
+
+@dataclass
+class AdamParams(OptimizerParams):
+ """
+ Default configuration for Adam optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html?highlight=adam#torch.optim.Adam
+ """
+
+ # betas: Tuple[float, float] = (0.9, 0.999)
+ eps: float = 1e-08
+ weight_decay: float = 0
+ amsgrad: bool = False
+
+
+@dataclass
+class AdamWParams(OptimizerParams):
+ """
+ Default configuration for AdamW optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW
+ """
+
+ betas: Tuple[float, float] = (0.9, 0.999)
+ eps: float = 1e-08
+ weight_decay: float = 0
+ amsgrad: bool = False
+
+
+@dataclass
+class AdadeltaParams(OptimizerParams):
+ """
+ Default configuration for Adadelta optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.Adadelta
+ """
+
+ rho: float = 0.9
+ eps: float = 1e-6
+ weight_decay: float = 0
+
+
+@dataclass
+class AdamaxParams(OptimizerParams):
+ """
+ Default configuration for Adamax optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.Adamax
+ """
+
+ betas: Tuple[float, float] = (0.9, 0.999)
+ eps: float = 1e-8
+ weight_decay: float = 0
+
+
+@dataclass
+class AdagradParams(OptimizerParams):
+ """
+ Default configuration for Adagrad optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.Adagrad
+ """
+
+ lr_decay: float = 0
+ weight_decay: float = 0
+ initial_accumulator_value: float = 0
+ eps: float = 1e-10
+
+
+@dataclass
+class RMSpropParams(OptimizerParams):
+ """
+ Default configuration for RMSprop optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.RMSprop
+ """
+
+ alpha: float = 0.99
+ eps: float = 1e-8
+ weight_decay: float = 0
+ momentum: float = 0
+ centered: bool = False
+
+
+@dataclass
+class RpropParams(OptimizerParams):
+ """
+ Default configuration for RpropParams optimizer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/optim.html#torch.optim.Rprop
+ """
+
+ etas: Tuple[float, float] = (0.5, 1.2)
+ step_sizes: Tuple[float, float] = (1e-6, 50)
+
+
+@dataclass
+class NovogradParams(OptimizerParams):
+ """
+ Configuration of the Novograd optimizer.
+
+ It has been proposed in "Stochastic Gradient Methods with Layer-wise
+ Adaptive Moments for Training of Deep Networks"
+ (https://arxiv.org/abs/1905.11286)
+
+ Args:
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper "On the Convergence of Adam and Beyond"
+ """
+
+ betas: Tuple[float, float] = (0.95, 0.98)
+ eps: float = 1e-8
+ weight_decay: float = 0
+ grad_averaging: bool = False
+ amsgrad: bool = False
+ luc: bool = False
+ luc_trust: float = 1e-3
+ luc_eps: float = 1e-8
+
+
+def register_optimizer_params(name: str, optimizer_params: OptimizerParams):
+ """
+ Checks if the optimizer param name exists in the registry, and if it doesnt, adds it.
+
+ This allows custom optimizer params to be added and called by name during instantiation.
+
+ Args:
+ name: Name of the optimizer. Will be used as key to retrieve the optimizer.
+ optimizer_params: Optimizer class
+ """
+ if name in AVAILABLE_OPTIMIZER_PARAMS:
+ raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}")
+
+ AVAILABLE_OPTIMIZER_PARAMS[name] = optimizer_params
+
+
+def get_optimizer_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> OptimizerParams:
+ """
+ Convenience method to obtain a OptimizerParams class and partially instantiate it with optimizer kwargs.
+
+ Args:
+ name: Name of the OptimizerParams in the registry.
+ kwargs: Optional kwargs of the optimizer used during instantiation.
+
+ Returns:
+ a partially instantiated OptimizerParams
+ """
+ if name is None:
+ return kwargs
+
+ if name not in AVAILABLE_OPTIMIZER_PARAMS:
+ raise ValueError(
+ f"Cannot resolve optimizer parameters '{name}'. Available optimizer parameters are : "
+ f"{AVAILABLE_OPTIMIZER_PARAMS.keys()}"
+ )
+
+ scheduler_params = AVAILABLE_OPTIMIZER_PARAMS[name]
+
+ if kwargs is not None and len(kwargs) != 0:
+ kwargs = OmegaConf.create(kwargs)
+ OmegaConf.merge(scheduler_params(), kwargs)
+
+ scheduler_params = partial(scheduler_params, **kwargs)
+ return scheduler_params
+
+
+AVAILABLE_OPTIMIZER_PARAMS = {
+ 'optim_params': OptimizerParams,
+ 'adam_params': AdamParams,
+ 'novograd_params': NovogradParams,
+ 'sgd_params': SGDParams,
+ 'adadelta_params': AdadeltaParams,
+ 'adamax_params': AdamaxParams,
+ 'adagrad_params': AdagradParams,
+ 'adamw_params': AdamWParams,
+ 'rmsprop_params': RMSpropParams,
+ 'rprop_params': RpropParams,
+}
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..944df02ed802e3a91281c02a5082f522e22e044a
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+from omegaconf import MISSING
+
+__all__ = ['DataLoaderConfig']
+
+
+@dataclass
+class DataLoaderConfig:
+ """
+ Configuration of PyTorch DataLoader.
+
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
+ """
+
+ batch_size: int = MISSING
+ shuffle: bool = False
+ sampler: Optional[Any] = None
+ batch_sampler: Optional[Any] = None
+ num_workers: int = 0
+ collate_fn: Optional[Any] = None
+ pin_memory: bool = False
+ drop_last: bool = False
+ timeout: int = 0
+ worker_init_fn: Optional[Any] = None
+ multiprocessing_context: Optional[Any] = None
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b248959bc7446a1925ef7c4cdf6367581d60035
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/pytorch_lightning.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Union
+
+from hydra.core.config_store import ConfigStore
+
+__all__ = ['TrainerConfig']
+
+
+cs = ConfigStore.instance()
+
+
+@dataclass
+class TrainerConfig:
+ """
+ Configuration of PyTorch Lightning Trainer.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ ..warning:
+ Picked just few params of the PTL trainer for now. This needs to be discussed.
+ ..note:
+ For the details on the function/meanings of the arguments, please refer to:
+ https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#
+ """
+
+ logger: Any = True
+ checkpoint_callback: Any = True
+ callbacks: Optional[Any] = None
+ default_root_dir: Optional[str] = None
+ gradient_clip_val: float = 0
+ process_position: int = 0
+ num_nodes: int = 1
+ num_processes: int = 1
+ gpus: Optional[Any] = None
+ auto_select_gpus: bool = False
+ tpu_cores: Optional[Any] = None
+ log_gpu_memory: Optional[str] = None
+ progress_bar_refresh_rate: int = 1
+ overfit_batches: Any = 0.0
+ track_grad_norm: Any = -1
+ check_val_every_n_epoch: int = 1
+ fast_dev_run: bool = False
+ accumulate_grad_batches: Any = 1
+ max_epochs: int = 1000
+ min_epochs: int = 1
+ max_steps: Optional[int] = None
+ min_steps: Optional[int] = None
+ limit_train_batches: Any = 1.0
+ limit_val_batches: Any = 1.0
+ limit_test_batches: Any = 1.0
+ val_check_interval: Any = 1.0
+ flush_logs_every_n_steps: int = 100
+ log_every_n_steps: int = 50
+ accelerator: Optional[str] = None
+ sync_batchnorm: bool = False
+ precision: int = 32
+ weights_summary: Optional[str] = "full" # ModelSummary.MODE_DEFAULT
+ weights_save_path: Optional[str] = None
+ num_sanity_val_steps: int = 2
+ truncated_bptt_steps: Optional[int] = None
+ resume_from_checkpoint: Optional[str] = None
+ profiler: Optional[Any] = None
+ benchmark: bool = False
+ deterministic: bool = False
+ reload_dataloaders_every_epoch: bool = False
+ auto_lr_find: Any = False
+ replace_sampler_ddp: bool = True
+ terminate_on_nan: bool = False
+ auto_scale_batch_size: Any = False
+ prepare_data_per_node: bool = True
+ amp_backend: str = 'native'
+ amp_level: str = 'O2' # backward compatible, todo: remove in v1.0.0
+
+
+@dataclass
+class HTrainerConfig(TrainerConfig):
+ find_unused_parameters: bool = False
+ ema_decay: float = 0.0
+ ema_start_step: int = 0
+
+
+# Register the trainer config.
+cs.store(
+ group="trainer", name="trainer", node=TrainerConfig,
+)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..629a1fb9fcef93c7bbe7e7d455048db85f030738
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/schedulers.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, Optional
+
+
+@dataclass
+class SchedulerParams:
+ """
+ Base configuration for all schedulers.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ last_epoch: int = -1
+
+
+@dataclass
+class WarmupSchedulerParams(SchedulerParams):
+ """
+ Base configuration for all schedulers.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ warmup_steps: Optional[float] = None
+ warmup_ratio: Optional[float] = None
+
+
+@dataclass
+class WarmupHoldSchedulerParams(WarmupSchedulerParams):
+ """
+ Base configuration for all schedulers.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ hold_steps: Optional[float] = None
+ hold_ratio: Optional[float] = None
+ min_lr: float = 0.0
+
+
+@dataclass
+class SquareAnnealingParams(WarmupSchedulerParams):
+ """
+ Square Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ min_lr: float = 1e-5
+
+
+@dataclass
+class SquareRootAnnealingParams(WarmupSchedulerParams):
+ """
+ Square Root Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ min_lr: float = 0.0
+
+
+@dataclass
+class CosineAnnealingParams(WarmupSchedulerParams):
+ """
+ Cosine Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ min_lr: float = 0.0
+
+
+@dataclass
+class NoamAnnealingParams(WarmupSchedulerParams):
+ """
+ Cosine Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ min_lr: float = 0.0
+
+
+@dataclass
+class WarmupAnnealingParams(WarmupSchedulerParams):
+ """
+ Warmup Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ warmup_ratio: 0.0
+
+
+@dataclass
+class InverseSquareRootAnnealingParams(WarmupSchedulerParams):
+ """
+ Inverse Square Root Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+
+@dataclass
+class PolynomialDecayAnnealingParams(WarmupSchedulerParams):
+ """
+ Polynomial Decay Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ power: float = 1.0
+ cycle: bool = False
+
+
+@dataclass
+class PolynomialHoldDecayAnnealingParams(WarmupSchedulerParams):
+ """
+ Polynomial Hold Decay Annealing parameter config
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ power: float = 1.0
+ cycle: bool = False
+
+
+"""
+Pytorch Optimizers
+"""
+
+
+@dataclass
+class StepLRParams(SchedulerParams):
+ """
+ Config for StepLR.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ step_size: float = 0.1
+ gamma: float = 0.1
+
+
+@dataclass
+class ExponentialLRParams(SchedulerParams):
+ """
+ Config for ExponentialLR.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ gamma: float = 0.9
+
+
+@dataclass
+class ReduceLROnPlateauParams:
+ """
+ Config for ReduceLROnPlateau.
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ mode: str = 'min'
+ factor: float = 0.1
+ patience: int = 10
+ verbose: bool = False
+ threshold: float = 1e-4
+ threshold_mode: str = 'rel'
+ cooldown: int = 0
+ min_lr: float = 0
+ eps: float = 1e-8
+
+
+@dataclass
+class CyclicLRParams(SchedulerParams):
+ """
+ Config for CyclicLR.
+ NOTE:
+ # `scale_fn` is not supported
+
+ It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name).
+ """
+
+ base_lr: float = 0.001
+ max_lr: float = 0.1
+ step_size_up: int = 2000
+ step_size_down: Optional[int] = None
+ mode: str = 'triangular'
+ gamma: float = 1.0
+ scale_mode: str = 'cycle'
+ # scale_fn is not supported
+ cycle_momentum: bool = True
+ base_momentum: float = 0.8
+ max_momentum: float = 0.9
+
+
+def register_scheduler_params(name: str, scheduler_params: SchedulerParams):
+ """
+ Checks if the schduler config name exists in the registry, and if it doesnt, adds it.
+
+ This allows custom schedulers to be added and called by name during instantiation.
+
+ Args:
+ name: Name of the optimizer. Will be used as key to retrieve the optimizer.
+ scheduler_params: SchedulerParams class
+ """
+ if name in AVAILABLE_SCHEDULER_PARAMS:
+ raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}")
+
+ AVAILABLE_SCHEDULER_PARAMS[name] = scheduler_params
+
+
+def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> SchedulerParams:
+ """
+ Convenience method to obtain a SchedulerParams class and partially instantiate it with optimizer kwargs.
+
+ Args:
+ name: Name of the SchedulerParams in the registry.
+ kwargs: Optional kwargs of the optimizer used during instantiation.
+
+ Returns:
+ a partially instantiated SchedulerParams
+ """
+ if name not in AVAILABLE_SCHEDULER_PARAMS:
+ raise ValueError(
+ f"Cannot resolve scheduler parameters '{name}'. Available scheduler parameters are : "
+ f"{AVAILABLE_SCHEDULER_PARAMS.keys()}"
+ )
+
+ scheduler_params = AVAILABLE_SCHEDULER_PARAMS[name]
+ scheduler_params = partial(scheduler_params, **kwargs)
+ return scheduler_params
+
+
+AVAILABLE_SCHEDULER_PARAMS = {
+ 'SchedulerParams': SchedulerParams,
+ 'WarmupPolicyParams': WarmupSchedulerParams,
+ 'WarmupHoldPolicyParams': WarmupHoldSchedulerParams,
+ 'SquareAnnealingParams': SquareAnnealingParams,
+ 'SquareRootAnnealingParams': SquareRootAnnealingParams,
+ 'InverseSquareRootAnnealingParams': InverseSquareRootAnnealingParams,
+ 'CosineAnnealingParams': CosineAnnealingParams,
+ 'NoamAnnealingParams': NoamAnnealingParams,
+ 'WarmupAnnealingParams': WarmupAnnealingParams,
+ 'PolynomialDecayAnnealingParams': PolynomialDecayAnnealingParams,
+ 'PolynomialHoldDecayAnnealingParams': PolynomialHoldDecayAnnealingParams,
+}
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc8f9c1cc2085c0bf465f253675f85f9094c434
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/config/set_config.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from typing import Any, Callable, Optional
+
+from hydra._internal.utils import _run_hydra, get_args_parser
+from hydra.core.config_store import ConfigStore
+from hydra.types import TaskFunction
+from omegaconf import DictConfig
+
+from nemo.core.config import Config
+
+
+def hydra_runner(
+ config_path: Optional[str] = None, config_name: Optional[str] = None, schema: Optional[Any] = None
+) -> Callable[[TaskFunction], Any]:
+ """
+ Decorator used for passing the Config paths to main function.
+ Optionally registers a schema used for validation/providing default values.
+
+ Args:
+ config_path: Path to the directory where the config exists.
+ config_name: Name of the config file.
+ schema: Structured config type representing the schema used for validation/providing default values.
+ """
+ if schema is not None:
+ # Create config store.
+ cs = ConfigStore.instance()
+ # Register the configuration as a node under a given name.
+ cs.store(name=config_name.replace(".yaml", ""), node=schema)
+
+ def decorator(task_function: TaskFunction) -> Callable[[], None]:
+ @functools.wraps(task_function)
+ def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
+ # Check it config was passed.
+ if cfg_passthrough is not None:
+ return task_function(cfg_passthrough)
+ else:
+ args = get_args_parser()
+
+ # Parse arguments in order to retrieve overrides
+ parsed_args = args.parse_args()
+
+ # Get overriding args in dot string format
+ overrides = parsed_args.overrides # type: list
+
+ # Update overrides
+ overrides.append("hydra.run.dir=.")
+ overrides.append('hydra.job_logging.root.handlers=null')
+
+ # Wrap a callable object with name `parse_args`
+ # This is to mimic the ArgParser.parse_args() API.
+ class _argparse_wrapper:
+ def __init__(self, arg_parser):
+ self.arg_parser = arg_parser
+ self._actions = arg_parser._actions
+
+ def parse_args(self, args=None, namespace=None):
+ return parsed_args
+
+ # no return value from run_hydra() as it may sometime actually run the task_function
+ # multiple times (--multirun)
+ _run_hydra(
+ args_parser=_argparse_wrapper(args),
+ task_function=task_function,
+ config_path=config_path,
+ config_name=config_name,
+ strict=None,
+ )
+
+ return wrapper
+
+ return decorator
+
+
+def set_config(config: Config) -> Callable[[TaskFunction], Any]:
+ """
+ Decorator used for passing the Structured Configs to main function.
+
+ Args:
+ config: config class derived from Config.
+ """
+ # Get class name. Not sure how important this is, but coudn't get name by accessing type().__name__.
+ class_name = str(config)
+ # Create config store.
+ cs = ConfigStore.instance()
+ # Register the configuration as a node under a given name.
+ cs.store(name=class_name, node=config)
+
+ def decorator(task_function: TaskFunction) -> Callable[[], None]:
+ @functools.wraps(task_function)
+ def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
+ # Check it config was passed.
+ if cfg_passthrough is not None:
+ return task_function(cfg_passthrough)
+ else:
+ args = get_args_parser()
+
+ # no return value from run_hydra() as it may sometime actually run the task_function
+ # multiple times (--multirun)
+ _run_hydra(
+ args_parser=args,
+ task_function=task_function,
+ config_path=None,
+ config_name=class_name,
+ strict=None,
+ )
+
+ return wrapper
+
+ return decorator
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8598bbe67f232e73be763ac0a5a37c7f7e6a5260
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nemo.core.neural_types.axes import *
+from nemo.core.neural_types.comparison import *
+from nemo.core.neural_types.elements import *
+from nemo.core.neural_types.neural_type import *
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda781f021cc544572eec543d2828fd38c9bf0a6
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/axes.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from typing import Optional
+
+__all__ = ['AxisKindAbstract', 'AxisKind', 'AxisType']
+
+
+class AxisKindAbstract(Enum):
+ """This is an abstract Enum to represents what does varying axis dimension mean.
+ In practice, you will almost always use AxisKind Enum. This Enum should be inherited by
+ your OWN Enum if you aren't satisfied with AxisKind. Then your own Enum can be used
+ instead of AxisKind."""
+
+ pass
+
+
+class AxisKind(AxisKindAbstract):
+ """This Enum represents what does varying axis dimension mean.
+ For example, does this dimension correspond to width, batch, time, etc.
+ The "Dimension" and "Channel" kinds are the same and used to represent
+ a general axis. "Any" axis will accept any axis kind fed to it.
+ """
+
+ Batch = 0
+ Time = 1
+ Dimension = 2
+ Channel = 2
+ Width = 3
+ Height = 4
+ Any = 5
+ Sequence = 6
+ FlowGroup = 7
+ Singleton = 8 # Used to represent a axis that has size 1
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __str__(self):
+ return str(self.name).lower()
+
+ @staticmethod
+ def from_str(label):
+ """Returns AxisKind instance based on short string representation"""
+ _label = label.lower().strip()
+ if _label == "b" or _label == "n" or _label == "batch":
+ return AxisKind.Batch
+ elif _label == "t" or _label == "time":
+ return AxisKind.Time
+ elif _label == "d" or _label == "c" or _label == "channel":
+ return AxisKind.Dimension
+ elif _label == "w" or _label == "width":
+ return AxisKind.Width
+ elif _label == "h" or _label == "height":
+ return AxisKind.Height
+ elif _label == "s" or _label == "singleton":
+ return AxisKind.Singleton
+ elif _label == "flowgroup":
+ return AxisKind.FlowGroup
+ elif _label == "any":
+ return AxisKind.Any
+ else:
+ raise ValueError(f"Can't create AxisKind from {label}")
+
+
+class AxisType(object):
+ """This class represents axis semantics and (optionally) it's dimensionality
+ Args:
+ kind (AxisKindAbstract): what kind of axis it is? For example Batch, Height, etc.
+ size (int, optional): specify if the axis should have a fixed size. By default it is set to None and you
+ typically do not want to set it for Batch and Time
+ is_list (bool, default=False): whether this is a list or a tensor axis
+ """
+
+ def __init__(self, kind: AxisKindAbstract, size: Optional[int] = None, is_list=False):
+ if size is not None and is_list:
+ raise ValueError("The axis can't be list and have a fixed size")
+ self.kind = kind
+ self.size = size
+ self.is_list = is_list
+
+ def __repr__(self):
+ if self.size is None:
+ representation = str(self.kind)
+ else:
+ representation = f"{str(self.kind)}:{self.size}"
+ if self.is_list:
+ representation += "_listdim"
+ return representation
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..391096221d478fd9717c019c3faf0d52ab6169da
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/comparison.py
@@ -0,0 +1,32 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+
+__all__ = ['NeuralTypeComparisonResult']
+
+
+class NeuralTypeComparisonResult(Enum):
+ """The result of comparing two neural type objects for compatibility.
+ When comparing A.compare_to(B):"""
+
+ SAME = 0
+ LESS = 1 # A is B
+ GREATER = 2 # B is A
+ DIM_INCOMPATIBLE = 3 # Resize connector might fix incompatibility
+ TRANSPOSE_SAME = 4 # A transpose and/or converting between lists and tensors will make them same
+ CONTAINER_SIZE_MISMATCH = 5 # A and B contain different number of elements
+ INCOMPATIBLE = 6 # A and B are incompatible
+ SAME_TYPE_INCOMPATIBLE_PARAMS = 7 # A and B are of the same type but parametrized differently
+ UNCHECKED = 8 # type comparison wasn't done
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e94783851577e818ea1704502e23b3926b5afb0
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/elements.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from abc import ABC
+from typing import Dict, Optional, Tuple
+
+from nemo.core.neural_types.comparison import NeuralTypeComparisonResult
+
+__all__ = [
+ 'ElementType',
+ 'VoidType',
+ 'ChannelType',
+ 'AcousticEncodedRepresentation',
+ 'AudioSignal',
+ 'SpectrogramType',
+ 'MelSpectrogramType',
+ 'MFCCSpectrogramType',
+ 'LogitsType',
+ 'LabelsType',
+ 'HypothesisType',
+ 'LossType',
+ 'RegressionValuesType',
+ 'CategoricalValuesType',
+ 'PredictionsType',
+ 'LogprobsType',
+ 'LengthsType',
+ 'EmbeddedTextType',
+ 'EncodedRepresentation',
+ 'MaskType',
+ 'Target',
+ 'ClassificationTarget',
+ 'ImageFeatureValue',
+ 'Index',
+ 'ImageValue',
+ 'NormalizedImageValue',
+ 'StringLabel',
+ 'StringType',
+ 'TokenIndex',
+ 'Length',
+ 'IntType',
+ 'FloatType',
+ 'NormalDistributionSamplesType',
+ 'NormalDistributionMeanType',
+ 'NormalDistributionLogVarianceType',
+ 'TokenDurationType',
+ 'TokenLogDurationType',
+ 'LogDeterminantType',
+ 'SequenceToSequenceAlignmentType',
+]
+
+
+class ElementType(ABC):
+ """Abstract class defining semantics of the tensor elements.
+ We are relying on Python for inheritance checking"""
+
+ def __str__(self):
+ return self.__doc__
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+ @property
+ def type_parameters(self) -> Dict:
+ """Override this property to parametrize your type. For example, you can specify 'storage' type such as
+ float, int, bool with 'dtype' keyword. Another example, is if you want to represent a signal with a
+ particular property (say, sample frequency), then you can put sample_freq->value in there.
+ When two types are compared their type_parameters must match."""
+ return {}
+
+ @property
+ def fields(self) -> Optional[Tuple]:
+ """This should be used to logically represent tuples/structures. For example, if you want to represent a
+ bounding box (x, y, width, height) you can put a tuple with names ('x', y', 'w', 'h') in here.
+ Under the hood this should be converted to the last tesnor dimension of fixed size = len(fields).
+ When two types are compared their fields must match."""
+ return None
+
+ def compare(self, second) -> NeuralTypeComparisonResult:
+ # First, check general compatibility
+ first_t = type(self)
+ second_t = type(second)
+
+ if first_t == second_t:
+ result = NeuralTypeComparisonResult.SAME
+ elif issubclass(first_t, second_t):
+ result = NeuralTypeComparisonResult.LESS
+ elif issubclass(second_t, first_t):
+ result = NeuralTypeComparisonResult.GREATER
+ else:
+ result = NeuralTypeComparisonResult.INCOMPATIBLE
+
+ if result != NeuralTypeComparisonResult.SAME:
+ return result
+ else:
+ # now check that all parameters match
+ check_params = set(self.type_parameters.keys()) == set(second.type_parameters.keys())
+ if check_params is False:
+ return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
+ else:
+ for k1, v1 in self.type_parameters.items():
+ if v1 is None or second.type_parameters[k1] is None:
+ # Treat None as Void
+ continue
+ if v1 != second.type_parameters[k1]:
+ return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
+ # check that all fields match
+ if self.fields == second.fields:
+ return NeuralTypeComparisonResult.SAME
+ else:
+ return NeuralTypeComparisonResult.INCOMPATIBLE
+
+
+class VoidType(ElementType):
+ """Void-like type which is compatible with everything.
+ It is a good practice to use this type only as necessary.
+ For example, when you need template-like functionality.
+ """
+
+ def compare(cls, second: abc.ABCMeta) -> NeuralTypeComparisonResult:
+ return NeuralTypeComparisonResult.SAME
+
+
+# TODO: Consider moving these files elsewhere
+class ChannelType(ElementType):
+ """Element to represent convolutional input/output channel.
+ """
+
+
+class EmbeddedTextType(ChannelType):
+ """Element to represent output on word/text embedding layers
+ """
+
+
+class LogitsType(ElementType):
+ """Element type to represent logits"""
+
+
+class LogprobsType(ElementType):
+ """Element type to represent log-probabilities. For example, outputs of softmax layers."""
+
+
+class LabelsType(ElementType):
+ """Element type to represent some sort of labels. This is often used as a base class to create
+ a more concrete types such as RegressionValuesType, etc."""
+
+
+class HypothesisType(LabelsType):
+ """Element type to represent some decoded hypothesis, which may further be processed to obtain
+ a concrete label."""
+
+
+class LengthsType(ElementType):
+ """Element type representing lengths of something"""
+
+
+class LossType(ElementType):
+ """Element type to represent outputs of Loss modules"""
+
+
+class EncodedRepresentation(ChannelType):
+ """Element type to represent encoded representation, for example, encoder's output"""
+
+
+class AcousticEncodedRepresentation(EncodedRepresentation):
+ """Element type to represent encoded representation returned by the acoustic encoder model"""
+
+
+class AudioSignal(ElementType):
+ """Element type to represent encoded representation returned by the acoustic encoder model
+ Args:
+ freq (int): sampling frequency of a signal. Note that two signals will only be the same if their
+ freq is the same.
+ """
+
+ def __init__(self, freq: int = None):
+ self._params = {}
+ self._params['freq'] = freq
+
+ @property
+ def type_parameters(self):
+ return self._params
+
+
+class SpectrogramType(ChannelType):
+ """Element type to represent generic spectrogram signal"""
+
+
+class MelSpectrogramType(SpectrogramType):
+ """Element type to represent mel spectrogram signal"""
+
+
+class MFCCSpectrogramType(SpectrogramType):
+ """Element type to represent MFCC spectrogram signal"""
+
+
+class PredictionsType(LabelsType):
+ """Element type to represent some sort of predictions returned by model"""
+
+
+class RegressionValuesType(PredictionsType):
+ """Element type to represent labels for regression task"""
+
+
+class CategoricalValuesType(PredictionsType):
+ """Element type to represent labels for categorical classification task"""
+
+
+class MaskType(PredictionsType):
+ """Element type to represent a boolean mask"""
+
+
+class Index(ElementType):
+ """Type representing an element being an index of the sample."""
+
+
+class Target(ElementType):
+ """
+ Type representing an element being a target value.
+ """
+
+
+class ClassificationTarget(Target):
+ """
+ Type representing an element being target value in the classification task, i.e. identifier of a desired class.
+ """
+
+
+class ImageValue(ElementType):
+ """
+ Type representing an element/value of a single image channel,
+ e.g. a single element (R) of RGB image.
+ """
+
+
+class NormalizedImageValue(ImageValue):
+ """
+ Type representing an element/value of a single image channel normalized to <0-1> range,
+ e.g. a single element (R) of normalized RGB image.
+ """
+
+
+class ImageFeatureValue(ImageValue):
+ """Type representing an element (single value) of a (image) feature maps."""
+
+
+class StringType(ElementType):
+ """Element type representing a single string"""
+
+
+class StringLabel(StringType):
+ """
+ Type representing an label being a string with class name (e.g. the "hamster" class in CIFAR100).
+ """
+
+
+class BoolType(ElementType):
+ """Element type representing a single integer"""
+
+
+class IntType(ElementType):
+ """Element type representing a single integer"""
+
+
+class FloatType(ElementType):
+ """Element type representing a single float"""
+
+
+class TokenIndex(IntType):
+ """Type representing an element being index of a token in some kind of a vocabulary."""
+
+
+class Length(IntType):
+ """Type representing an element storing a "length" (e.g. length of a list)."""
+
+
+class ProbabilityDistributionSamplesType(ElementType):
+ """Element to represent tensors that meant to be sampled from a valid probability distribution
+ """
+
+
+class NormalDistributionSamplesType(ProbabilityDistributionSamplesType):
+ """Element to represent tensors that meant to be sampled from a valid normal distribution
+ """
+
+
+class SequenceToSequenceAlignmentType(ElementType):
+ """Class to represent the alignment from seq-to-seq attention outputs. Generally a mapping from endcoder time steps
+ to decoder time steps."""
+
+
+class NormalDistributionMeanType(ElementType):
+ """Element to represent the mean of a normal distribution"""
+
+
+class NormalDistributionLogVarianceType(ElementType):
+ """Element to represent the log variance of a normal distribution"""
+
+
+class TokenDurationType(ElementType):
+ """Element for representing the duration of a token"""
+
+
+class TokenLogDurationType(ElementType):
+ """Element for representing the log-duration of a token"""
+
+
+class LogDeterminantType(ElementType):
+ """Element for representing log determinants usually used in flow models"""
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..8714d700b08192c57955425e601d4a029bf6e17c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/neural_types/neural_type.py
@@ -0,0 +1,223 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+from nemo.core.neural_types.axes import AxisKind, AxisType
+from nemo.core.neural_types.comparison import NeuralTypeComparisonResult
+from nemo.core.neural_types.elements import ElementType, VoidType
+
+__all__ = [
+ 'NeuralType',
+ 'NeuralTypeError',
+ 'NeuralPortNameMismatchError',
+ 'NeuralPortNmTensorMismatchError',
+]
+
+
+class NeuralType(object):
+ """This is the main class which would represent neural type concept.
+ It is used to represent *the types* of inputs and outputs.
+ Args:
+ axes (Optional[Tuple]): a tuple of AxisTypes objects representing the semantics of what varying each axis means
+ You can use a short, string-based form here. For example: ('B', 'C', 'H', 'W') would correspond to an NCHW
+ format frequently used in computer vision. ('B', 'T', 'D') is frequently used for signal processing and
+ means [batch, time, dimension/channel].
+ elements_type (ElementType): an instance of ElementType class representing the semantics of what is stored
+ inside the tensor. For example: logits (LogitsType), log probabilities (LogprobType), etc.
+ optional (bool): By default, this is false. If set to True, it would means that input to the port of this
+ type can be optional.
+ """
+
+ def __str__(self):
+
+ if self.axes is not None:
+ return f"axes: {self.axes}; elements_type: {self.elements_type.__class__.__name__}"
+ else:
+ return f"axes: None; elements_type: {self.elements_type.__class__.__name__}"
+
+ def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False):
+ if not isinstance(elements_type, ElementType):
+ raise ValueError(
+ "elements_type of NeuralType must be an instance of a class derived from ElementType. "
+ "Did you pass a class instead?"
+ )
+ self.elements_type = elements_type
+ if axes is not None:
+ NeuralType.__check_sanity(axes)
+ axes_list = []
+ for axis in axes:
+ if isinstance(axis, str):
+ axes_list.append(AxisType(AxisKind.from_str(axis), None))
+ elif isinstance(axis, AxisType):
+ axes_list.append(axis)
+ else:
+ raise ValueError("axis type must be either str or AxisType instance")
+ self.axes = tuple(axes_list)
+ else:
+ self.axes = None
+ self.optional = optional
+
+ def compare(self, second) -> NeuralTypeComparisonResult:
+ """Performs neural type comparison of self with second. When you chain two modules' inputs/outputs via
+ __call__ method, this comparison will be called to ensure neural type compatibility."""
+ # First, handle dimensionality
+ axes_a = self.axes
+ axes_b = second.axes
+
+ # "Big void" type
+ if isinstance(self.elements_type, VoidType) and self.axes is None:
+ return NeuralTypeComparisonResult.SAME
+
+ if self.axes is None:
+ if second.axes is None:
+ return self.elements_type.compare(second.elements_type)
+ else:
+ return NeuralTypeComparisonResult.INCOMPATIBLE
+
+ dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b)
+ element_comparison_result = self.elements_type.compare(second.elements_type)
+
+ # SAME DIMS
+ if dimensions_pass == 0:
+ return element_comparison_result
+ # TRANSPOSE_SAME DIMS
+ elif dimensions_pass == 1:
+ if element_comparison_result == NeuralTypeComparisonResult.SAME:
+ return NeuralTypeComparisonResult.TRANSPOSE_SAME
+ else:
+ return NeuralTypeComparisonResult.INCOMPATIBLE
+ # DIM_INCOMPATIBLE DIMS
+ elif dimensions_pass == 2:
+ if element_comparison_result == NeuralTypeComparisonResult.SAME:
+ return NeuralTypeComparisonResult.DIM_INCOMPATIBLE
+ else:
+ return NeuralTypeComparisonResult.INCOMPATIBLE
+ else:
+ return NeuralTypeComparisonResult.INCOMPATIBLE
+
+ def compare_and_raise_error(self, parent_type_name, port_name, second_object):
+ """ Method compares definition of one type with another and raises an error if not compatible. """
+ type_comatibility = self.compare(second_object)
+ if (
+ type_comatibility != NeuralTypeComparisonResult.SAME
+ and type_comatibility != NeuralTypeComparisonResult.GREATER
+ ):
+ raise NeuralPortNmTensorMismatchError(
+ parent_type_name, port_name, str(self), str(second_object.ntype), type_comatibility
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, NeuralType):
+ return self.compare(other)
+
+ return False
+
+ @staticmethod
+ def __check_sanity(axes):
+ # check that list come before any tensor dimension
+ are_strings = True
+ for axis in axes:
+ if not isinstance(axis, str):
+ are_strings = False
+ if isinstance(axis, str) and not are_strings:
+ raise ValueError("Either use full class names or all strings")
+ if are_strings:
+ return
+ checks_passed = True
+ saw_tensor_dim = False
+ for axis in axes:
+ if not axis.is_list:
+ saw_tensor_dim = True
+ else: # current axis is a list
+ if saw_tensor_dim: # which is preceded by tensor dim
+ checks_passed = False
+ if not checks_passed:
+ raise ValueError(
+ "You have list dimension after Tensor dimension. All list dimensions must preceed Tensor dimensions"
+ )
+
+ @staticmethod
+ def __compare_axes(axes_a, axes_b) -> int:
+ """
+ Compares axes_a and axes_b
+ Args:
+ axes_a: first axes tuple
+ axes_b: second axes tuple
+
+ Returns:
+ 0 - if they are exactly the same
+ 1 - if they are "TRANSPOSE_SAME"
+ 2 - if the are "DIM_INCOMPATIBLE"
+ 3 - if they are different
+ """
+ if axes_a is None and axes_b is None:
+ return 0
+ elif axes_a is None and axes_b is not None:
+ return 3
+ elif axes_a is not None and axes_b is None:
+ return 3
+ elif len(axes_a) != len(axes_b):
+ return 3
+ # After these ifs we know that len(axes_a) == len(axes_b)
+
+ same = True
+ kinds_a = dict()
+ kinds_b = dict()
+ for axis_a, axis_b in zip(axes_a, axes_b):
+ kinds_a[axis_a.kind] = axis_a.size
+ kinds_b[axis_b.kind] = axis_b.size
+ if axis_a.kind == AxisKind.Any:
+ same = True
+ elif (
+ axis_a.kind != axis_b.kind
+ or axis_a.is_list != axis_b.is_list
+ or (axis_a.size != axis_b.size and axis_a.size is not None)
+ ):
+ same = False
+ if same:
+ return 0
+ else:
+ # can be TRANSPOSE_SAME, DIM_INCOMPATIBLE
+ if kinds_a.keys() == kinds_b.keys():
+ for key, value in kinds_a.items():
+ if kinds_b[key] != value:
+ return 2
+ return 1
+ else:
+ return 3
+
+
+class NeuralTypeError(Exception):
+ """Base class for neural type related exceptions."""
+
+
+class NeuralPortNameMismatchError(NeuralTypeError):
+ """Exception raised when neural module is called with incorrect port
+ names."""
+
+ def __init__(self, input_port_name):
+ super().__init__()
+ self.message = "Wrong input port name: {0}".format(input_port_name)
+
+
+class NeuralPortNmTensorMismatchError(NeuralTypeError):
+ """Exception raised when a port is fed with a NmTensor of incompatible
+ type."""
+
+ def __init__(self, class_name, port_name, first_type, second_type, type_comatibility):
+ super().__init__()
+ self.message = "\nIn {}. \nPort: {} and a NmTensor it was fed are \n".format(class_name, port_name)
+ self.message += "of incompatible neural types:\n\n{} \n\n and \n\n{}".format(first_type, second_type)
+ self.message += "\n\nType comparison result: {}".format(type_comatibility)
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b421582c00347cb61f4d502886942a675d8a944c
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.core.optim.lr_scheduler import (
+ CosineAnnealing,
+ InverseSquareRootAnnealing,
+ NoamAnnealing,
+ PolynomialDecayAnnealing,
+ PolynomialHoldDecayAnnealing,
+ SquareAnnealing,
+ SquareRootAnnealing,
+ WarmupAnnealing,
+ WarmupHoldPolicy,
+ WarmupPolicy,
+ prepare_lr_scheduler,
+)
+from nemo.core.optim.novograd import Novograd
+from nemo.core.optim.optimizers import get_optimizer, parse_optimizer_args, register_optimizer
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..87df3ea1ebc16be650be7821fb1dcc0da7603b70
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/lr_scheduler.py
@@ -0,0 +1,688 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import dataclasses
+import math
+import warnings
+from functools import partial
+from typing import Any, Dict, Optional, Union
+
+import hydra
+import torch.optim as optim
+import torch.optim.lr_scheduler as pt_scheduler
+import torch.utils.data.dataloader as dataloader
+from omegaconf import DictConfig, OmegaConf
+from torch.optim.lr_scheduler import _LRScheduler
+
+from nemo.core.config import SchedulerParams, get_scheduler_config, register_scheduler_params
+from nemo.utils import logging
+
+
+class WarmupPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, warmup_power=None, max_steps=None, min_lr=0.0, last_epoch=-1):
+ assert not (
+ warmup_steps is not None and warmup_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.warmup_power = warmup_power
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
+ )
+
+ step = self.last_epoch
+
+ if step <= self.warmup_steps:
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ if self.warmup_power:
+ lr_val = lr_val ** self.warmup_power
+
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class WarmupHoldPolicy(WarmupPolicy):
+ """Variant of WarmupPolicy which maintains high learning rate for a defined number of steps.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ warmup_power=None,
+ hold_steps=None,
+ hold_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (hold_steps is not None and hold_ratio is not None), "Either use particular number of step or ratio"
+ assert hold_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"
+
+ self.min_lr = min_lr
+ self._last_warmup_lr = 0.0
+
+ # Necessary to duplicate as class attributes are hidden in inner class
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if hold_steps is not None:
+ self.hold_steps = hold_steps + self.warmup_steps
+ elif hold_ratio is not None:
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
+ else:
+ self.hold_steps = 0
+
+ super().__init__(
+ optimizer,
+ warmup_steps=warmup_steps,
+ warmup_ratio=warmup_ratio,
+ warmup_power=warmup_power,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ )
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
+ )
+
+ step = self.last_epoch
+
+ # Warmup phase
+ if step <= self.warmup_steps:
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ # Hold phase
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
+ return self.base_lrs
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+
+def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps) ** 0.5
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _square_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps) ** 2
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _cosine_annealing(initial_lr, step, max_steps, min_lr):
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ out_lr = (initial_lr - min_lr) * mult + min_lr
+ return out_lr
+
+
+def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
+ if cycle:
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
+ decay_steps *= multiplier
+ else:
+ step = min(step, decay_steps)
+ p = step / decay_steps
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
+ lr += min_lr
+ return lr
+
+
+class SquareAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _square_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class SquareRootAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class CosineAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate that was lower than the minimum learning rate."
+ )
+
+ new_lrs = [
+ _cosine_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class NoamAnnealing(_LRScheduler):
+ def __init__(
+ self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1
+ ):
+ self._normalize = d_model ** (-0.5)
+ assert not (
+ warmup_steps is not None and warmup_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning
+ )
+
+ step = max(1, self.last_epoch)
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate that was lower than the minimum learning rate."
+ )
+
+ new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs]
+ return new_lrs
+
+ def _noam_annealing(self, initial_lr, step):
+ mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5)))
+ out_lr = initial_lr * mult
+ if step > self.warmup_steps:
+ out_lr = max(out_lr, self.min_lr)
+ return out_lr
+
+
+class WarmupAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ progress = float(step / self.max_steps)
+ warmup_ratio = float(self.warmup_steps / self.max_steps)
+
+ mult = max((progress - 1.0) / (warmup_ratio - 1.0), 0.0)
+ out_lr = [initial_lr * mult for initial_lr in self.base_lrs]
+
+ return out_lr
+
+
+class InverseSquareRootAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs):
+ super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr)
+
+ def _get_lr(self, step):
+ denom = ((step + 1) / (self.warmup_steps + 1)) ** 0.5
+ out_lr = [initial_lr / denom for initial_lr in self.base_lrs]
+ return out_lr
+
+
+class PolynomialDecayAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
+ self.power = power
+ self.cycle = cycle
+
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _poly_decay(
+ initial_lr,
+ step=step - self.warmup_steps,
+ decay_steps=self.max_steps - self.warmup_steps,
+ power=self.power,
+ min_lr=self.min_lr,
+ cycle=self.cycle,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class PolynomialHoldDecayAnnealing(WarmupHoldPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0.0, power=1.0, cycle=False, last_epoch=-1, **kwargs):
+ self.power = power
+ self.cycle = cycle
+
+ super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs)
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _poly_decay(
+ initial_lr,
+ step=step - self.hold_steps,
+ decay_steps=self.max_steps - max(self.warmup_steps, self.hold_steps),
+ power=self.power,
+ min_lr=self.min_lr,
+ cycle=self.cycle,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+def register_scheduler(name: str, scheduler: _LRScheduler, scheduler_params: SchedulerParams):
+ """
+ Checks if the scheduler name exists in the registry, and if it doesnt, adds it.
+
+ This allows custom schedulers to be added and called by name during instantiation.
+
+ Args:
+ name: Name of the optimizer. Will be used as key to retrieve the optimizer.
+ scheduler: Scheduler class (inherits from _LRScheduler)
+ scheduler_params: The parameters as a dataclass of the scheduler
+ """
+ if name in AVAILABLE_SCHEDULERS:
+ raise ValueError(f"Cannot override pre-existing schedulers. Conflicting scheduler name = {name}")
+
+ AVAILABLE_SCHEDULERS[name] = scheduler
+
+ sched_name = "{}_params".format(scheduler.__name__)
+ register_scheduler_params(name=sched_name, scheduler_params=scheduler_params)
+
+
+def get_scheduler(name: str, **kwargs: Optional[Dict[str, Any]]) -> _LRScheduler:
+ """
+ Convenience method to obtain an _LRScheduler class and partially instantiate it with optimizer kwargs.
+
+ Args:
+ name: Name of the scheduler in the registry.
+ kwargs: Optional kwargs of the scheduler used during instantiation.
+
+ Returns:
+ a partially instantiated _LRScheduler
+ """
+ if name not in AVAILABLE_SCHEDULERS:
+ raise ValueError(
+ f"Cannot resolve scheduler{name}'. Available optimizers are : " f"{AVAILABLE_SCHEDULERS.keys()}"
+ )
+
+ scheduler_cls = AVAILABLE_SCHEDULERS[name]
+ scheduler = partial(scheduler_cls, **kwargs)
+ return scheduler
+
+
+def prepare_lr_scheduler(
+ optimizer: optim.Optimizer,
+ scheduler_config: Union[Dict[str, Any], DictConfig],
+ train_dataloader: Optional[dataloader.DataLoader] = None,
+) -> Optional[Dict[str, Any]]:
+ """
+ Constructs an LR Scheduler (optionally) for a given optimizer, based on a config with the following schema
+
+ optim:
+ name:
+ lr:
+
+ #
+ args:
+ name: auto # special keyword, resolves to correct optimizer config for given optimizer name
+ # cls: nemo.core.config.optimizers.NovogradParams # explicit instantiation by class path
+ params: # optional override parameters for the optimizer config
+ betas: [0.8, 0.5]
+ weight_decay: 0.001
+
+ # scheduler setup
+ sched:
+ name:
+ iters_per_batch: null # computed at runtime; mandatory to have
+ max_steps: null # computed at runtime or explicitly set here; mandatory to have
+
+ # pytorch lightning args
+ monitor: val_loss
+ reduce_on_plateau: false
+
+ #
+ args:
+ name: auto # special keyword, resolves to correct optimizer config for given optimizer name
+ # cls: nemo.core.config.schedulers.CosineAnnealingParams # explicit instantiation by class path
+ params: # optional override parameters for the optimizer config
+ warmup_steps: null
+ warmup_ratio: null
+ min_lr: 0.0
+ last_epoch: -1
+
+ Args:
+ optimizer: An instantiated Optimizer.
+ scheduler_config: A dictionary / config dict which follows the above schema.
+ train_dataloader: Optional requirement, must be passed if "iters_per_batch" is defined
+ instead of "max_steps". Used to compute effective "max_steps".
+
+ Returns:
+ A dictionary containing the LR Scheduler implementation if the config was successfully parsed
+ along with other parameters required by Pytorch Lightning, otherwise None.
+ """
+ # Build nested dictionary for convenience out of structured objects
+ if isinstance(scheduler_config, DictConfig):
+ scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)
+
+ elif dataclasses.is_dataclass(scheduler_config):
+ # Recursively transform data classes to basic dictionaries
+ scheduler_config = OmegaConf.create(scheduler_config)
+ scheduler_config = OmegaConf.to_container(scheduler_config, resolve=True)
+
+ # Test to see if config follows above schema
+
+ add_max_args_flag = True
+ interval = 'step'
+ if scheduler_config is not None:
+ if 'args' in scheduler_config:
+ scheduler_args = scheduler_config.pop('args')
+ else:
+ scheduler_args = copy.deepcopy(scheduler_config)
+
+ # Remove extra parameters from scheduler_args nest
+ # Assume all other parameters are to be passed into scheduler constructor
+
+ if 'name' in scheduler_args and scheduler_args['name'] == 'ReduceLROnPlateau':
+ add_max_args_flag = False
+ interval = 'epoch'
+
+ scheduler_args.pop('name', None)
+ scheduler_args.pop('t_max_epochs', None)
+ scheduler_args.pop('t_accumulate_grad_batches', None)
+ scheduler_args.pop('t_limit_train_batches', None)
+ scheduler_args.pop('t_num_workers', None)
+ scheduler_args.pop('monitor', None)
+ scheduler_args.pop('reduce_on_plateau', None)
+
+ else:
+ # Return gracefully in case `sched` was not supplied; inform user
+ logging.info('Scheduler not initialized as no `sched` config supplied to setup_optimizer()')
+ return None
+
+ # Try instantiation of scheduler params from config class path
+ if 'name' not in scheduler_config:
+ scheduler_args_cfg = OmegaConf.create(scheduler_args)
+ scheduler_conf = hydra.utils.instantiate(scheduler_args_cfg)
+ scheduler_args = vars(scheduler_conf)
+
+ # Get name of the scheduler
+ scheduler_name = scheduler_conf.__class__.__name__
+
+ if 'Params' in scheduler_name:
+ scheduler_name = scheduler_name.replace('Params', '')
+ else:
+ # Class path instantiation failed; try resolving "name" component
+
+ # Get name of the scheduler
+ if 'name' in scheduler_config:
+ scheduler_name = scheduler_config['name']
+ else:
+ logging.warning(
+ "Could not resolve classpath for Scheduler Config, and `name` "
+ "was not provided either. \n"
+ "Scheduler cannot be instantiated !"
+ )
+ return None
+
+ # If class path was not provided, perhaps `name` is provided for resolution
+ if 'name' in scheduler_args:
+ # If `auto` is passed as name for resolution of optimizer name,
+ # then lookup optimizer name and resolve its parameter config
+ if scheduler_args['name'] == 'auto':
+ scheduler_params_name = "{}Params".format(scheduler_name)
+ else:
+ scheduler_params_name = scheduler_args['name']
+
+ # Get override arguments provided in the config yaml file / Dict Config
+ scheduler_params_override = scheduler_args.get('params', {})
+
+ # If params is itself a dict config object provided explicitly in Dict Config
+ # Resolve to dictionary for convenience
+ if isinstance(scheduler_params_override, DictConfig):
+ scheduler_params_override = OmegaConf.to_container(scheduler_params_override, resolve=True)
+
+ # Get and instantiate the Config dataclass for this scheduler
+ scheduler_params_cls = get_scheduler_config(scheduler_params_name, **scheduler_params_override)
+ scheduler_params = scheduler_params_cls() # instantiate the parameters object
+ scheduler_args = vars(scheduler_params) # extract just the dictionary from the Config object
+
+ else:
+ # assume the input dictionary is schedular args (from dataclasses / omegaconf)
+ pass
+
+ # Extract value to monitor in losses, if provided.
+ if 'monitor' in scheduler_config:
+ monitor = scheduler_config.get('monitor')
+ else:
+ # Default to train loss
+ monitor = 'loss'
+
+ # Store exact max_steps if it is provided
+ if 'max_steps' in scheduler_config and scheduler_config['max_steps'] is not None:
+ max_steps = scheduler_config['max_steps']
+
+ elif 't_max_epochs' in scheduler_config:
+ # Compute effective max_steps if t_max_epochs is provided
+ if train_dataloader is None:
+ logging.warning(
+ 'As `t_max_epochs` is provided/computed, it is required to pass the train dataloader in order\n'
+ 'to compute effective maximum number of steps.\n'
+ 'Scheduler will not be instantiated !'
+ )
+ return None
+
+ # Raise exception if neither `max_steps` nor `t_max_epochs` is provided
+ if scheduler_config.get('t_max_epochs', None) is None:
+ logging.warning(
+ "`t_max_epochs` cannot be None when `max_steps` is not not provided.\n"
+ "This can occur when `train dataloader` is not available to correctly "
+ "prepare the scheduler.\n"
+ "Scheduler will not be instantiated !"
+ )
+ return None
+
+ # Get iters_per_batch
+ max_epochs = scheduler_config.get('t_max_epochs')
+ accumulate_grad_batches = scheduler_config.get('t_accumulate_grad_batches')
+ limit_train_batches = scheduler_config.get('t_limit_train_batches')
+ num_workers = scheduler_config.get('t_num_workers')
+
+ # Compute effective num max_steps
+ num_samples = len(train_dataloader.dataset)
+ batch_size = train_dataloader.batch_size
+ drop_last = train_dataloader.drop_last
+
+ max_steps = compute_max_steps(
+ max_epochs=max_epochs,
+ accumulate_grad_batches=accumulate_grad_batches,
+ limit_train_batches=limit_train_batches,
+ num_workers=num_workers,
+ num_samples=num_samples,
+ batch_size=batch_size,
+ drop_last=drop_last,
+ )
+
+ else:
+ logging.warning(
+ "Neither `max_steps` nor `iters_per_batch` were provided to `optim.sched`, "
+ "cannot compute effective `max_steps` !\n"
+ "Scheduler will not be instantiated !"
+ )
+ return None
+
+ # Inject max_steps (effective or provided) into the scheduler config
+ if add_max_args_flag:
+ scheduler_args['max_steps'] = max_steps
+
+ # Get the scheduler class from the config
+ scheduler_cls = get_scheduler(scheduler_name, **scheduler_args)
+
+ # Instantiate the LR schedule
+ schedule = scheduler_cls(optimizer, **scheduler_args)
+
+ logging.info(
+ 'Scheduler "%s" \nwill be used during training (effective maximum steps = %d) - \nParameters : \n(%s)',
+ str(schedule),
+ max_steps,
+ OmegaConf.to_yaml(OmegaConf.create(scheduler_args)),
+ )
+
+ # Wrap the schedule in PTL arguments to perform stepwise computation
+ # Rather than epoch level computation
+ if isinstance(schedule, optim.lr_scheduler.ReduceLROnPlateau):
+ reduce_lr_on_plateau = True
+ else:
+ reduce_lr_on_plateau = False
+
+ schedule_dict = {
+ 'scheduler': schedule,
+ 'interval': interval,
+ 'frequency': 1,
+ 'monitor': monitor,
+ 'reduce_on_plateau': reduce_lr_on_plateau,
+ }
+ return schedule_dict
+
+
+def compute_max_steps(
+ max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last
+):
+ _round = math.floor if drop_last else math.ceil
+
+ sampler_num_samples = math.ceil(num_samples / num_workers)
+
+ if drop_last and num_workers > 1:
+ logging.warning(
+ "Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"
+ )
+ # TODO: Master verion, not in pytorch 1.6.0
+ # sampler_num_samples = math.ceil((num_samples - num_workers)/ num_workers)
+
+ steps_per_epoch = _round(sampler_num_samples / batch_size)
+ if isinstance(limit_train_batches, int) or limit_train_batches == 0.0:
+ steps_per_epoch = min(steps_per_epoch, int(limit_train_batches))
+ elif steps_per_epoch != float('inf'):
+ # limit_train_batches is a percentage of batches per epoch
+ steps_per_epoch = int(steps_per_epoch * limit_train_batches)
+ if accumulate_grad_batches == 1:
+ steps_per_epoch = max(steps_per_epoch, 1)
+
+ return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs
+
+
+AVAILABLE_SCHEDULERS = {
+ 'WarmupPolicy': WarmupPolicy,
+ 'WarmupHoldPolicy': WarmupHoldPolicy,
+ 'SquareAnnealing': SquareAnnealing,
+ 'CosineAnnealing': CosineAnnealing,
+ 'NoamAnnealing': NoamAnnealing,
+ 'WarmupAnnealing': WarmupAnnealing,
+ 'InverseSquareRootAnnealing': InverseSquareRootAnnealing,
+ 'SquareRootAnnealing': SquareRootAnnealing,
+ 'PolynomialDecayAnnealing': PolynomialDecayAnnealing,
+ 'PolynomialHoldDecayAnnealing': PolynomialHoldDecayAnnealing,
+ 'StepLR': pt_scheduler.StepLR,
+ 'ExponentialLR': pt_scheduler.ExponentialLR,
+ 'ReduceLROnPlateau': pt_scheduler.ReduceLROnPlateau,
+ 'CyclicLR': pt_scheduler.CyclicLR,
+}
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..34415a505e48f947baab5740a86686a63cbd1215
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/novograd.py
@@ -0,0 +1,198 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+__all__ = ['Novograd']
+
+
+def _check_valid_opt_params(lr, eps, betas):
+ if lr < 0:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if eps < 0:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
+ raise ValueError(f"Betas have to be between 0 and 1: {betas}")
+
+
+class Novograd(Optimizer):
+ """Implements Novograd algorithm.
+ It has been proposed in "Stochastic Gradient Methods with Layer-wise
+ Adaptive Moments for Training of Deep Networks"
+ (https://arxiv.org/abs/1905.11286)
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper "On the Convergence of Adam and Beyond"
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.95, 0.98),
+ eps=1e-8,
+ eps_in_sqrt=False,
+ weight_decay=0,
+ weight_decay_ema=True,
+ grad_averaging=False,
+ amsgrad=False,
+ luc=False,
+ luc_grad_trust=0.0,
+ luc_grad_trust_rel=False,
+ luc_trust=1e-3,
+ luc_trust_min=0.0,
+ luc_eps=1e-8,
+ luc_update_min=1e-7,
+ luc_update_max=1.0,
+ ):
+ _check_valid_opt_params(lr, eps, betas)
+ assert isinstance(eps_in_sqrt, bool)
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad,
+ eps_in_sqrt=eps_in_sqrt,
+ luc=luc,
+ luc_grad_trust=luc_grad_trust,
+ luc_grad_trust_rel=luc_grad_trust_rel,
+ luc_trust=luc_trust,
+ luc_trust_min=luc_trust_min,
+ luc_eps=luc_eps,
+ luc_update_min=luc_update_min,
+ luc_update_max=luc_update_max,
+ weight_decay_ema=weight_decay_ema
+ )
+ super(Novograd, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Novograd, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault("amsgrad", False)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError("Sparse gradients are not supported.")
+
+ amsgrad = group["amsgrad"]
+ state = self.state[p]
+
+ # State initialization
+ if not state:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
+ if amsgrad:
+ # Maintains max of all exp moving avg of squared grad
+ state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+ if amsgrad:
+ max_exp_avg_sq = state["max_exp_avg_sq"]
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+
+ if group['luc'] and group['luc_grad_trust'] > 0:
+ if not group['luc_grad_trust_rel']:
+ # Clip grad so that grad are less than eta*weights
+ luc_factor = get_luc_factor(p.data, grad, luc_trust=group['luc_grad_trust'], luc_trust_min=0.0)
+ grad.mul_(luc_factor)
+ else:
+ if exp_avg_sq != 0:
+ luc_factor = get_luc_factor(exp_avg_sq.sqrt(), grad, luc_trust=group['luc_grad_trust'], luc_trust_min=0.0)
+ grad.mul_(luc_factor)
+
+ norm = grad.norm().pow(2)
+
+ if exp_avg_sq == 0:
+ exp_avg_sq.copy_(norm)
+ else:
+ exp_avg_sq.mul_(beta2).add_(norm, alpha=1.0 - beta2)
+
+ if amsgrad:
+ # Maintains max of all 2nd moment running avg till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max for normalizing running avg. of gradient
+ if not group['eps_in_sqrt']:
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"])
+ else:
+ denom = max_exp_avg_sq.add_(group["eps"]).sqrt()
+ else:
+ if not group['eps_in_sqrt']:
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
+ else:
+ denom = exp_avg_sq.add_(group["eps"]).sqrt()
+
+ grad.div_(denom)
+ if group["weight_decay"] != 0 and group['weight_decay_ema']:
+ grad.add_(p.data, alpha=group["weight_decay"])
+ if group["grad_averaging"]:
+ grad.mul_(1 - beta1)
+ exp_avg.mul_(beta1).add_(grad)
+
+ update = exp_avg
+ if group["weight_decay"] != 0 and not group['weight_decay_ema']:
+ update = update.add(p.data, alpha=group["weight_decay"])
+
+ lr = group["lr"]
+ if group['luc'] and group['luc_trust'] > 0:
+ # Clip lr so that updates are less than eta*weights
+ luc_factor = get_luc_factor(p.data, update.data, luc_trust=group['luc_trust'], luc_trust_min=group['luc_trust_min'])
+ lr = luc_factor * lr
+
+ p.data.add_(update, alpha=-lr)
+
+ return loss
+
+
+def get_luc_factor(param, grad, *, luc_trust, luc_trust_min):
+ param_norm = torch.norm(param)
+ param_norm = max(param_norm, 1e-3)
+ grad_norm = torch.norm(grad)
+ if grad_norm == 0:
+ return 1.0
+
+ max_grad_norm = param_norm * luc_trust
+ min_grad_norm = param_norm * luc_trust_min
+ if grad_norm > max_grad_norm:
+ luc_factor = max_grad_norm / grad_norm
+ elif grad_norm < min_grad_norm:
+ luc_factor = min_grad_norm / grad_norm
+ else:
+ luc_factor = 1.0
+
+ return luc_factor
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc326589d15fb2227d0a6a1650747028c7a8a833
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/core/optim/optimizers.py
@@ -0,0 +1,164 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from functools import partial
+from typing import Any, Dict, List, Optional, Union
+
+import hydra
+import torch.optim as optim
+from omegaconf import DictConfig, OmegaConf
+from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
+from torch.optim.optimizer import Optimizer
+
+from nemo.core.config import OptimizerParams, get_optimizer_config, register_optimizer_params
+from nemo.core.optim.novograd import Novograd
+
+__all__ = ['get_optimizer', 'register_optimizer', 'parse_optimizer_args']
+
+
+AVAILABLE_OPTIMIZERS = {
+ 'sgd': optim.SGD,
+ 'adam': optim.Adam,
+ 'adamw': optim.AdamW,
+ # 'adadelta': adadelta.Adadelta,
+ # 'adamax': adamax.Adamax,
+ # 'adagrad': adagrad.Adagrad,
+ # 'rmsprop': rmsprop.RMSprop,
+ # 'rprop': rprop.Rprop,
+ 'novograd': Novograd,
+}
+
+
+def parse_optimizer_args(
+ optimizer_name: str, optimizer_kwargs: Union[DictConfig, Dict[str, Any]]
+) -> Union[Dict[str, Any], DictConfig]:
+ """
+ Parses a list of strings, of the format "key=value" or "key2=val1,val2,..."
+ into a dictionary of type {key=value, key2=[val1, val2], ...}
+
+ This dictionary is then used to instantiate the chosen Optimizer.
+
+ Args:
+ optimizer_name: string name of the optimizer, used for auto resolution of params
+ optimizer_kwargs: Either a list of strings in a specified format,
+ or a dictionary. If a dictionary is provided, it is assumed the dictionary
+ is the final parsed value, and simply returned.
+ If a list of strings is provided, each item in the list is parsed into a
+ new dictionary.
+
+ Returns:
+ A dictionary
+ """
+ kwargs = {}
+
+ if optimizer_kwargs is None:
+ return kwargs
+
+ optimizer_kwargs = copy.deepcopy(optimizer_kwargs)
+
+ if isinstance(optimizer_kwargs, DictConfig):
+ optimizer_kwargs = OmegaConf.to_container(optimizer_kwargs, resolve=True)
+
+ # If it is a dictionary, perform stepwise resolution
+ if hasattr(optimizer_kwargs, 'keys'):
+ # Attempt class path resolution
+ try:
+ optimizer_kwargs_config = OmegaConf.create(optimizer_kwargs)
+ optimizer_instance = hydra.utils.instantiate(optimizer_kwargs_config) # type: DictConfig
+ optimizer_instance = vars(optimizer_instance)
+ return optimizer_instance
+ except Exception:
+ pass
+
+ # If class path was not provided, perhaps `name` is provided for resolution
+ if 'name' in optimizer_kwargs:
+ # If `auto` is passed as name for resolution of optimizer name,
+ # then lookup optimizer name and resolve its parameter config
+ if optimizer_kwargs['name'] == 'auto':
+ optimizer_params_name = "{}_params".format(optimizer_name)
+ optimizer_kwargs.pop('name')
+ else:
+ optimizer_params_name = optimizer_kwargs.pop('name')
+
+ # Override arguments provided in the config yaml file
+ if 'params' in optimizer_kwargs:
+ # If optimizer kwarg overrides are wrapped in yaml `params`
+ optimizer_params_override = optimizer_kwargs.get('params')
+ else:
+ # If the kwargs themselves are a DictConfig
+ optimizer_params_override = optimizer_kwargs
+
+ if isinstance(optimizer_params_override, DictConfig):
+ optimizer_params_override = OmegaConf.to_container(optimizer_params_override, resolve=True)
+
+ optimizer_params_cls = get_optimizer_config(optimizer_params_name, **optimizer_params_override)
+
+ # If we are provided just a Config object, simply return the dictionary of that object
+ if optimizer_params_name is None:
+ optimizer_params = vars(optimizer_params_cls)
+ return optimizer_params
+
+ else:
+ # If we are provided a partial class instantiation of a Config,
+ # Instantiate it and retrieve its vars as a dictionary
+ optimizer_params = optimizer_params_cls() # instantiate the parameters object
+ optimizer_params = vars(optimizer_params)
+ return optimizer_params
+
+ # simply return the dictionary that was provided
+ return optimizer_kwargs
+
+ return kwargs
+
+
+def register_optimizer(name: str, optimizer: Optimizer, optimizer_params: OptimizerParams):
+ """
+ Checks if the optimizer name exists in the registry, and if it doesnt, adds it.
+
+ This allows custom optimizers to be added and called by name during instantiation.
+
+ Args:
+ name: Name of the optimizer. Will be used as key to retrieve the optimizer.
+ optimizer: Optimizer class
+ optimizer_params: The parameters as a dataclass of the optimizer
+ """
+ if name in AVAILABLE_OPTIMIZERS:
+ raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}")
+
+ AVAILABLE_OPTIMIZERS[name] = optimizer
+
+ optim_name = "{}_params".format(optimizer.__name__)
+ register_optimizer_params(name=optim_name, optimizer_params=optimizer_params)
+
+
+def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
+ """
+ Convenience method to obtain an Optimizer class and partially instantiate it with optimizer kwargs.
+
+ Args:
+ name: Name of the Optimizer in the registry.
+ kwargs: Optional kwargs of the optimizer used during instantiation.
+
+ Returns:
+ a partially instantiated Optimizer
+ """
+ if name not in AVAILABLE_OPTIMIZERS:
+ raise ValueError(
+ f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}"
+ )
+
+ optimizer = AVAILABLE_OPTIMIZERS[name]
+ optimizer = partial(optimizer, **kwargs)
+ return optimizer
diff --git a/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..3843cf41a6090d636866e565aab1907af94dcf22
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/SPIRAL_L2_BN_FSQ_CTC/nemo/package_info.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+MAJOR = 1
+MINOR = 0
+PATCH = 0
+PRE_RELEASE = 'b4'
+
+# Use the following formatting: (major, minor, patch, pre-release)
+VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
+
+__shortversion__ = '.'.join(map(str, VERSION[:3]))
+__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
+
+__package_name__ = 'nemo_toolkit'
+__contact_names__ = 'NVIDIA'
+__contact_emails__ = 'nemo-toolkit@nvidia.com'
+__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/'
+__repository_url__ = 'https://github.com/nvidia/nemo'
+__download_url__ = 'https://github.com/NVIDIA/NeMo/releases'
+__description__ = 'NeMo - a toolkit for Conversational AI'
+__license__ = 'Apache2'
+__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, tts, speech, language'
diff --git a/MMaDA/models/speech_tokenization/UVITS/attentions.py b/MMaDA/models/speech_tokenization/UVITS/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..383c5da5c34103003a973aab97cc39180db35fda
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/attentions.py
@@ -0,0 +1,313 @@
+import copy
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import commons
+import modules
+from modules import LayerNorm
+
+
+class Encoder(nn.Module):
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,
+ **kwargs):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
+ window_size=window_size))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
+ proximal_bias=False, proximal_init=True, **kwargs):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
+ proximal_bias=proximal_bias, proximal_init=proximal_init))
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True,
+ block_length=None, proximal_bias=False, proximal_init=False):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels ** -0.5
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert t_s == t_t, "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert t_s == t_t, "Local attention is only available for self-attention."
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None,
+ causal=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
diff --git a/MMaDA/models/speech_tokenization/UVITS/commons.py b/MMaDA/models/speech_tokenization/UVITS/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..970489852841b0350f945b10e1c6e572860e9da8
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/commons.py
@@ -0,0 +1,161 @@
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q)
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = (
+ math.log(float(max_timescale) / float(min_timescale)) /
+ (num_timescales - 1))
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ device = duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1. / norm_type)
+ return total_norm
diff --git a/MMaDA/models/speech_tokenization/UVITS/data_utils.py b/MMaDA/models/speech_tokenization/UVITS/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8621250a8d201b0c3e19b24b6955bfdd2cd1503d
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/data_utils.py
@@ -0,0 +1,398 @@
+import time
+import os
+import random
+import numpy as np
+import torch
+import torch.utils.data
+
+import commons
+from mel_processing import spectrogram_torch
+from utils import load_wav_to_torch, load_filepaths_and_text
+from text import text_to_sequence, cleaned_text_to_sequence
+import librosa
+
+
+class TextAudioLoader(torch.utils.data.Dataset):
+ """
+ 1) loads audio, text pairs
+ 2) normalizes text and converts them to sequences of integers
+ 3) computes spectrograms from audio files.
+ """
+
+ def __init__(self, audiopaths_and_text, hparams, is_training=True):
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
+ self.text_cleaners = hparams.text_cleaners
+ self.max_wav_value = hparams.max_wav_value
+ self.sampling_rate = hparams.sampling_rate
+ self.filter_length = hparams.filter_length
+ self.hop_length = hparams.hop_length
+ self.win_length = hparams.win_length
+ self.sampling_rate = hparams.sampling_rate
+
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
+
+ self.add_blank = hparams.add_blank
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
+ self.max_text_len = getattr(hparams, "max_text_len", 190)
+
+ if is_training:
+ random.seed(1234)
+ random.shuffle(self.audiopaths_and_text)
+ self._filter()
+
+ def _filter(self):
+ """
+ Filter text & store spec lengths
+ """
+ # Store spectrogram lengths for Bucketing
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
+ # spec_length = wav_length // hop_length
+
+ audiopaths_and_text_new = []
+ lengths = []
+ for audiopath, text in self.audiopaths_and_text:
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
+ if os.path.exists(audiopath):
+ audiopaths_and_text_new.append([audiopath, text])
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
+ self.audiopaths_and_text = audiopaths_and_text_new
+ self.lengths = lengths
+
+ def get_audio_text_pair(self, audiopath_and_text):
+ # separate filename and text
+ audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
+ text = self.get_text(text)
+ spec, wav = self.get_audio(audiopath)
+ return (text, spec, wav)
+
+ def get_audio(self, filename):
+ audio, sampling_rate = librosa.load(filename, sr=self.sampling_rate)
+ audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0)
+
+ spec = spectrogram_torch(audio_norm, self.filter_length,
+ self.sampling_rate, self.hop_length, self.win_length,
+ center=False)
+
+ return spec, audio_norm
+
+ def get_text(self, text):
+ if self.cleaned_text:
+ text_norm = cleaned_text_to_sequence(text)
+ else:
+ text_norm = text_to_sequence(text, self.text_cleaners)
+ if self.add_blank:
+ text_norm = commons.intersperse(text_norm, 0)
+ text_norm = torch.LongTensor(text_norm)
+ return text_norm
+
+ def __getitem__(self, index):
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
+
+ def __len__(self):
+ return len(self.audiopaths_and_text)
+
+
+class TextAudioCollate():
+ """ Zero-pads model inputs and targets
+ """
+
+ def __init__(self, return_ids=False):
+ self.return_ids = return_ids
+
+ def __call__(self, batch):
+ """Collate's training batch from normalized text and aduio
+ PARAMS
+ ------
+ batch: [text_normalized, spec_normalized, wav_normalized]
+ """
+ # Right zero-pad all one-hot text sequences to max input length
+ _, ids_sorted_decreasing = torch.sort(
+ torch.LongTensor([x[1].size(1) for x in batch]),
+ dim=0, descending=True)
+
+ max_text_len = max([len(x[0]) for x in batch])
+ max_spec_len = max([x[1].size(1) for x in batch])
+ max_wav_len = max([x[2].size(1) for x in batch])
+
+ text_lengths = torch.LongTensor(len(batch))
+ spec_lengths = torch.LongTensor(len(batch))
+ wav_lengths = torch.LongTensor(len(batch))
+
+ text_padded = torch.LongTensor(len(batch), max_text_len)
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
+ text_padded.zero_()
+ spec_padded.zero_()
+ wav_padded.zero_()
+ for i in range(len(ids_sorted_decreasing)):
+ row = batch[ids_sorted_decreasing[i]]
+
+ text = row[0]
+ text_padded[i, :text.size(0)] = text
+ text_lengths[i] = text.size(0)
+
+ spec = row[1]
+ spec_padded[i, :, :spec.size(1)] = spec
+ spec_lengths[i] = spec.size(1)
+
+ wav = row[2]
+ wav_padded[i, :, :wav.size(1)] = wav
+ wav_lengths[i] = wav.size(1)
+
+ if self.return_ids:
+ return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, ids_sorted_decreasing
+ return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths
+
+
+"""Multi speaker version"""
+
+
+class TextAudioSpeakerLoader(torch.utils.data.Dataset):
+ """
+ 1) loads audio, speaker_id, text pairs
+ 2) normalizes text and converts them to sequences of integers
+ 3) computes spectrograms from audio files.
+ """
+
+ def __init__(self, audiopaths_sid_text, hparams, is_training=True):
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text, hparams, is_training=is_training)
+ self.text_cleaners = hparams.text_cleaners
+ self.max_wav_value = hparams.max_wav_value
+ self.sampling_rate = hparams.sampling_rate
+ self.filter_length = hparams.filter_length
+ self.hop_length = hparams.hop_length
+ self.win_length = hparams.win_length
+ self.sampling_rate = hparams.sampling_rate
+
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
+
+ self.add_blank = hparams.add_blank
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
+ self.max_text_len = getattr(hparams, "max_text_len", 190)
+
+ if is_training:
+ random.seed(1234)
+ random.shuffle(self.audiopaths_sid_text)
+ self._filter()
+
+ def _filter(self):
+ """
+ Filter text & store spec lengths
+ """
+ # Store spectrogram lengths for Bucketing
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
+ # spec_length = wav_length // hop_length
+
+ audiopaths_sid_text_new = []
+ lengths = []
+ filters = 0
+ for audiopath, sid, text in self.audiopaths_sid_text:
+ if not self.cleaned_text and 'text_split' in self.text_cleaners:
+ text_ = text.split()
+ if self.min_text_len <= len(text_) and len(text_) <= self.max_text_len and os.path.exists(audiopath):
+ audiopaths_sid_text_new.append([audiopath, sid, text])
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
+ else:
+ filters += 1
+ else:
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
+ audiopaths_sid_text_new.append([audiopath, sid, text])
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
+ else:
+ filters += 1
+ self.audiopaths_sid_text = audiopaths_sid_text_new
+ self.lengths = lengths
+ print(f"Filter out {filters} files")
+
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
+ # separate filename, speaker_id and text
+ audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2]
+ text = self.get_text(text)
+ spec, wav = self.get_audio(audiopath)
+ sid = self.get_sid(sid)
+ return (text, spec, wav, sid)
+
+ def get_audio(self, filename):
+ audio, sampling_rate = librosa.load(filename, sr=self.sampling_rate)
+ audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0)
+
+ spec = spectrogram_torch(audio_norm, self.filter_length,
+ self.sampling_rate, self.hop_length, self.win_length,
+ center=False)
+ spec = torch.squeeze(spec, 0)
+ return spec, audio_norm
+
+ def get_text(self, text):
+ if self.cleaned_text:
+ text_norm = cleaned_text_to_sequence(text)
+ else:
+ text_norm = text_to_sequence(text, self.text_cleaners)
+ if self.add_blank:
+ text_norm = commons.intersperse(text_norm, 0)
+ text_norm = torch.LongTensor(text_norm)
+ return text_norm
+
+ def get_sid(self, sid):
+ sid = torch.LongTensor([int(sid)])
+ return sid
+
+ def __getitem__(self, index):
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
+
+ def __len__(self):
+ return len(self.audiopaths_sid_text)
+
+
+class TextAudioSpeakerCollate():
+ """ Zero-pads model inputs and targets
+ """
+
+ def __init__(self, return_ids=False):
+ self.return_ids = return_ids
+
+ def __call__(self, batch):
+ """Collate's training batch from normalized text, audio and speaker identities
+ PARAMS
+ ------
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
+ """
+ # Right zero-pad all one-hot text sequences to max input length
+ _, ids_sorted_decreasing = torch.sort(
+ torch.LongTensor([x[1].size(1) for x in batch]),
+ dim=0, descending=True)
+
+ max_text_len = max([len(x[0]) for x in batch])
+ max_spec_len = max([x[1].size(1) for x in batch])
+ max_wav_len = max([x[2].size(1) for x in batch])
+
+ text_lengths = torch.LongTensor(len(batch))
+ spec_lengths = torch.LongTensor(len(batch))
+ wav_lengths = torch.LongTensor(len(batch))
+ sid = torch.LongTensor(len(batch))
+
+ text_padded = torch.LongTensor(len(batch), max_text_len)
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
+ text_padded.zero_()
+ spec_padded.zero_()
+ wav_padded.zero_()
+ for i in range(len(ids_sorted_decreasing)):
+ row = batch[ids_sorted_decreasing[i]]
+
+ text = row[0]
+ text_padded[i, :text.size(0)] = text
+ text_lengths[i] = text.size(0)
+
+ spec = row[1]
+ spec_padded[i, :, :spec.size(1)] = spec
+ spec_lengths[i] = spec.size(1)
+
+ wav = row[2]
+ wav_padded[i, :, :wav.size(1)] = wav
+ wav_lengths[i] = wav.size(1)
+
+ sid[i] = row[3]
+
+ if self.return_ids:
+ return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
+ return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid
+
+
+class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
+ """
+ Maintain similar input lengths in a batch.
+ Length groups are specified by boundaries.
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
+
+ It removes samples which are not included in the boundaries.
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
+ """
+
+ def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+ self.lengths = dataset.lengths
+ self.batch_size = batch_size
+ self.boundaries = boundaries
+
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
+ self.total_size = sum(self.num_samples_per_bucket)
+ self.num_samples = self.total_size // self.num_replicas
+
+ def _create_buckets(self):
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
+ for i in range(len(self.lengths)):
+ length = self.lengths[i]
+ idx_bucket = self._bisect(length)
+ if idx_bucket != -1:
+ buckets[idx_bucket].append(i)
+
+ for i in range(len(buckets) - 1, -1, -1): # second term debugged from 0 to -1
+ if len(buckets[i]) == 0:
+ buckets.pop(i)
+ self.boundaries.pop(i + 1)
+
+ num_samples_per_bucket = []
+ for i in range(len(buckets)):
+ len_bucket = len(buckets[i])
+ total_batch_size = self.num_replicas * self.batch_size
+ rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
+ num_samples_per_bucket.append(len_bucket + rem)
+ return buckets, num_samples_per_bucket
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+
+ indices = []
+ if self.shuffle:
+ for bucket in self.buckets:
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
+ else:
+ for bucket in self.buckets:
+ indices.append(list(range(len(bucket))))
+
+ batches = []
+ for i in range(len(self.buckets)):
+ bucket = self.buckets[i]
+ len_bucket = len(bucket)
+ ids_bucket = indices[i]
+ num_samples_bucket = self.num_samples_per_bucket[i]
+
+ # add extra samples to make it evenly divisible
+ rem = num_samples_bucket - len_bucket
+ ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
+
+ # subsample
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
+
+ # batching
+ for j in range(len(ids_bucket) // self.batch_size):
+ batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
+ batches.append(batch)
+
+ if self.shuffle:
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
+ batches = [batches[i] for i in batch_ids]
+ self.batches = batches
+
+ assert len(self.batches) * self.batch_size == self.num_samples
+ return iter(self.batches)
+
+ def _bisect(self, x, lo=0, hi=None):
+ if hi is None:
+ hi = len(self.boundaries) - 1
+
+ if hi > lo:
+ mid = (hi + lo) // 2
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
+ return mid
+ elif x <= self.boundaries[mid]:
+ return self._bisect(x, lo, mid)
+ else:
+ return self._bisect(x, mid + 1, hi)
+ else:
+ return -1
+
+ def __len__(self):
+ return self.num_samples // self.batch_size
diff --git a/MMaDA/models/speech_tokenization/UVITS/mel_processing.py b/MMaDA/models/speech_tokenization/UVITS/mel_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..572fb4e9e7306c0675f423fbcf6349780c402017
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/mel_processing.py
@@ -0,0 +1,114 @@
+import math
+import os
+import random
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torch.utils.data
+import numpy as np
+import librosa
+import librosa.util as librosa_util
+from librosa.util import normalize, pad_center, tiny
+from scipy.signal import get_window
+from scipy.io.wavfile import read
+from librosa.filters import mel as librosa_mel_fn
+
+MAX_WAV_VALUE = 32768.0
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global hann_window
+ dtype_device = str(y.dtype) + '_' + str(y.device)
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ return spec
+
+
+def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
+ global mel_basis
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+ return spec
+
+
+def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global mel_basis, hann_window
+ dtype_device = str(y.dtype) + '_' + str(y.device)
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/MMaDA/models/speech_tokenization/UVITS/models.py b/MMaDA/models/speech_tokenization/UVITS/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7db0b97805a2a3ecfbdc14cd818669248ef67f
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/models.py
@@ -0,0 +1,674 @@
+import copy
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Variable
+
+import commons
+import modules
+import attentions
+import monotonic_align
+
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from commons import init_weights, get_padding
+from script import vae_utils
+
+
+class StochasticDurationPredictor(nn.Module):
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
+ super().__init__()
+ filter_channels = in_channels # it needs to be removed from future version.
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.log_flow = modules.Log()
+ self.flows = nn.ModuleList()
+ self.flows.append(modules.ElementwiseAffine(2))
+ for i in range(n_flows):
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
+ self.flows.append(modules.Flip())
+
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
+ self.post_flows = nn.ModuleList()
+ self.post_flows.append(modules.ElementwiseAffine(2))
+ for i in range(4):
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
+ self.post_flows.append(modules.Flip())
+
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
+
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
+ x = torch.detach(x)
+ x = self.pre(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.convs(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not reverse:
+ flows = self.flows
+ assert w is not None
+
+ logdet_tot_q = 0
+ h_w = self.post_pre(w)
+ h_w = self.post_convs(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
+ z_q = e_q
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
+ logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in flows:
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
+ logdet_tot = logdet_tot + logdet
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
+ return nll + logq # [b]
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
+ for flow in flows:
+ z = flow(z, x_mask, g=x, reverse=reverse)
+ z0, z1 = torch.split(z, [1, 1], 1)
+ logw = z0
+ return logw
+
+
+class DurationPredictor(nn.Module):
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.gin_channels = gin_channels
+
+ self.drop = nn.Dropout(p_dropout)
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.norm_1 = modules.LayerNorm(filter_channels)
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.norm_2 = modules.LayerNorm(filter_channels)
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
+
+ def forward(self, x, x_mask, g=None):
+ x = torch.detach(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_1(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_2(x)
+ x = self.drop(x)
+ x = self.proj(x * x_mask)
+ return x * x_mask
+
+
+class TextEncoder(nn.Module):
+ def __init__(self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout):
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
+
+ self.encoder = attentions.Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout)
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths):
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ x = self.encoder(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ return x, m, logs, x_mask
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=gin_channels, mean_only=True))
+ self.flows.append(modules.Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+
+class PosteriorEncoder(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, g=None):
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
+ upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
+ super(Generator, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(weight_norm(
+ ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
+ k, u, padding=(k - u) // 2)))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(resblock(ch, k, d))
+
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ self.ups.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ def forward(self, x, g=None):
+ x = self.conv_pre(x)
+ if g is not None:
+ x = x + self.cond(g)
+
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.use_spectral_norm = use_spectral_norm
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminator, self).__init__()
+ periods = [2, 3, 5, 7, 11]
+
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
+ self.discriminators = nn.ModuleList(discs)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ y_d_gs.append(y_d_g)
+ fmap_rs.append(fmap_r)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class MelStyleEncoder(nn.Module):
+ ''' MelStyleEncoder '''
+
+ def __init__(self, in_dim, style_hidden, style_vector_dim, style_kernel_size, style_head, dropout,
+ mode='determinant'):
+ super(MelStyleEncoder, self).__init__()
+ self.in_dim = in_dim
+ self.hidden_dim = style_hidden
+ self.out_dim = style_vector_dim
+ self.kernel_size = style_kernel_size
+ self.n_head = style_head
+ self.dropout = dropout
+ self.mode = mode
+
+ self.spectral = nn.Sequential(
+ modules.LinearNorm(self.in_dim, self.hidden_dim),
+ modules.Mish(),
+ nn.Dropout(self.dropout),
+ modules.LinearNorm(self.hidden_dim, self.hidden_dim),
+ modules.Mish(),
+ nn.Dropout(self.dropout)
+ )
+
+ self.temporal = nn.Sequential(
+ modules.Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+ modules.Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+ )
+
+ self.slf_attn = modules.MultiHeadAttention(self.n_head, self.hidden_dim,
+ self.hidden_dim // self.n_head, self.hidden_dim // self.n_head,
+ self.dropout)
+ self.fc = modules.LinearNorm(self.hidden_dim, self.out_dim)
+
+ def temporal_avg_pool(self, x, mask=None):
+ if mask is None:
+ out = torch.mean(x, dim=1)
+ else:
+ len_ = (~mask).sum(dim=1).unsqueeze(1)
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
+ x = x.sum(dim=1)
+ out = torch.div(x, len_)
+ return out
+
+ def forward(self, x, mask=None):
+
+ max_len = x.shape[1]
+ if mask is not None:
+ mask = (mask.int() == 0).squeeze(1)
+ slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
+ else:
+ slf_attn_mask = None
+ # spectral
+ x = self.spectral(x)
+ # temporal
+ x = x.transpose(1, 2)
+ x = self.temporal(x)
+ x = x.transpose(1, 2)
+ # self-attention
+ if mask is not None:
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
+ x, _ = self.slf_attn(x, mask=slf_attn_mask)
+ # fc
+ x = self.fc(x)
+ # temoral average pooling
+ w = self.temporal_avg_pool(x, mask=mask)
+ w = F.normalize(w, dim=1)
+ return w
+
+
+class SynthesizerTrn(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(self,
+ n_vocab,
+ spec_channels,
+ segment_size,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ n_speakers=0,
+ gin_channels=0,
+ use_sdp=True,
+ use_ref_enc=False,
+ ref_mode='determinant',
+ **kwargs):
+
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.spec_channels = spec_channels
+ self.inter_channels = inter_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.resblock = resblock
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.segment_size = segment_size
+ self.n_speakers = n_speakers
+ self.gin_channels = gin_channels
+
+ self.use_sdp = use_sdp
+ self.use_ref_enc = use_ref_enc
+
+ self.enc_p = TextEncoder(n_vocab,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout)
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
+ upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
+ gin_channels=gin_channels)
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
+
+ if use_sdp:
+ self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
+ else:
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
+
+ self.ref_mode = ref_mode
+ if n_speakers > 1:
+ if ref_mode == 'entropy':
+ self.emb_g = nn.Parameter(
+ torch.randn(n_speakers, gin_channels * 2),
+ requires_grad=True)
+ else:
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
+
+ if use_ref_enc:
+ self.enc_r = MelStyleEncoder(spec_channels, hidden_channels, gin_channels, 5, 2, p_dropout, mode=ref_mode)
+
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
+
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+ g, spk_centroid = None, None
+ y_mask = None
+ aux = {}
+ if self.n_speakers > 0:
+ if self.use_ref_enc:
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) # [b, 1, t]
+ aux['y_mask'] = y_mask
+
+ if self.ref_mode == 'determinant':
+ g = self.enc_r(y.transpose(1, 2), y_mask).unsqueeze(-1) # [b, h, 1]
+ spk_centroid = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+ aux['spk_centroid'] = spk_centroid
+
+ elif self.ref_mode == 'entropy':
+ spk_mean, spk_var = vae_utils.gaussian_parameters(self.emb_g) # p^2
+ mean, var = spk_mean[sid], spk_var[sid]
+
+ g = self.enc_r(y.transpose(1, 2), y_mask)
+ g = g.unsqueeze(-1)
+
+ aux['spk_mean'] = mean
+ aux['spk_var'] = var
+ else:
+ raise NotImplementedError
+ else:
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
+ z_p = self.flow(z, y_mask, g=g)
+
+ with torch.no_grad():
+ # negative cross-entropy
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
+ neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2),
+ s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
+ neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
+ neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
+
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
+
+ w = attn.sum(2)
+ if self.use_sdp:
+ l_length = self.dp(x, x_mask, w, g=g)
+ l_length = l_length / torch.sum(x_mask)
+ else:
+ logw_ = torch.log(w + 1e-6) * x_mask
+ logw = self.dp(x, x_mask, g=g)
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging
+
+ # expand prior
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
+
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
+ o = self.dec(z_slice, g=g)
+ if self.n_speakers > 0:
+ return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), g, aux
+ else:
+ return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
+
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, spec=None,
+ spec_lengths=None):
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+ g = None
+ if sid is not None:
+ if self.n_speakers > 0:
+ if self.use_ref_enc:
+ assert spec is not None
+ spec_mask = torch.unsqueeze(commons.sequence_mask(spec_lengths, spec.size(2)), 1).to(spec.dtype)
+ g = self.enc_r(spec.transpose(1, 2), spec_mask).unsqueeze(-1)
+ else:
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+
+ if self.use_sdp:
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
+ else:
+ logw = self.dp(x, x_mask, g=g)
+ w = torch.exp(logw) * x_mask * length_scale
+ w_ceil = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = commons.generate_path(w_ceil, attn_mask)
+
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
+ 2) # [b, t', t], [b, t, d] -> [b, d, t']
+
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
+
+ def get_speaker_style_embedding(self, spec, spec_lengths):
+ assert self.use_ref_enc
+ spec_mask = torch.unsqueeze(commons.sequence_mask(spec_lengths, spec.size(2)), 1).to(spec.dtype)
+ g = self.enc_r(spec.transpose(1, 2), spec_mask)
+ return g
+
+ def synthesis_from_content_unit_style_embedding(self, x, x_lengths,style_embedding, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, ):
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+
+ g=style_embedding # should of shape [b, h, 1]
+
+ if self.use_sdp:
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
+ else:
+ logw = self.dp(x, x_mask, g=g)
+ w = torch.exp(logw) * x_mask * length_scale
+ w_ceil = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = commons.generate_path(w_ceil, attn_mask)
+
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
+ 2) # [b, t', t], [b, t, d] -> [b, d, t']
+
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
diff --git a/MMaDA/models/speech_tokenization/UVITS/modules.py b/MMaDA/models/speech_tokenization/UVITS/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4dd724eea79b871136732f5957ae5f065cf7c8a
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/modules.py
@@ -0,0 +1,569 @@
+import copy
+import math
+import numpy as np
+import scipy
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+import commons
+from commons import init_weights, get_padding
+from transforms import piecewise_rational_quadratic_transform
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.GAMMA = nn.Parameter(torch.ones(channels))
+ self.BETA = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.GAMMA, self.BETA, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DDSConv(nn.Module):
+ """
+ Dialted and Depth-Separable Convolution
+ """
+
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size ** i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
+ groups=channels, dilation=dilation, padding=padding
+ ))
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size,
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = commons.fused_add_tanh_sigmoid_multiply(
+ x_in,
+ g_l,
+ n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c2(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ if not reverse:
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+ else:
+ return x
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels, 1))
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
+ gin_channels=gin_channels)
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
+
+
+class ConvFlow(nn.Module):
+ def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.num_bins = num_bins
+ self.tail_bound = tail_bound
+ self.half_channels = in_channels // 2
+
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
+ self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0)
+ h = self.convs(h, x_mask, g=g)
+ h = self.proj(h) * x_mask
+
+ b, c, t = x0.shape
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
+
+ unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
+ unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
+ unnormalized_derivatives = h[..., 2 * self.num_bins:]
+
+ x1, logabsdet = piecewise_rational_quadratic_transform(x1,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=reverse,
+ tails='linear',
+ tail_bound=self.tail_bound
+ )
+
+ x = torch.cat([x0, x1], 1) * x_mask
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
+ if not reverse:
+ return x, logdet
+ else:
+ return x
+
+
+class LinearNorm(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ spectral_norm=False,
+ ):
+ super(LinearNorm, self).__init__()
+ self.fc = nn.Linear(in_channels, out_channels, bias)
+
+ if spectral_norm:
+ self.fc = nn.utils.spectral_norm(self.fc)
+
+ def forward(self, input):
+ out = self.fc(input)
+ return out
+
+
+class Mish(nn.Module):
+ def __init__(self):
+ super(Mish, self).__init__()
+
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class LinearNorm(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ spectral_norm=False,
+ ):
+ super(LinearNorm, self).__init__()
+ self.fc = nn.Linear(in_channels, out_channels, bias)
+
+ if spectral_norm:
+ self.fc = nn.utils.spectral_norm(self.fc)
+
+ def forward(self, input):
+ out = self.fc(input)
+ return out
+
+
+class ConvNorm(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=None,
+ dilation=1,
+ bias=True,
+ spectral_norm=False,
+ ):
+ super(ConvNorm, self).__init__()
+
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ def forward(self, input):
+ out = self.conv(input)
+ return out
+
+
+class MultiHeadAttention(nn.Module):
+ ''' Multi-Head Attention module '''
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model, n_head * d_k)
+ self.w_ks = nn.Linear(d_model, n_head * d_k)
+ self.w_vs = nn.Linear(d_model, n_head * d_v)
+
+ self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
+
+ self.fc = nn.Linear(n_head * d_v, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ if spectral_norm:
+ self.w_qs = nn.utils.spectral_norm(self.w_qs)
+ self.w_ks = nn.utils.spectral_norm(self.w_ks)
+ self.w_vs = nn.utils.spectral_norm(self.w_vs)
+ self.fc = nn.utils.spectral_norm(self.fc)
+
+ def forward(self, x, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_x, _ = x.size()
+
+ residual = x
+
+ q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
+ k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
+ v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
+ q = q.permute(2, 0, 1, 3).contiguous().view(-1,
+ len_x, d_k) # (n*b) x lq x dk
+ k = k.permute(2, 0, 1, 3).contiguous().view(-1,
+ len_x, d_k) # (n*b) x lk x dk
+ v = v.permute(2, 0, 1, 3).contiguous().view(-1,
+ len_x, d_v) # (n*b) x lv x dv
+
+ if mask is not None:
+ slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
+ else:
+ slf_mask = None
+ output, attn = self.attention(q, k, v, mask=slf_mask)
+
+ output = output.view(n_head, sz_b, len_x, d_v)
+ output = output.permute(1, 2, 0, 3).contiguous().view(
+ sz_b, len_x, -1) # b x lq x (n*dv)
+
+ output = self.fc(output)
+
+ output = self.dropout(output) + residual
+ return output, attn
+
+
+class ScaledDotProductAttention(nn.Module):
+ ''' Scaled Dot-Product Attention '''
+
+ def __init__(self, temperature, dropout):
+ super().__init__()
+ self.temperature = temperature
+ self.softmax = nn.Softmax(dim=2)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.bmm(q, k.transpose(1, 2))
+ attn = attn / self.temperature
+
+ if mask is not None:
+ attn = attn.masked_fill(mask, -np.inf)
+
+ attn = self.softmax(attn)
+ p_attn = self.dropout(attn)
+
+ output = torch.bmm(p_attn, v)
+ return output, attn
+
+
+class Conv1dGLU(nn.Module):
+ '''
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
+ '''
+
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
+ super(Conv1dGLU, self).__init__()
+ self.out_channels = out_channels
+ self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
+ x = x1 * torch.sigmoid(x2)
+ x = residual + self.dropout(x)
+ return x
diff --git a/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json b/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5c935b729734c5b41edcd286f1669e0c68434151
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json
@@ -0,0 +1,91 @@
+{
+ "train": {
+ "log_interval": 200,
+ "eval_interval": 1000,
+ "seed": 1234,
+ "epochs": 10000,
+ "learning_rate": 2e-4,
+ "betas": [
+ 0.8,
+ 0.99
+ ],
+ "eps": 1e-9,
+ "batch_size": 32,
+ "fp16_run": false,
+ "lr_decay": 0.999875,
+ "segment_size": 8192,
+ "init_lr_ratio": 1,
+ "warmup_epochs": 0,
+ "c_mel": 45,
+ "c_kl": 1.0,
+ "c_spk": 0.0
+ },
+ "data": {
+ "training_files": "/home/ma-user/work/daxintan/data/for_U2S_training/xujing_cosyvoice_40ms_multilingual_8888/EN_CH_female_male/xujing_cosyvoice_EN_CH_female_male_wav_sid_reduced_unit.txt",
+ "validation_files": "/home/ma-user/work/daxintan/data/for_U2S_training/LibriTTS_AISHELL1_40ms_multilingual_8888/LibriTTS_test-clean_AISHELL1_test_wav_sid_reduced_unit.txt",
+ "text_cleaners": [
+ "text_split"
+ ],
+ "max_wav_value": 32768.0,
+ "sampling_rate": 22050,
+ "filter_length": 1024,
+ "hop_length": 256,
+ "win_length": 1024,
+ "n_mel_channels": 80,
+ "mel_fmin": 0.0,
+ "mel_fmax": null,
+ "add_blank": false,
+ "n_speakers": 1344,
+ "cleaned_text": false,
+ "max_text_len": 1000
+ },
+ "model": {
+ "inter_channels": 192,
+ "hidden_channels": 192,
+ "filter_channels": 768,
+ "n_heads": 2,
+ "n_layers": 6,
+ "kernel_size": 3,
+ "p_dropout": 0.1,
+ "resblock": "1",
+ "resblock_kernel_sizes": [
+ 3,
+ 7,
+ 11
+ ],
+ "resblock_dilation_sizes": [
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ],
+ [
+ 1,
+ 3,
+ 5
+ ]
+ ],
+ "upsample_rates": [
+ 8,
+ 8,
+ 2,
+ 2
+ ],
+ "upsample_initial_channel": 512,
+ "upsample_kernel_sizes": [
+ 16,
+ 16,
+ 4,
+ 4
+ ],
+ "n_layers_q": 3,
+ "use_spectral_norm": false,
+ "gin_channels": 256,
+ "use_ref_enc": true
+ }
+}
diff --git a/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py b/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b5c246de8a08223bfa2ca588bee5fa2871dbe3
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/my_synthesis/my_synthesis_for_speech_unit_sequence_recombination.py
@@ -0,0 +1,47 @@
+from pathlib import Path
+
+import torch
+from scipy.io.wavfile import write
+from tqdm import tqdm
+
+import utils
+from models.speech_tokenization.UVITS.models import SynthesizerTrn
+from text import text_to_sequence
+
+import random
+import shutil
+import librosa
+import numpy as np
+import json
+
+from data_utils import load_wav_to_torch, spectrogram_torch
+
+
+def get_audio(filename, hps):
+ audio, sampling_rate = librosa.load(filename, sr=hps.sampling_rate)
+ audio_norm = torch.FloatTensor(audio.astype(np.float32)).unsqueeze(0)
+ spec = spectrogram_torch(audio_norm, hps.filter_length,
+ hps.sampling_rate, hps.hop_length, hps.win_length,
+ center=False)
+ spec = torch.squeeze(spec, 0)
+ return spec, audio_norm
+
+
+def get_text_unit_list(TTS_input_output_file):
+ text_list, unit_seq_list = [], []
+ with open(TTS_input_output_file, 'r') as f:
+ for line in f:
+ if line.startswith('input:'):
+ text_list.append(line.lstrip('input:').rstrip('\n').strip())
+ elif line.startswith('output: '):
+ unit_seq_list.append(line.lstrip('output:').rstrip('\n').strip())
+ return text_list, unit_seq_list
+
+
+def get_U2S_config_checkpoint_file(unit_type, language='English'):
+ assert language in ['English', 'Chinese']
+ assert unit_type =='40ms_multilingual_8888_xujing_cosyvoice_FT'
+ # English and Chinese using the same UVITS model and config!!
+ config_file = "./speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/config.json"
+ checkpoint_file = "./speech_tokenization/UVITS/my_UVITS_model/ROMA_1n8g_u2s_40ms_multilingual_8888_xujing_cosyvoice_EN_CH_female_male_FT_20240812/saved_checkpoint/G_322000.pth"
+ return config_file, checkpoint_file
\ No newline at end of file
diff --git a/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py b/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py
new file mode 100644
index 0000000000000000000000000000000000000000..dec0ccfe395abdcbc03d53cc3dce9873c44586e8
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/script/grad_multiply.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
+
+
+def grad_multiply_wrapper(module, rate):
+ class GMModel(torch.nn.Module):
+ def __init__(self):
+ super(GMModel, self).__init__()
+ self.module = module
+ self.rate = rate
+
+ def forward(self, *args, **kwargs):
+ if self.rate > 0:
+ features = self.module(*args, **kwargs)
+ if self.rate != 1.0:
+ if isinstance(features, torch.Tensor):
+ features = GradMultiply.apply(features, self.rate)
+ elif isinstance(features, tuple):
+ features = (GradMultiply.apply(f, self.rate) for f in features)
+ elif isinstance(features, dict):
+ features = {k: GradMultiply.apply(f, self.rate) for k, f in features.items()}
+ else:
+ with torch.no_grad():
+ features = module(*args, **kwargs)
+ return features
+
+ return GMModel()
+
+
+if __name__ == '__main__':
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ self.a = torch.nn.Parameter(torch.zeros(1))
+
+ def forward(self, x):
+ return self.a * x
+
+
+ m = M()
+ m = grad_multiply_wrapper(m, 0.3)
+ optim = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0)
+ optim.zero_grad()
+ loss = m(torch.ones(1))
+ loss.backward()
+ optim.step()
+ print(m.module.a)
diff --git a/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py b/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ebeb9dcd162fef5e998531d257a3f252a080885
--- /dev/null
+++ b/MMaDA/models/speech_tokenization/UVITS/script/html_generation.py
@@ -0,0 +1,232 @@
+"""
+Author: zhengnianzu
+Place: shenzhen
+Time: 2020.8.25
+Update: 2022.10.17
+"""
+from ast import arg
+from distutils.command.config import config
+from inspect import ArgSpec
+import os
+import sys
+import base64
+import argparse
+import shutil
+
+
+def add_description(title, description):
+ return f"""